Merge "Implement Virtual Switch resilience test case"
[yardstick.git] / yardstick / ssh.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 # yardstick comment: this is a modified copy of rally/rally/common/sshutils.py
17
18 """High level ssh library.
19
20 Usage examples:
21
22 Execute command and get output:
23
24     ssh = sshclient.SSH("root", "example.com", port=33)
25     status, stdout, stderr = ssh.execute("ps ax")
26     if status:
27         raise Exception("Command failed with non-zero status.")
28     print(stdout.splitlines())
29
30 Execute command with huge output:
31
32     class PseudoFile(io.RawIOBase):
33         def write(chunk):
34             if "error" in chunk:
35                 email_admin(chunk)
36
37     ssh = SSH("root", "example.com")
38     with PseudoFile() as p:
39         ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
40
41 Execute local script on remote side:
42
43     ssh = sshclient.SSH("user", "example.com")
44
45     with open("~/myscript.sh", "r") as stdin_file:
46         status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
47                                        stdin=stdin_file)
48
49 Upload file:
50
51     ssh = SSH("user", "example.com")
52     # use rb for binary files
53     with open("/store/file.gz", "rb") as stdin_file:
54         ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
55
56 Eventlet:
57
58     eventlet.monkey_patch(select=True, time=True)
59     or
60     eventlet.monkey_patch()
61     or
62     sshclient = eventlet.import_patched("yardstick.ssh")
63
64 """
65 from __future__ import absolute_import
66 import os
67 import io
68 import select
69 import socket
70 import time
71 import re
72
73 import logging
74
75 import paramiko
76 from chainmap import ChainMap
77 from oslo_utils import encodeutils
78 from scp import SCPClient
79 import six
80
81 from yardstick.common.utils import try_int, NON_NONE_DEFAULT, make_dict_from_map
82 from yardstick.network_services.utils import provision_tool
83
84
85 def convert_key_to_str(key):
86     if not isinstance(key, (paramiko.RSAKey, paramiko.DSSKey)):
87         return key
88     k = io.StringIO()
89     key.write_private_key(k)
90     return k.getvalue()
91
92
93 class SSHError(Exception):
94     pass
95
96
97 class SSHTimeout(SSHError):
98     pass
99
100
101 class SSH(object):
102     """Represent ssh connection."""
103
104     SSH_PORT = paramiko.config.SSH_PORT
105     DEFAULT_WAIT_TIMEOUT = 120
106
107     @staticmethod
108     def gen_keys(key_filename, bit_count=2048):
109         rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None)
110         rsa_key.write_private_key_file(key_filename)
111         print("Writing %s ..." % key_filename)
112         with open('.'.join([key_filename, "pub"]), "w") as pubkey_file:
113             pubkey_file.write(rsa_key.get_name())
114             pubkey_file.write(' ')
115             pubkey_file.write(rsa_key.get_base64())
116             pubkey_file.write('\n')
117
118     @staticmethod
119     def get_class():
120         # must return static class name, anything else refers to the calling class
121         # i.e. the subclass, not the superclass
122         return SSH
123
124     @classmethod
125     def get_arg_key_map(cls):
126         return {
127             'user': ('user', NON_NONE_DEFAULT),
128             'host': ('ip', NON_NONE_DEFAULT),
129             'port': ('ssh_port', cls.SSH_PORT),
130             'pkey': ('pkey', None),
131             'key_filename': ('key_filename', None),
132             'password': ('password', None),
133             'name': ('name', None),
134         }
135
136     def __init__(self, user, host, port=None, pkey=None,
137                  key_filename=None, password=None, name=None):
138         """Initialize SSH client.
139
140         :param user: ssh username
141         :param host: hostname or ip address of remote ssh server
142         :param port: remote ssh port
143         :param pkey: RSA or DSS private key string or file object
144         :param key_filename: private key filename
145         :param password: password
146         """
147         self.name = name
148         if name:
149             self.log = logging.getLogger(__name__ + '.' + self.name)
150         else:
151             self.log = logging.getLogger(__name__)
152
153         self.wait_timeout = self.DEFAULT_WAIT_TIMEOUT
154         self.user = user
155         self.host = host
156         # everybody wants to debug this in the caller, do it here instead
157         self.log.debug("user:%s host:%s", user, host)
158
159         # we may get text port from YAML, convert to int
160         self.port = try_int(port, self.SSH_PORT)
161         self.pkey = self._get_pkey(pkey) if pkey else None
162         self.password = password
163         self.key_filename = key_filename
164         self._client = False
165         # paramiko loglevel debug will output ssh protocl debug
166         # we don't ever really want that unless we are debugging paramiko
167         # ssh issues
168         if os.environ.get("PARAMIKO_DEBUG", "").lower() == "true":
169             logging.getLogger("paramiko").setLevel(logging.DEBUG)
170         else:
171             logging.getLogger("paramiko").setLevel(logging.WARN)
172
173     @classmethod
174     def args_from_node(cls, node, overrides=None, defaults=None):
175         if overrides is None:
176             overrides = {}
177         if defaults is None:
178             defaults = {}
179
180         params = ChainMap(overrides, node, defaults)
181         return make_dict_from_map(params, cls.get_arg_key_map())
182
183     @classmethod
184     def from_node(cls, node, overrides=None, defaults=None):
185         return cls(**cls.args_from_node(node, overrides, defaults))
186
187     def _get_pkey(self, key):
188         if isinstance(key, six.string_types):
189             key = six.moves.StringIO(key)
190         errors = []
191         for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
192             try:
193                 return key_class.from_private_key(key)
194             except paramiko.SSHException as e:
195                 errors.append(e)
196         raise SSHError("Invalid pkey: %s" % errors)
197
198     @property
199     def is_connected(self):
200         return bool(self._client)
201
202     def _get_client(self):
203         if self.is_connected:
204             return self._client
205         try:
206             self._client = paramiko.SSHClient()
207             self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
208             self._client.connect(self.host, username=self.user,
209                                  port=self.port, pkey=self.pkey,
210                                  key_filename=self.key_filename,
211                                  password=self.password,
212                                  allow_agent=False, look_for_keys=False,
213                                  timeout=1)
214             return self._client
215         except Exception as e:
216             message = ("Exception %(exception_type)s was raised "
217                        "during connect. Exception value is: %(exception)r")
218             self._client = False
219             raise SSHError(message % {"exception": e,
220                                       "exception_type": type(e)})
221
222     def _make_dict(self):
223         return {
224             'user': self.user,
225             'host': self.host,
226             'port': self.port,
227             'pkey': self.pkey,
228             'key_filename': self.key_filename,
229             'password': self.password,
230             'name': self.name,
231         }
232
233     def copy(self):
234         return self.get_class()(**self._make_dict())
235
236     def close(self):
237         if self._client:
238             self._client.close()
239             self._client = False
240
241     def run(self, cmd, stdin=None, stdout=None, stderr=None,
242             raise_on_error=True, timeout=3600,
243             keep_stdin_open=False, pty=False):
244         """Execute specified command on the server.
245
246         :param cmd:             Command to be executed.
247         :type cmd:              str
248         :param stdin:           Open file or string to pass to stdin.
249         :param stdout:          Open file to connect to stdout.
250         :param stderr:          Open file to connect to stderr.
251         :param raise_on_error:  If False then exit code will be return. If True
252                                 then exception will be raized if non-zero code.
253         :param timeout:         Timeout in seconds for command execution.
254                                 Default 1 hour. No timeout if set to 0.
255         :param keep_stdin_open: don't close stdin on empty reads
256         :type keep_stdin_open:  bool
257         :param pty:             Request a pseudo terminal for this connection.
258                                 This allows passing control characters.
259                                 Default False.
260         :type pty:              bool
261         """
262
263         client = self._get_client()
264
265         if isinstance(stdin, six.string_types):
266             stdin = six.moves.StringIO(stdin)
267
268         return self._run(client, cmd, stdin=stdin, stdout=stdout,
269                          stderr=stderr, raise_on_error=raise_on_error,
270                          timeout=timeout,
271                          keep_stdin_open=keep_stdin_open, pty=pty)
272
273     def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
274              raise_on_error=True, timeout=3600,
275              keep_stdin_open=False, pty=False):
276
277         transport = client.get_transport()
278         session = transport.open_session()
279         if pty:
280             session.get_pty()
281         session.exec_command(cmd)
282         start_time = time.time()
283
284         # encode on transmit, decode on receive
285         data_to_send = encodeutils.safe_encode("", incoming='utf-8')
286         stderr_data = None
287
288         # If we have data to be sent to stdin then `select' should also
289         # check for stdin availability.
290         if stdin and not stdin.closed:
291             writes = [session]
292         else:
293             writes = []
294
295         while True:
296             # Block until data can be read/write.
297             e = select.select([session], writes, [session], 1)[2]
298
299             if session.recv_ready():
300                 data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
301                 self.log.debug("stdout: %r", data)
302                 if stdout is not None:
303                     stdout.write(data)
304                 continue
305
306             if session.recv_stderr_ready():
307                 stderr_data = encodeutils.safe_decode(
308                     session.recv_stderr(4096), 'utf-8')
309                 self.log.debug("stderr: %r", stderr_data)
310                 if stderr is not None:
311                     stderr.write(stderr_data)
312                 continue
313
314             if session.send_ready():
315                 if stdin is not None and not stdin.closed:
316                     if not data_to_send:
317                         stdin_txt = stdin.read(4096)
318                         if stdin_txt is None:
319                             stdin_txt = ''
320                         data_to_send = encodeutils.safe_encode(
321                             stdin_txt, incoming='utf-8')
322                         if not data_to_send:
323                             # we may need to keep stdin open
324                             if not keep_stdin_open:
325                                 stdin.close()
326                                 session.shutdown_write()
327                                 writes = []
328                     if data_to_send:
329                         sent_bytes = session.send(data_to_send)
330                         # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
331                         data_to_send = data_to_send[sent_bytes:]
332
333             if session.exit_status_ready():
334                 break
335
336             if timeout and (time.time() - timeout) > start_time:
337                 args = {"cmd": cmd, "host": self.host}
338                 raise SSHTimeout("Timeout executing command "
339                                  "'%(cmd)s' on host %(host)s" % args)
340             if e:
341                 raise SSHError("Socket error.")
342
343         exit_status = session.recv_exit_status()
344         if exit_status != 0 and raise_on_error:
345             fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
346             details = fmt % {"cmd": cmd, "status": exit_status}
347             if stderr_data:
348                 details += " Last stderr data: '%s'." % stderr_data
349             raise SSHError(details)
350         return exit_status
351
352     def execute(self, cmd, stdin=None, timeout=3600):
353         """Execute the specified command on the server.
354
355         :param cmd:     Command to be executed.
356         :param stdin:   Open file to be sent on process stdin.
357         :param timeout: Timeout for execution of the command.
358
359         :returns: tuple (exit_status, stdout, stderr)
360         """
361         stdout = six.moves.StringIO()
362         stderr = six.moves.StringIO()
363
364         exit_status = self.run(cmd, stderr=stderr,
365                                stdout=stdout, stdin=stdin,
366                                timeout=timeout, raise_on_error=False)
367         stdout.seek(0)
368         stderr.seek(0)
369         return exit_status, stdout.read(), stderr.read()
370
371     def wait(self, timeout=None, interval=1):
372         """Wait for the host will be available via ssh."""
373         if timeout is None:
374             timeout = self.wait_timeout
375
376         end_time = time.time() + timeout
377         while True:
378             try:
379                 return self.execute("uname")
380             except (socket.error, SSHError) as e:
381                 self.log.debug("Ssh is still unavailable: %r", e)
382                 time.sleep(interval)
383             if time.time() > end_time:
384                 raise SSHTimeout("Timeout waiting for '%s'" % self.host)
385
386     def put(self, files, remote_path=b'.', recursive=False):
387         client = self._get_client()
388
389         with SCPClient(client.get_transport()) as scp:
390             scp.put(files, remote_path, recursive)
391
392     def get(self, remote_path, local_path='/tmp/', recursive=True):
393         client = self._get_client()
394
395         with SCPClient(client.get_transport()) as scp:
396             scp.get(remote_path, local_path, recursive)
397
398     # keep shell running in the background, e.g. screen
399     def send_command(self, command):
400         client = self._get_client()
401         client.exec_command(command, get_pty=True)
402
403     def _put_file_sftp(self, localpath, remotepath, mode=None):
404         client = self._get_client()
405
406         with client.open_sftp() as sftp:
407             sftp.put(localpath, remotepath)
408             if mode is None:
409                 mode = 0o777 & os.stat(localpath).st_mode
410             sftp.chmod(remotepath, mode)
411
412     TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
413
414     def _put_file_shell(self, localpath, remotepath, mode=None):
415         # quote to stop wordpslit
416         tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
417         if not tilde:
418             tilde = ''
419         cmd = ['cat > %s"%s"' % (tilde, remotepath)]
420         if mode is not None:
421             # use -- so no options
422             cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
423
424         with open(localpath, "rb") as localfile:
425             # only chmod on successful cat
426             self.run("&& ".join(cmd), stdin=localfile)
427
428     def put_file(self, localpath, remotepath, mode=None):
429         """Copy specified local file to the server.
430
431         :param localpath:   Local filename.
432         :param remotepath:  Remote filename.
433         :param mode:        Permissions to set after upload
434         """
435         try:
436             self._put_file_sftp(localpath, remotepath, mode=mode)
437         except (paramiko.SSHException, socket.error):
438             self._put_file_shell(localpath, remotepath, mode=mode)
439
440     def provision_tool(self, tool_path, tool_file=None):
441         return provision_tool(self, tool_path, tool_file)
442
443     def put_file_obj(self, file_obj, remotepath, mode=None):
444         client = self._get_client()
445
446         with client.open_sftp() as sftp:
447             sftp.putfo(file_obj, remotepath)
448             if mode is not None:
449                 sftp.chmod(remotepath, mode)
450
451     def get_file_obj(self, remotepath, file_obj):
452         client = self._get_client()
453
454         with client.open_sftp() as sftp:
455             sftp.getfo(remotepath, file_obj)
456
457
458 class AutoConnectSSH(SSH):
459
460     @classmethod
461     def get_arg_key_map(cls):
462         arg_key_map = super(AutoConnectSSH, cls).get_arg_key_map()
463         arg_key_map['wait'] = ('wait', True)
464         return arg_key_map
465
466     # always wait or we will get OpenStack SSH errors
467     def __init__(self, user, host, port=None, pkey=None,
468                  key_filename=None, password=None, name=None, wait=True):
469         super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
470         if wait and wait is not True:
471             self.wait_timeout = int(wait)
472
473     def _make_dict(self):
474         data = super(AutoConnectSSH, self)._make_dict()
475         data.update({
476             'wait': self.wait_timeout
477         })
478         return data
479
480     def _connect(self):
481         if not self.is_connected:
482             interval = 1
483             timeout = self.wait_timeout
484
485             end_time = time.time() + timeout
486             while True:
487                 try:
488                     return self._get_client()
489                 except (socket.error, SSHError) as e:
490                     self.log.debug("Ssh is still unavailable: %r", e)
491                     time.sleep(interval)
492                 if time.time() > end_time:
493                     raise SSHTimeout("Timeout waiting for '%s'" % self.host)
494
495     def drop_connection(self):
496         """ Don't close anything, just force creation of a new client """
497         self._client = False
498
499     def execute(self, cmd, stdin=None, timeout=3600):
500         self._connect()
501         return super(AutoConnectSSH, self).execute(cmd, stdin, timeout)
502
503     def run(self, cmd, stdin=None, stdout=None, stderr=None,
504             raise_on_error=True, timeout=3600,
505             keep_stdin_open=False, pty=False):
506         self._connect()
507         return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
508                                                timeout, keep_stdin_open, pty)
509
510     def put(self, files, remote_path=b'.', recursive=False):
511         self._connect()
512         return super(AutoConnectSSH, self).put(files, remote_path, recursive)
513
514     def put_file(self, local_path, remote_path, mode=None):
515         self._connect()
516         return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
517
518     def put_file_obj(self, file_obj, remote_path, mode=None):
519         self._connect()
520         return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
521
522     def get_file_obj(self, remote_path, file_obj):
523         self._connect()
524         return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
525
526     def provision_tool(self, tool_path, tool_file=None):
527         self._connect()
528         return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
529
530     @staticmethod
531     def get_class():
532         # must return static class name, anything else refers to the calling class
533         # i.e. the subclass, not the superclass
534         return AutoConnectSSH