Merge "Unify Firewall testcases TG and VNF names"
[yardstick.git] / yardstick / ssh.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 # yardstick comment: this is a modified copy of rally/rally/common/sshutils.py
17
18 """High level ssh library.
19
20 Usage examples:
21
22 Execute command and get output:
23
24     ssh = sshclient.SSH("root", "example.com", port=33)
25     status, stdout, stderr = ssh.execute("ps ax")
26     if status:
27         raise Exception("Command failed with non-zero status.")
28     print(stdout.splitlines())
29
30 Execute command with huge output:
31
32     class PseudoFile(io.RawIOBase):
33         def write(chunk):
34             if "error" in chunk:
35                 email_admin(chunk)
36
37     ssh = SSH("root", "example.com")
38     with PseudoFile() as p:
39         ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
40
41 Execute local script on remote side:
42
43     ssh = sshclient.SSH("user", "example.com")
44
45     with open("~/myscript.sh", "r") as stdin_file:
46         status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
47                                        stdin=stdin_file)
48
49 Upload file:
50
51     ssh = SSH("user", "example.com")
52     # use rb for binary files
53     with open("/store/file.gz", "rb") as stdin_file:
54         ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
55
56 Eventlet:
57
58     eventlet.monkey_patch(select=True, time=True)
59     or
60     eventlet.monkey_patch()
61     or
62     sshclient = eventlet.import_patched("yardstick.ssh")
63
64 """
65 import io
66 import logging
67 import os
68 import re
69 import select
70 import socket
71 import time
72
73 import paramiko
74 from chainmap import ChainMap
75 from oslo_utils import encodeutils
76 from scp import SCPClient
77 import six
78
79 from yardstick.common import exceptions
80 from yardstick.common.utils import try_int, NON_NONE_DEFAULT, make_dict_from_map
81 from yardstick.network_services.utils import provision_tool
82
83
84 def convert_key_to_str(key):
85     if not isinstance(key, (paramiko.RSAKey, paramiko.DSSKey)):
86         return key
87     k = io.StringIO()
88     key.write_private_key(k)
89     return k.getvalue()
90
91
92 class SSH(object):
93     """Represent ssh connection."""
94
95     SSH_PORT = paramiko.config.SSH_PORT
96     DEFAULT_WAIT_TIMEOUT = 120
97
98     @staticmethod
99     def gen_keys(key_filename, bit_count=2048):
100         rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None)
101         rsa_key.write_private_key_file(key_filename)
102         print("Writing %s ..." % key_filename)
103         with open('.'.join([key_filename, "pub"]), "w") as pubkey_file:
104             pubkey_file.write(rsa_key.get_name())
105             pubkey_file.write(' ')
106             pubkey_file.write(rsa_key.get_base64())
107             pubkey_file.write('\n')
108
109     @staticmethod
110     def get_class():
111         # must return static class name, anything else refers to the calling class
112         # i.e. the subclass, not the superclass
113         return SSH
114
115     @classmethod
116     def get_arg_key_map(cls):
117         return {
118             'user': ('user', NON_NONE_DEFAULT),
119             'host': ('ip', NON_NONE_DEFAULT),
120             'port': ('ssh_port', cls.SSH_PORT),
121             'pkey': ('pkey', None),
122             'key_filename': ('key_filename', None),
123             'password': ('password', None),
124             'name': ('name', None),
125         }
126
127     def __init__(self, user, host, port=None, pkey=None,
128                  key_filename=None, password=None, name=None):
129         """Initialize SSH client.
130
131         :param user: ssh username
132         :param host: hostname or ip address of remote ssh server
133         :param port: remote ssh port
134         :param pkey: RSA or DSS private key string or file object
135         :param key_filename: private key filename
136         :param password: password
137         """
138         self.name = name
139         if name:
140             self.log = logging.getLogger(__name__ + '.' + self.name)
141         else:
142             self.log = logging.getLogger(__name__)
143
144         self.wait_timeout = self.DEFAULT_WAIT_TIMEOUT
145         self.user = user
146         self.host = host
147         # everybody wants to debug this in the caller, do it here instead
148         self.log.debug("user:%s host:%s", user, host)
149
150         # we may get text port from YAML, convert to int
151         self.port = try_int(port, self.SSH_PORT)
152         self.pkey = self._get_pkey(pkey) if pkey else None
153         self.password = password
154         self.key_filename = key_filename
155         self._client = False
156         # paramiko loglevel debug will output ssh protocl debug
157         # we don't ever really want that unless we are debugging paramiko
158         # ssh issues
159         if os.environ.get("PARAMIKO_DEBUG", "").lower() == "true":
160             logging.getLogger("paramiko").setLevel(logging.DEBUG)
161         else:
162             logging.getLogger("paramiko").setLevel(logging.WARN)
163
164     @classmethod
165     def args_from_node(cls, node, overrides=None, defaults=None):
166         if overrides is None:
167             overrides = {}
168         if defaults is None:
169             defaults = {}
170
171         params = ChainMap(overrides, node, defaults)
172         return make_dict_from_map(params, cls.get_arg_key_map())
173
174     @classmethod
175     def from_node(cls, node, overrides=None, defaults=None):
176         return cls(**cls.args_from_node(node, overrides, defaults))
177
178     def _get_pkey(self, key):
179         if isinstance(key, six.string_types):
180             key = six.moves.StringIO(key)
181         errors = []
182         for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
183             try:
184                 return key_class.from_private_key(key)
185             except paramiko.SSHException as e:
186                 errors.append(e)
187         raise exceptions.SSHError(error_msg='Invalid pkey: %s' % errors)
188
189     @property
190     def is_connected(self):
191         return bool(self._client)
192
193     def _get_client(self):
194         if self.is_connected:
195             return self._client
196         try:
197             self._client = paramiko.SSHClient()
198             self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
199             self._client.connect(self.host, username=self.user,
200                                  port=self.port, pkey=self.pkey,
201                                  key_filename=self.key_filename,
202                                  password=self.password,
203                                  allow_agent=False, look_for_keys=False,
204                                  timeout=1)
205             return self._client
206         except Exception as e:
207             message = ("Exception %(exception_type)s was raised "
208                        "during connect. Exception value is: %(exception)r" %
209                        {"exception": e, "exception_type": type(e)})
210             self._client = False
211             raise exceptions.SSHError(error_msg=message)
212
213     def _make_dict(self):
214         return {
215             'user': self.user,
216             'host': self.host,
217             'port': self.port,
218             'pkey': self.pkey,
219             'key_filename': self.key_filename,
220             'password': self.password,
221             'name': self.name,
222         }
223
224     def copy(self):
225         return self.get_class()(**self._make_dict())
226
227     def close(self):
228         if self._client:
229             self._client.close()
230             self._client = False
231
232     def run(self, cmd, stdin=None, stdout=None, stderr=None,
233             raise_on_error=True, timeout=3600,
234             keep_stdin_open=False, pty=False):
235         """Execute specified command on the server.
236
237         :param cmd:             Command to be executed.
238         :type cmd:              str
239         :param stdin:           Open file or string to pass to stdin.
240         :param stdout:          Open file to connect to stdout.
241         :param stderr:          Open file to connect to stderr.
242         :param raise_on_error:  If False then exit code will be return. If True
243                                 then exception will be raized if non-zero code.
244         :param timeout:         Timeout in seconds for command execution.
245                                 Default 1 hour. No timeout if set to 0.
246         :param keep_stdin_open: don't close stdin on empty reads
247         :type keep_stdin_open:  bool
248         :param pty:             Request a pseudo terminal for this connection.
249                                 This allows passing control characters.
250                                 Default False.
251         :type pty:              bool
252         """
253
254         client = self._get_client()
255
256         if isinstance(stdin, six.string_types):
257             stdin = six.moves.StringIO(stdin)
258
259         return self._run(client, cmd, stdin=stdin, stdout=stdout,
260                          stderr=stderr, raise_on_error=raise_on_error,
261                          timeout=timeout,
262                          keep_stdin_open=keep_stdin_open, pty=pty)
263
264     def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
265              raise_on_error=True, timeout=3600,
266              keep_stdin_open=False, pty=False):
267
268         transport = client.get_transport()
269         session = transport.open_session()
270         if pty:
271             session.get_pty()
272         session.exec_command(cmd)
273         start_time = time.time()
274
275         # encode on transmit, decode on receive
276         data_to_send = encodeutils.safe_encode("", incoming='utf-8')
277         stderr_data = None
278
279         # If we have data to be sent to stdin then `select' should also
280         # check for stdin availability.
281         if stdin and not stdin.closed:
282             writes = [session]
283         else:
284             writes = []
285
286         while True:
287             # Block until data can be read/write.
288             e = select.select([session], writes, [session], 1)[2]
289
290             if session.recv_ready():
291                 data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
292                 self.log.debug("stdout: %r", data)
293                 if stdout is not None:
294                     stdout.write(data)
295                 continue
296
297             if session.recv_stderr_ready():
298                 stderr_data = encodeutils.safe_decode(
299                     session.recv_stderr(4096), 'utf-8')
300                 self.log.debug("stderr: %r", stderr_data)
301                 if stderr is not None:
302                     stderr.write(stderr_data)
303                 continue
304
305             if session.send_ready():
306                 if stdin is not None and not stdin.closed:
307                     if not data_to_send:
308                         stdin_txt = stdin.read(4096)
309                         if stdin_txt is None:
310                             stdin_txt = ''
311                         data_to_send = encodeutils.safe_encode(
312                             stdin_txt, incoming='utf-8')
313                         if not data_to_send:
314                             # we may need to keep stdin open
315                             if not keep_stdin_open:
316                                 stdin.close()
317                                 session.shutdown_write()
318                                 writes = []
319                     if data_to_send:
320                         sent_bytes = session.send(data_to_send)
321                         # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
322                         data_to_send = data_to_send[sent_bytes:]
323
324             if session.exit_status_ready():
325                 break
326
327             if timeout and (time.time() - timeout) > start_time:
328                 message = ('Timeout executing command %(cmd)s on host %(host)s'
329                            % {"cmd": cmd, "host": self.host})
330                 raise exceptions.SSHTimeout(error_msg=message)
331             if e:
332                 raise exceptions.SSHError(error_msg='Socket error')
333
334         exit_status = session.recv_exit_status()
335         if exit_status != 0 and raise_on_error:
336             fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
337             details = fmt % {"cmd": cmd, "status": exit_status}
338             if stderr_data:
339                 details += " Last stderr data: '%s'." % stderr_data
340             raise exceptions.SSHError(error_msg=details)
341         return exit_status
342
343     def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False):
344         """Execute the specified command on the server.
345
346         :param cmd: (str)             Command to be executed.
347         :param stdin: (StringIO)      Open file to be sent on process stdin.
348         :param timeout: (int)         Timeout for execution of the command.
349         :param raise_on_error: (bool) If True, then an SSHError will be raised
350                                       when non-zero exit code.
351
352         :returns: tuple (exit_status, stdout, stderr)
353         """
354         stdout = six.moves.StringIO()
355         stderr = six.moves.StringIO()
356
357         exit_status = self.run(cmd, stderr=stderr,
358                                stdout=stdout, stdin=stdin,
359                                timeout=timeout, raise_on_error=raise_on_error)
360         stdout.seek(0)
361         stderr.seek(0)
362         return exit_status, stdout.read(), stderr.read()
363
364     def wait(self, timeout=None, interval=1):
365         """Wait for the host will be available via ssh."""
366         if timeout is None:
367             timeout = self.wait_timeout
368
369         end_time = time.time() + timeout
370         while True:
371             try:
372                 return self.execute("uname")
373             except (socket.error, exceptions.SSHError) as e:
374                 self.log.debug("Ssh is still unavailable: %r", e)
375                 time.sleep(interval)
376             if time.time() > end_time:
377                 raise exceptions.SSHTimeout(
378                     error_msg='Timeout waiting for "%s"' % self.host)
379
380     def put(self, files, remote_path=b'.', recursive=False):
381         client = self._get_client()
382
383         with SCPClient(client.get_transport()) as scp:
384             scp.put(files, remote_path, recursive)
385
386     def get(self, remote_path, local_path='/tmp/', recursive=True):
387         client = self._get_client()
388
389         with SCPClient(client.get_transport()) as scp:
390             scp.get(remote_path, local_path, recursive)
391
392     # keep shell running in the background, e.g. screen
393     def send_command(self, command):
394         client = self._get_client()
395         client.exec_command(command, get_pty=True)
396
397     def _put_file_sftp(self, localpath, remotepath, mode=None):
398         client = self._get_client()
399
400         with client.open_sftp() as sftp:
401             sftp.put(localpath, remotepath)
402             if mode is None:
403                 mode = 0o777 & os.stat(localpath).st_mode
404             sftp.chmod(remotepath, mode)
405
406     TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
407
408     def _put_file_shell(self, localpath, remotepath, mode=None):
409         # quote to stop wordpslit
410         tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
411         if not tilde:
412             tilde = ''
413         cmd = ['cat > %s"%s"' % (tilde, remotepath)]
414         if mode is not None:
415             # use -- so no options
416             cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
417
418         with open(localpath, "rb") as localfile:
419             # only chmod on successful cat
420             self.run("&& ".join(cmd), stdin=localfile)
421
422     def put_file(self, localpath, remotepath, mode=None):
423         """Copy specified local file to the server.
424
425         :param localpath:   Local filename.
426         :param remotepath:  Remote filename.
427         :param mode:        Permissions to set after upload
428         """
429         try:
430             self._put_file_sftp(localpath, remotepath, mode=mode)
431         except (paramiko.SSHException, socket.error):
432             self._put_file_shell(localpath, remotepath, mode=mode)
433
434     def provision_tool(self, tool_path, tool_file=None):
435         return provision_tool(self, tool_path, tool_file)
436
437     def put_file_obj(self, file_obj, remotepath, mode=None):
438         client = self._get_client()
439
440         with client.open_sftp() as sftp:
441             sftp.putfo(file_obj, remotepath)
442             if mode is not None:
443                 sftp.chmod(remotepath, mode)
444
445     def get_file_obj(self, remotepath, file_obj):
446         client = self._get_client()
447
448         with client.open_sftp() as sftp:
449             sftp.getfo(remotepath, file_obj)
450
451     def interactive_terminal_open(self, time_out=45):
452         """Open interactive terminal on a SSH channel.
453
454         :param time_out: Timeout in seconds.
455         :returns: SSH channel with opened terminal.
456
457         .. warning:: Interruptingcow is used here, and it uses
458            signal(SIGALRM) to let the operating system interrupt program
459            execution. This has the following limitations: Python signal
460            handlers only apply to the main thread, so you cannot use this
461            from other threads. You must not use this in a program that
462            uses SIGALRM itself (this includes certain profilers)
463         """
464         chan = self._get_client().get_transport().open_session()
465         chan.get_pty()
466         chan.invoke_shell()
467         chan.settimeout(int(time_out))
468         chan.set_combine_stderr(True)
469
470         buf = ''
471         while not buf.endswith((":~# ", ":~$ ", "~]$ ", "~]# ")):
472             try:
473                 chunk = chan.recv(10 * 1024 * 1024)
474                 if not chunk:
475                     break
476                 buf += chunk
477                 if chan.exit_status_ready():
478                     self.log.error('Channel exit status ready')
479                     break
480             except socket.timeout:
481                 raise exceptions.SSHTimeout(error_msg='Socket timeout: %s' % buf)
482         return chan
483
484     def interactive_terminal_exec_command(self, chan, cmd, prompt):
485         """Execute command on interactive terminal.
486
487         interactive_terminal_open() method has to be called first!
488
489         :param chan: SSH channel with opened terminal.
490         :param cmd: Command to be executed.
491         :param prompt: Command prompt, sequence of characters used to
492         indicate readiness to accept commands.
493         :returns: Command output.
494
495         .. warning:: Interruptingcow is used here, and it uses
496            signal(SIGALRM) to let the operating system interrupt program
497            execution. This has the following limitations: Python signal
498            handlers only apply to the main thread, so you cannot use this
499            from other threads. You must not use this in a program that
500            uses SIGALRM itself (this includes certain profilers)
501         """
502         chan.sendall('{c}\n'.format(c=cmd))
503         buf = ''
504         while not buf.endswith(prompt):
505             try:
506                 chunk = chan.recv(10 * 1024 * 1024)
507                 if not chunk:
508                     break
509                 buf += chunk
510                 if chan.exit_status_ready():
511                     self.log.error('Channel exit status ready')
512                     break
513             except socket.timeout:
514                 message = ("Socket timeout during execution of command: "
515                            "%(cmd)s\nBuffer content:\n%(buf)s" % {"cmd": cmd,
516                                                                   "buf": buf})
517                 raise exceptions.SSHTimeout(error_msg=message)
518         tmp = buf.replace(cmd.replace('\n', ''), '')
519         for item in prompt:
520             tmp.replace(item, '')
521         return tmp
522
523     @staticmethod
524     def interactive_terminal_close(chan):
525         """Close interactive terminal SSH channel.
526
527         :param: chan: SSH channel to be closed.
528         """
529         chan.close()
530
531
532 class AutoConnectSSH(SSH):
533
534     @classmethod
535     def get_arg_key_map(cls):
536         arg_key_map = super(AutoConnectSSH, cls).get_arg_key_map()
537         arg_key_map['wait'] = ('wait', True)
538         return arg_key_map
539
540     # always wait or we will get OpenStack SSH errors
541     def __init__(self, user, host, port=None, pkey=None,
542                  key_filename=None, password=None, name=None, wait=True):
543         super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
544         if wait and wait is not True:
545             self.wait_timeout = int(wait)
546
547     def _make_dict(self):
548         data = super(AutoConnectSSH, self)._make_dict()
549         data.update({
550             'wait': self.wait_timeout
551         })
552         return data
553
554     def _connect(self):
555         if not self.is_connected:
556             interval = 1
557             timeout = self.wait_timeout
558
559             end_time = time.time() + timeout
560             while True:
561                 try:
562                     return self._get_client()
563                 except (socket.error, exceptions.SSHError) as e:
564                     self.log.debug("Ssh is still unavailable: %r", e)
565                     time.sleep(interval)
566                 if time.time() > end_time:
567                     raise exceptions.SSHTimeout(
568                         error_msg='Timeout waiting for "%s"' % self.host)
569
570     def drop_connection(self):
571         """ Don't close anything, just force creation of a new client """
572         self._client = False
573
574     def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False):
575         self._connect()
576         return super(AutoConnectSSH, self).execute(cmd, stdin, timeout,
577                                                    raise_on_error)
578
579     def run(self, cmd, stdin=None, stdout=None, stderr=None,
580             raise_on_error=True, timeout=3600,
581             keep_stdin_open=False, pty=False):
582         self._connect()
583         return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
584                                                timeout, keep_stdin_open, pty)
585
586     def put(self, files, remote_path=b'.', recursive=False):
587         self._connect()
588         return super(AutoConnectSSH, self).put(files, remote_path, recursive)
589
590     def put_file(self, local_path, remote_path, mode=None):
591         self._connect()
592         return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
593
594     def put_file_obj(self, file_obj, remote_path, mode=None):
595         self._connect()
596         return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
597
598     def get_file_obj(self, remote_path, file_obj):
599         self._connect()
600         return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
601
602     def provision_tool(self, tool_path, tool_file=None):
603         self._connect()
604         return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
605
606     @staticmethod
607     def get_class():
608         # must return static class name, anything else refers to the calling class
609         # i.e. the subclass, not the superclass
610         return AutoConnectSSH