NFVBENCH-27 Search vm image under project folder
[nfvbench.git] / nfvbench / connection.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16
17 """High level ssh library.
18 Usage examples:
19 Execute command and get output:
20     ssh = sshclient.SSH('root', 'example.com', port=33)
21     status, stdout, stderr = ssh.execute('ps ax')
22     if status:
23         raise Exception('Command failed with non-zero status.')
24     print stdout.splitlines()
25 Execute command with huge output:
26     class PseudoFile(object):
27         def write(chunk):
28             if 'error' in chunk:
29                 email_admin(chunk)
30     ssh = sshclient.SSH('root', 'example.com')
31     ssh.run('tail -f /var/log/syslog', stdout=PseudoFile(), timeout=False)
32 Execute local script on remote side:
33     ssh = sshclient.SSH('user', 'example.com')
34     status, out, err = ssh.execute('/bin/sh -s arg1 arg2',
35                                    stdin=open('~/myscript.sh', 'r'))
36 Upload file:
37     ssh = sshclient.SSH('user', 'example.com')
38     ssh.run('cat > ~/upload/file.gz', stdin=open('/store/file.gz', 'rb'))
39 Eventlet:
40     eventlet.monkey_patch(select=True, time=True)
41     or
42     eventlet.monkey_patch()
43     or
44     sshclient = eventlet.import_patched("opentstack.common.sshclient")
45 """
46
47 import re
48 import select
49 import shlex
50 import socket
51 import StringIO
52 import subprocess
53 import sys
54 import threading
55 import time
56
57 from log import LOG
58 import paramiko
59
60 # from rally.openstack.common.gettextutils import _
61
62
63 class ConnectionError(Exception):
64     pass
65
66
67 class Connection(object):
68
69     '''
70     A base connection class. Not intended to be constructed.
71     '''
72
73     def __init__(self):
74         self.distro_id = None
75         self.distro_id_like = None
76         self.distro_version = None
77         self.__get_distro()
78
79     def close(self):
80         pass
81
82     def execute(self, cmd, stdin=None, timeout=3600):
83         pass
84
85     def __extract_property(self, name, input_str):
86         expr = name + r'="?([\w\.]*)"?'
87         match = re.search(expr, input_str)
88         if match:
89             return match.group(1)
90         return 'Unknown'
91
92     # Get the linux distro
93     def __get_distro(self):
94         '''cat /etc/*-release | grep ID
95         Ubuntu:
96             DISTRIB_ID=Ubuntu
97             ID=ubuntu
98             ID_LIKE=debian
99             VERSION_ID="14.04"
100         RHEL:
101             ID="rhel"
102             ID_LIKE="fedora"
103             VERSION_ID="7.0"
104         '''
105         distro_cmd = "grep ID /etc/*-release"
106         (status, distro_out, _) = self.execute(distro_cmd)
107         if status:
108             distro_out = ''
109         self.distro_id = self.__extract_property('ID', distro_out)
110         self.distro_id_like = self.__extract_property('ID_LIKE', distro_out)
111         self.distro_version = self.__extract_property('VERSION_ID', distro_out)
112
113     def pidof(self, proc_name):
114         '''
115         Return a list containing the pids of all processes of a given name
116         the list is empty if there is no pid
117         '''
118         # the path update is necessary for RHEL
119         cmd = "PATH=$PATH:/usr/sbin pidof " + proc_name
120         (status, cmd_output, _) = self.execute(cmd)
121         if status:
122             return []
123         cmd_output = cmd_output.strip()
124         result = cmd_output.split()
125         return result
126
127     # kill pids in the given list of pids
128     def kill_proc(self, pid_list):
129         cmd = "kill -9 " + ' '.join(pid_list)
130         self.execute(cmd)
131
132     # check stats for a given path
133     def stat(self, path):
134         (status, cmd_output, _) = self.execute('stat ' + path)
135         if status:
136             return None
137         return cmd_output
138
139     def ping_check(self, target_ip, ping_count=2, pass_threshold=80):
140         '''helper function to ping from one host to an IP address,
141             for a given count and pass_threshold;
142            Steps:
143             ssh to the host and then ping to the target IP
144             then match the output and verify that the loss% is
145             less than the pass_threshold%
146             Return 1 if the criteria passes
147             Return 0, if it fails
148         '''
149         cmd = "ping -c " + str(ping_count) + " " + str(target_ip)
150         (_, cmd_output, _) = self.execute(cmd)
151
152         match = re.search(r'(\d*)% packet loss', cmd_output)
153         pkt_loss = match.group(1)
154         if int(pkt_loss) < int(pass_threshold):
155             return 1
156         else:
157             LOG.error('Ping to %s failed: %s', target_ip, cmd_output)
158             return 0
159
160     def read_remote_file(self, from_path):
161         '''
162         Read a remote file and save it to a buffer.
163         '''
164         cmd = "cat " + from_path
165         (status, cmd_output, _) = self.execute(cmd)
166         if status:
167             return None
168         return cmd_output
169
170     def get_host_os_version(self):
171         '''
172         Identify the host distribution/relase.
173         '''
174         os_release_file = "/etc/os-release"
175         sys_release_file = "/etc/system-release"
176         name = ""
177         version = ""
178
179         if self.stat(os_release_file):
180             data = self.read_remote_file(os_release_file)
181             if data is None:
182                 LOG.error("Failed to read file %s", os_release_file)
183                 return None
184
185             for line in data.splitlines():
186                 mobj = re.match(r'NAME=(.*)', line)
187                 if mobj:
188                     name = mobj.group(1).strip("\"")
189
190                 mobj = re.match(r'VERSION_ID=(.*)', line)
191                 if mobj:
192                     version = mobj.group(1).strip("\"")
193
194             os_name = name + " " + version
195             return os_name
196
197         if self.stat(sys_release_file):
198             data = self.read_remote_file(sys_release_file)
199             if data is None:
200                 LOG.error("Failed to read file %s", sys_release_file)
201                 return None
202
203             for line in data.splitlines():
204                 mobj = re.match(r'Red Hat.*', line)
205                 if mobj:
206                     return mobj.group(0)
207
208         return None
209
210     def check_rpm_package_installed(self, rpm_pkg):
211         '''
212         Given a host and a package name, check if it is installed on the
213         system.
214         '''
215         check_pkg_cmd = "rpm -qa | grep " + rpm_pkg
216
217         (status, cmd_output, _) = self.execute(check_pkg_cmd)
218         if status:
219             return None
220
221         pkg_pattern = ".*" + rpm_pkg + ".*"
222         rpm_pattern = re.compile(pkg_pattern, re.IGNORECASE)
223
224         for line in cmd_output.splitlines():
225             mobj = rpm_pattern.match(line)
226             if mobj:
227                 return mobj.group(0)
228
229         LOG.info("%s pkg installed ", rpm_pkg)
230
231         return None
232
233     def get_openstack_release(self, ver_str):
234         '''
235         Get the release series name from the package version
236         Refer to here for release tables:
237         https://wiki.openstack.org/wiki/Releases
238         '''
239         ver_table = {"2015.1": "Kilo",
240                      "2014.2": "Juno",
241                      "2014.1": "Icehouse",
242                      "2013.2": "Havana",
243                      "2013.1": "Grizzly",
244                      "2012.2": "Folsom",
245                      "2012.1": "Essex",
246                      "2011.3": "Diablo",
247                      "2011.2": "Cactus",
248                      "2011.1": "Bexar",
249                      "2010.1": "Austin"}
250
251         ver_prefix = re.search(r"20\d\d\.\d", ver_str).group(0)
252         if ver_prefix in ver_table:
253             return ver_table[ver_prefix]
254         else:
255             return "Unknown"
256
257     def check_openstack_version(self):
258         '''
259         Identify the openstack version running on the controller.
260         '''
261         nova_cmd = "nova-manage --version"
262         (status, _, err_output) = self.execute(nova_cmd)
263
264         if status:
265             return "Unknown"
266
267         ver_str = err_output.strip()
268         release_str = self.get_openstack_release(err_output)
269         return release_str + " (" + ver_str + ")"
270
271     def get_cpu_info(self):
272         '''
273         Get the CPU info of the controller.
274         Note: Here we are assuming the controller node has the exact
275               hardware as the compute nodes.
276         '''
277
278         cmd = 'cat /proc/cpuinfo | grep -m1 "model name"'
279         (status, std_output, _) = self.execute(cmd)
280         if status:
281             return "Unknown"
282         model_name = re.search(r":\s(.*)", std_output).group(1)
283
284         cmd = 'cat /proc/cpuinfo | grep "model name" | wc -l'
285         (status, std_output, _) = self.execute(cmd)
286         if status:
287             return "Unknown"
288         cores = std_output.strip()
289
290         return (cores + " * " + model_name)
291
292     def get_nic_name(self, agent_type, encap, internal_iface_dict):
293         '''
294         Get the NIC info of the controller.
295         Note: Here we are assuming the controller node has the exact
296               hardware as the compute nodes.
297         '''
298
299         # The internal_ifac_dict is a dictionary contains the mapping between
300         # hostname and the internal interface name like below:
301         # {u'hh23-4': u'eth1', u'hh23-5': u'eth1', u'hh23-6': u'eth1'}
302
303         cmd = "hostname"
304         (status, std_output, _) = self.execute(cmd)
305         if status:
306             return "Unknown"
307         hostname = std_output.strip()
308
309         if hostname in internal_iface_dict:
310             iface = internal_iface_dict[hostname]
311         else:
312             return "Unknown"
313
314         # Figure out which interface is for internal traffic
315         if 'Linux bridge' in agent_type:
316             ifname = iface
317         elif 'Open vSwitch' in agent_type:
318             if encap == 'vlan':
319                 # [root@hh23-10 ~]# ovs-vsctl list-ports br-inst
320                 # eth1
321                 # phy-br-inst
322                 cmd = 'ovs-vsctl list-ports ' + \
323                     iface + ' | grep -E "^[^phy].*"'
324                 (status, std_output, _) = self.execute(cmd)
325                 if status:
326                     return "Unknown"
327                 ifname = std_output.strip()
328             elif encap == 'vxlan' or encap == 'gre':
329                 # This is complicated. We need to first get the local IP address on
330                 # br-tun, then do a reverse lookup to get the physical interface.
331                 #
332                 # [root@hh23-4 ~]# ip addr show to "23.23.2.14"
333                 # 3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc mq state UP qlen 1000
334                 #    inet 23.23.2.14/24 brd 23.23.2.255 scope global eth1
335                 #       valid_lft forever preferred_lft forever
336                 cmd = "ip addr show to " + iface + " | awk -F: '{print $2}'"
337                 (status, std_output, _) = self.execute(cmd)
338                 if status:
339                     return "Unknown"
340                 ifname = std_output.strip()
341         else:
342             return "Unknown"
343
344         cmd = 'ethtool -i ' + ifname + ' | grep bus-info'
345         (status, std_output, _) = self.execute(cmd)
346         if status:
347             return "Unknown"
348         bus_info = re.search(r":\s(.*)", std_output).group(1)
349
350         cmd = 'lspci -s ' + bus_info
351         (status, std_output, _) = self.execute(cmd)
352         if status:
353             return "Unknown"
354         nic_name = re.search(
355             r"Ethernet controller:\s(.*)",
356             std_output).group(1)
357
358         return (nic_name)
359
360     def get_l2agent_version(self, agent_type):
361         '''
362         Get the L2 agent version of the controller.
363         Note: Here we are assuming the controller node has the exact
364               hardware as the compute nodes.
365         '''
366         if 'Linux bridge' in agent_type:
367             cmd = "brctl --version | awk -F',' '{print $2}'"
368             ver_string = "Linux Bridge "
369         elif 'Open vSwitch' in agent_type:
370             cmd = "ovs-vsctl --version | awk -F')' '{print $2}'"
371             ver_string = "OVS "
372         else:
373             return "Unknown"
374
375         (status, std_output, _) = self.execute(cmd)
376         if status:
377             return "Unknown"
378
379         return ver_string + std_output.strip()
380
381
382 class SSHError(Exception):
383     pass
384
385
386 class SSHTimeout(SSHError):
387     pass
388
389 # Check IPv4 address syntax - not completely fool proof but will catch
390 # some invalid formats
391
392
393 def is_ipv4(address):
394     try:
395         socket.inet_aton(address)
396     except socket.error:
397         return False
398     return True
399
400
401 class SSHAccess(object):
402
403     '''
404     A class to contain all the information needed to access a host
405     (native or virtual) using SSH
406     '''
407
408     def __init__(self, arg_value=None):
409         '''
410             decode user@host[:pwd]
411             'hugo@1.1.1.1:secret' -> ('hugo', '1.1.1.1', 'secret', None)
412             'huggy@2.2.2.2' -> ('huggy', '2.2.2.2', None, None)
413             None ->(None, None, None, None)
414             Examples of fatal errors (will call exit):
415                 'hutch@q.1.1.1' (invalid IP)
416                 '@3.3.3.3' (missing username)
417                 'hiro@' or 'buggy' (missing host IP)
418             The error field will be None in case of success or will
419             contain a string describing the error
420         '''
421         self.username = None
422         self.host = None
423         self.password = None
424         # name of the file that contains the private key
425         self.private_key_file = None
426         # this is the private key itself (a long string starting with
427         # -----BEGIN RSA PRIVATE KEY-----
428         # used when the private key is not saved in any file
429         self.private_key = None
430         self.public_key_file = None
431         self.port = 22
432         self.error = None
433
434         if not arg_value:
435             return
436         match = re.search(r'^([^@]+)@([0-9\.]+):?(.*)$', arg_value)
437         if not match:
438             self.error = 'Invalid argument: ' + arg_value
439             return
440         if not is_ipv4(match.group(2)):
441             self.error = 'Invalid IPv4 address ' + match.group(2)
442             return
443         (self.username, self.host, self.password) = match.groups()
444
445     def copy_from(self, ssh_access):
446         self.username = ssh_access.username
447         self.host = ssh_access.host
448         self.port = ssh_access.port
449         self.password = ssh_access.password
450         self.private_key = ssh_access.private_key
451         self.public_key_file = ssh_access.public_key_file
452         self.private_key_file = ssh_access.private_key_file
453
454
455 class SSH(Connection):
456
457     """Represent ssh connection."""
458
459     def __init__(self, ssh_access,
460                  connect_timeout=60,
461                  connect_retry_count=30,
462                  connect_retry_wait_sec=2):
463         """Initialize SSH client.
464         :param user: ssh username
465         :param host: hostname or ip address of remote ssh server
466         :param port: remote ssh port
467         :param pkey: RSA or DSS private key string or file object
468         :param key_filename: private key filename
469         :param password: password
470         :param connect_timeout: timeout when connecting ssh
471         :param connect_retry_count: how many times to retry connecting
472         :param connect_retry_wait_sec: seconds to wait between retries
473         """
474
475         self.ssh_access = ssh_access
476         if ssh_access.private_key:
477             self.pkey = self._get_pkey(ssh_access.private_key)
478         else:
479             self.pkey = None
480         self._client = False
481         self.connect_timeout = connect_timeout
482         self.connect_retry_count = connect_retry_count
483         self.connect_retry_wait_sec = connect_retry_wait_sec
484         super(SSH, self).__init__()
485
486     def _get_pkey(self, key):
487         '''Get the binary form of the private key
488         from the text form
489         '''
490         if isinstance(key, basestring):
491             key = StringIO.StringIO(key)
492         errors = []
493         for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
494             try:
495                 return key_class.from_private_key(key)
496             except paramiko.SSHException as exc:
497                 errors.append(exc)
498         raise SSHError('Invalid pkey: %s' % (errors))
499
500     def _is_active(self):
501         if self._client:
502             try:
503                 transport = self._client.get_transport()
504                 session = transport.open_session()
505                 session.close()
506                 return True
507             except Exception:
508                 return False
509         else:
510             return False
511
512     def _get_client(self, force=False):
513         if not force and self._is_active():
514             return self._client
515         if self._client:
516             LOG.info('Re-establishing ssh connection with %s' % (self.ssh_access.host))
517             self._client.close()
518         self._client = paramiko.SSHClient()
519         self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
520         for _ in range(self.connect_retry_count):
521             try:
522                 self._client.connect(self.ssh_access.host,
523                                      username=self.ssh_access.username,
524                                      port=self.ssh_access.port,
525                                      pkey=self.pkey,
526                                      key_filename=self.ssh_access.private_key_file,
527                                      password=self.ssh_access.password,
528                                      timeout=self.connect_timeout)
529                 self._client.get_transport().set_keepalive(5)
530                 return self._client
531             except (paramiko.AuthenticationException,
532                     paramiko.BadHostKeyException,
533                     paramiko.SSHException,
534                     socket.error,
535                     Exception):
536                 time.sleep(self.connect_retry_wait_sec)
537
538         self._client = None
539         msg = '[%s] SSH Connection failed after %s attempts' % (self.ssh_access.host,
540                                                                 self.connect_retry_count)
541         raise SSHError(msg)
542
543     def _get_session(self):
544         client = self._get_client()
545         for _ in range(self.connect_retry_count):
546             try:
547                 transport = client.get_transport()
548                 session = transport.open_session()
549                 return session
550             except Exception:
551                 client = self._get_client(force=True)
552         return None
553
554     def close(self):
555         super(SSH, self).close()
556         if self._client:
557             self._client.close()
558             self._client = False
559
560     def run(self, cmd, stdin=None, stdout=None, stderr=None,
561             raise_on_error=True, timeout=3600, sudo=False):
562         """Execute specified command on the server.
563         :param cmd:             Command to be executed.
564         :param stdin:           Open file or string to pass to stdin.
565         :param stdout:          Open file to connect to stdout.
566         :param stderr:          Open file to connect to stderr.
567         :param raise_on_error:  If False then exit code will be return. If True
568                                 then exception will be raized if non-zero code.
569         :param timeout:         Timeout in seconds for command execution.
570                                 Default 1 hour. No timeout if set to 0.
571         :param sudo:            Executes command as sudo with default password
572         """
573
574         if isinstance(stdin, basestring):
575             stdin = StringIO.StringIO(stdin)
576
577         return self._run(cmd, stdin=stdin, stdout=stdout,
578                          stderr=stderr, raise_on_error=raise_on_error,
579                          timeout=timeout, sudo=sudo)
580
581     def _run(self, cmd, stdin=None, stdout=None, stderr=None,
582              raise_on_error=True, timeout=3600, sudo=False):
583
584         session = self._get_session()
585
586         if session is None:
587             raise SSHError('Unable to open session to ssh connection')
588
589         if sudo:
590             cmd = "echo " + self.ssh_access.password + " | sudo -S -p '' " + cmd
591             session.get_pty()
592
593         session.exec_command(cmd)
594         start_time = time.time()
595
596         data_to_send = ''
597         stderr_data = None
598
599         # If we have data to be sent to stdin then `select' should also
600         # check for stdin availability.
601         if stdin and not stdin.closed:
602             writes = [session]
603         else:
604             writes = []
605
606         while True:
607             # Block until data can be read/write.
608             select.select([session], writes, [session], 1)
609
610             if session.recv_ready():
611                 data = session.recv(4096)
612                 if stdout is not None:
613                     stdout.write(data)
614                 continue
615
616             if session.recv_stderr_ready():
617                 stderr_data = session.recv_stderr(4096)
618                 if stderr is not None:
619                     stderr.write(stderr_data)
620                 continue
621
622             if session.send_ready():
623                 if stdin is not None and not stdin.closed:
624                     if not data_to_send:
625                         data_to_send = stdin.read(4096)
626                         if not data_to_send:
627                             stdin.close()
628                             session.shutdown_write()
629                             writes = []
630                             continue
631                     sent_bytes = session.send(data_to_send)
632                     data_to_send = data_to_send[sent_bytes:]
633
634             if session.exit_status_ready():
635                 break
636
637             if timeout and (time.time() - timeout) > start_time:
638                 args = {'cmd': cmd, 'host': self.ssh_access.host}
639                 raise SSHTimeout(('Timeout executing command '
640                                   '"%(cmd)s" on host %(host)s') % args)
641             # if e:
642             #    raise SSHError('Socket error.')
643
644         exit_status = session.recv_exit_status()
645         if 0 != exit_status and raise_on_error:
646             fmt = ('Command "%(cmd)s" failed with exit_status %(status)d.')
647             details = fmt % {'cmd': cmd, 'status': exit_status}
648             if stderr_data:
649                 details += (' Last stderr data: "%s".') % stderr_data
650             raise SSHError(details)
651         return exit_status
652
653     def execute(self, cmd, stdin=None, timeout=3600, sudo=False):
654         """Execute the specified command on the server.
655         :param cmd:     Command to be executed.
656         :param stdin:   Open file to be sent on process stdin.
657         :param timeout: Timeout for execution of the command.
658         Return tuple (exit_status, stdout, stderr)
659         """
660         stdout = StringIO.StringIO()
661         stderr = StringIO.StringIO()
662
663         exit_status = self.run(cmd, stderr=stderr,
664                                stdout=stdout, stdin=stdin,
665                                timeout=timeout, raise_on_error=False, sudo=sudo)
666         stdout.seek(0)
667         stderr.seek(0)
668         return (exit_status, stdout.read(), stderr.read())
669
670     def wait(self, timeout=120, interval=1):
671         """Wait for the host will be available via ssh."""
672         start_time = time.time()
673         while True:
674             try:
675                 return self.execute('uname')
676             except (socket.error, SSHError):
677                 time.sleep(interval)
678             if time.time() > (start_time + timeout):
679                 raise SSHTimeout(
680                     ('Timeout waiting for "%s"') %
681                     self.ssh_access.host)
682
683
684 class SubprocessTimeout(Exception):
685     pass
686
687
688 class Subprocess(Connection):
689
690     """Represent subprocess connection."""
691
692     def execute(self, cmd, stdin=None, timeout=3600):
693         process = subprocess.Popen(shlex.split(cmd), stderr=subprocess.PIPE,
694                                    stdout=subprocess.PIPE,
695                                    shell=True)
696         timer = threading.Timer(timeout, process.kill)
697         stdout, stderr = process.communicate(input=stdin)
698         status = process.wait()
699         if timer.is_alive():
700             timer.cancel()
701             raise SubprocessTimeout('Timeout executing command "%(cmd)s"')
702         return (status, stdout, stderr)
703
704
705 ##################################################
706 # Only invoke the module directly for test purposes. Should be
707 # invoked from pns script.
708 ##################################################
709 def main():
710     # As argument pass the SSH access string, e.g. "localadmin@1.1.1.1:secret"
711     test_ssh = SSH(SSHAccess(sys.argv[1]))
712
713     print 'ID=' + test_ssh.distro_id
714     print 'ID_LIKE=' + test_ssh.distro_id_like
715     print 'VERSION_ID=' + test_ssh.distro_version
716
717     # ssh.wait()
718     # print ssh.pidof('bash')
719     # print ssh.stat('/tmp')
720     print test_ssh.check_openstack_version()
721     print test_ssh.get_cpu_info()
722     print test_ssh.get_l2agent_version("Open vSwitch agent")
723
724 if __name__ == "__main__":
725     main()