Merge "ping bottlenecks failed when security group rule do not support ipv6 - dovetai...
[yardstick.git] / yardstick / ssh.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 # yardstick comment: this is a modified copy of rally/rally/common/sshutils.py
17
18 """High level ssh library.
19
20 Usage examples:
21
22 Execute command and get output:
23
24     ssh = sshclient.SSH("root", "example.com", port=33)
25     status, stdout, stderr = ssh.execute("ps ax")
26     if status:
27         raise Exception("Command failed with non-zero status.")
28     print(stdout.splitlines())
29
30 Execute command with huge output:
31
32     class PseudoFile(io.RawIOBase):
33         def write(chunk):
34             if "error" in chunk:
35                 email_admin(chunk)
36
37     ssh = SSH("root", "example.com")
38     with PseudoFile() as p:
39         ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
40
41 Execute local script on remote side:
42
43     ssh = sshclient.SSH("user", "example.com")
44
45     with open("~/myscript.sh", "r") as stdin_file:
46         status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
47                                        stdin=stdin_file)
48
49 Upload file:
50
51     ssh = SSH("user", "example.com")
52     # use rb for binary files
53     with open("/store/file.gz", "rb") as stdin_file:
54         ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
55
56 Eventlet:
57
58     eventlet.monkey_patch(select=True, time=True)
59     or
60     eventlet.monkey_patch()
61     or
62     sshclient = eventlet.import_patched("yardstick.ssh")
63
64 """
65 import io
66 import logging
67 import os
68 import re
69 import select
70 import socket
71 import time
72
73 import paramiko
74 from chainmap import ChainMap
75 from oslo_utils import encodeutils
76 from scp import SCPClient
77 import six
78
79 from yardstick.common import exceptions
80 from yardstick.common.utils import try_int, NON_NONE_DEFAULT, make_dict_from_map
81 from yardstick.network_services.utils import provision_tool
82
83
84 def convert_key_to_str(key):
85     if not isinstance(key, (paramiko.RSAKey, paramiko.DSSKey)):
86         return key
87     k = io.StringIO()
88     key.write_private_key(k)
89     return k.getvalue()
90
91
92 # class SSHError(Exception):
93 #     pass
94 #
95 #
96 # class SSHTimeout(SSHError):
97 #     pass
98
99
100 class SSH(object):
101     """Represent ssh connection."""
102
103     SSH_PORT = paramiko.config.SSH_PORT
104     DEFAULT_WAIT_TIMEOUT = 120
105
106     @staticmethod
107     def gen_keys(key_filename, bit_count=2048):
108         rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None)
109         rsa_key.write_private_key_file(key_filename)
110         print("Writing %s ..." % key_filename)
111         with open('.'.join([key_filename, "pub"]), "w") as pubkey_file:
112             pubkey_file.write(rsa_key.get_name())
113             pubkey_file.write(' ')
114             pubkey_file.write(rsa_key.get_base64())
115             pubkey_file.write('\n')
116
117     @staticmethod
118     def get_class():
119         # must return static class name, anything else refers to the calling class
120         # i.e. the subclass, not the superclass
121         return SSH
122
123     @classmethod
124     def get_arg_key_map(cls):
125         return {
126             'user': ('user', NON_NONE_DEFAULT),
127             'host': ('ip', NON_NONE_DEFAULT),
128             'port': ('ssh_port', cls.SSH_PORT),
129             'pkey': ('pkey', None),
130             'key_filename': ('key_filename', None),
131             'password': ('password', None),
132             'name': ('name', None),
133         }
134
135     def __init__(self, user, host, port=None, pkey=None,
136                  key_filename=None, password=None, name=None):
137         """Initialize SSH client.
138
139         :param user: ssh username
140         :param host: hostname or ip address of remote ssh server
141         :param port: remote ssh port
142         :param pkey: RSA or DSS private key string or file object
143         :param key_filename: private key filename
144         :param password: password
145         """
146         self.name = name
147         if name:
148             self.log = logging.getLogger(__name__ + '.' + self.name)
149         else:
150             self.log = logging.getLogger(__name__)
151
152         self.wait_timeout = self.DEFAULT_WAIT_TIMEOUT
153         self.user = user
154         self.host = host
155         # everybody wants to debug this in the caller, do it here instead
156         self.log.debug("user:%s host:%s", user, host)
157
158         # we may get text port from YAML, convert to int
159         self.port = try_int(port, self.SSH_PORT)
160         self.pkey = self._get_pkey(pkey) if pkey else None
161         self.password = password
162         self.key_filename = key_filename
163         self._client = False
164         # paramiko loglevel debug will output ssh protocl debug
165         # we don't ever really want that unless we are debugging paramiko
166         # ssh issues
167         if os.environ.get("PARAMIKO_DEBUG", "").lower() == "true":
168             logging.getLogger("paramiko").setLevel(logging.DEBUG)
169         else:
170             logging.getLogger("paramiko").setLevel(logging.WARN)
171
172     @classmethod
173     def args_from_node(cls, node, overrides=None, defaults=None):
174         if overrides is None:
175             overrides = {}
176         if defaults is None:
177             defaults = {}
178
179         params = ChainMap(overrides, node, defaults)
180         return make_dict_from_map(params, cls.get_arg_key_map())
181
182     @classmethod
183     def from_node(cls, node, overrides=None, defaults=None):
184         return cls(**cls.args_from_node(node, overrides, defaults))
185
186     def _get_pkey(self, key):
187         if isinstance(key, six.string_types):
188             key = six.moves.StringIO(key)
189         errors = []
190         for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
191             try:
192                 return key_class.from_private_key(key)
193             except paramiko.SSHException as e:
194                 errors.append(e)
195         raise exceptions.SSHError(error_msg='Invalid pkey: %s' % errors)
196
197     @property
198     def is_connected(self):
199         return bool(self._client)
200
201     def _get_client(self):
202         if self.is_connected:
203             return self._client
204         try:
205             self._client = paramiko.SSHClient()
206             self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
207             self._client.connect(self.host, username=self.user,
208                                  port=self.port, pkey=self.pkey,
209                                  key_filename=self.key_filename,
210                                  password=self.password,
211                                  allow_agent=False, look_for_keys=False,
212                                  timeout=1)
213             return self._client
214         except Exception as e:
215             message = ("Exception %(exception_type)s was raised "
216                        "during connect. Exception value is: %(exception)r" %
217                        {"exception": e, "exception_type": type(e)})
218             self._client = False
219             raise exceptions.SSHError(error_msg=message)
220
221     def _make_dict(self):
222         return {
223             'user': self.user,
224             'host': self.host,
225             'port': self.port,
226             'pkey': self.pkey,
227             'key_filename': self.key_filename,
228             'password': self.password,
229             'name': self.name,
230         }
231
232     def copy(self):
233         return self.get_class()(**self._make_dict())
234
235     def close(self):
236         if self._client:
237             self._client.close()
238             self._client = False
239
240     def run(self, cmd, stdin=None, stdout=None, stderr=None,
241             raise_on_error=True, timeout=3600,
242             keep_stdin_open=False, pty=False):
243         """Execute specified command on the server.
244
245         :param cmd:             Command to be executed.
246         :type cmd:              str
247         :param stdin:           Open file or string to pass to stdin.
248         :param stdout:          Open file to connect to stdout.
249         :param stderr:          Open file to connect to stderr.
250         :param raise_on_error:  If False then exit code will be return. If True
251                                 then exception will be raized if non-zero code.
252         :param timeout:         Timeout in seconds for command execution.
253                                 Default 1 hour. No timeout if set to 0.
254         :param keep_stdin_open: don't close stdin on empty reads
255         :type keep_stdin_open:  bool
256         :param pty:             Request a pseudo terminal for this connection.
257                                 This allows passing control characters.
258                                 Default False.
259         :type pty:              bool
260         """
261
262         client = self._get_client()
263
264         if isinstance(stdin, six.string_types):
265             stdin = six.moves.StringIO(stdin)
266
267         return self._run(client, cmd, stdin=stdin, stdout=stdout,
268                          stderr=stderr, raise_on_error=raise_on_error,
269                          timeout=timeout,
270                          keep_stdin_open=keep_stdin_open, pty=pty)
271
272     def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
273              raise_on_error=True, timeout=3600,
274              keep_stdin_open=False, pty=False):
275
276         transport = client.get_transport()
277         session = transport.open_session()
278         if pty:
279             session.get_pty()
280         session.exec_command(cmd)
281         start_time = time.time()
282
283         # encode on transmit, decode on receive
284         data_to_send = encodeutils.safe_encode("", incoming='utf-8')
285         stderr_data = None
286
287         # If we have data to be sent to stdin then `select' should also
288         # check for stdin availability.
289         if stdin and not stdin.closed:
290             writes = [session]
291         else:
292             writes = []
293
294         while True:
295             # Block until data can be read/write.
296             e = select.select([session], writes, [session], 1)[2]
297
298             if session.recv_ready():
299                 data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
300                 self.log.debug("stdout: %r", data)
301                 if stdout is not None:
302                     stdout.write(data)
303                 continue
304
305             if session.recv_stderr_ready():
306                 stderr_data = encodeutils.safe_decode(
307                     session.recv_stderr(4096), 'utf-8')
308                 self.log.debug("stderr: %r", stderr_data)
309                 if stderr is not None:
310                     stderr.write(stderr_data)
311                 continue
312
313             if session.send_ready():
314                 if stdin is not None and not stdin.closed:
315                     if not data_to_send:
316                         stdin_txt = stdin.read(4096)
317                         if stdin_txt is None:
318                             stdin_txt = ''
319                         data_to_send = encodeutils.safe_encode(
320                             stdin_txt, incoming='utf-8')
321                         if not data_to_send:
322                             # we may need to keep stdin open
323                             if not keep_stdin_open:
324                                 stdin.close()
325                                 session.shutdown_write()
326                                 writes = []
327                     if data_to_send:
328                         sent_bytes = session.send(data_to_send)
329                         # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
330                         data_to_send = data_to_send[sent_bytes:]
331
332             if session.exit_status_ready():
333                 break
334
335             if timeout and (time.time() - timeout) > start_time:
336                 message = ('Timeout executing command %(cmd)s on host %(host)s'
337                            % {"cmd": cmd, "host": self.host})
338                 raise exceptions.SSHTimeout(error_msg=message)
339             if e:
340                 raise exceptions.SSHError(error_msg='Socket error')
341
342         exit_status = session.recv_exit_status()
343         if exit_status != 0 and raise_on_error:
344             fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
345             details = fmt % {"cmd": cmd, "status": exit_status}
346             if stderr_data:
347                 details += " Last stderr data: '%s'." % stderr_data
348             raise exceptions.SSHError(error_msg=details)
349         return exit_status
350
351     def execute(self, cmd, stdin=None, timeout=3600):
352         """Execute the specified command on the server.
353
354         :param cmd:     Command to be executed.
355         :param stdin:   Open file to be sent on process stdin.
356         :param timeout: Timeout for execution of the command.
357
358         :returns: tuple (exit_status, stdout, stderr)
359         """
360         stdout = six.moves.StringIO()
361         stderr = six.moves.StringIO()
362
363         exit_status = self.run(cmd, stderr=stderr,
364                                stdout=stdout, stdin=stdin,
365                                timeout=timeout, raise_on_error=False)
366         stdout.seek(0)
367         stderr.seek(0)
368         return exit_status, stdout.read(), stderr.read()
369
370     def wait(self, timeout=None, interval=1):
371         """Wait for the host will be available via ssh."""
372         if timeout is None:
373             timeout = self.wait_timeout
374
375         end_time = time.time() + timeout
376         while True:
377             try:
378                 return self.execute("uname")
379             except (socket.error, exceptions.SSHError) as e:
380                 self.log.debug("Ssh is still unavailable: %r", e)
381                 time.sleep(interval)
382             if time.time() > end_time:
383                 raise exceptions.SSHTimeout(
384                     error_msg='Timeout waiting for "%s"' % self.host)
385
386     def put(self, files, remote_path=b'.', recursive=False):
387         client = self._get_client()
388
389         with SCPClient(client.get_transport()) as scp:
390             scp.put(files, remote_path, recursive)
391
392     def get(self, remote_path, local_path='/tmp/', recursive=True):
393         client = self._get_client()
394
395         with SCPClient(client.get_transport()) as scp:
396             scp.get(remote_path, local_path, recursive)
397
398     # keep shell running in the background, e.g. screen
399     def send_command(self, command):
400         client = self._get_client()
401         client.exec_command(command, get_pty=True)
402
403     def _put_file_sftp(self, localpath, remotepath, mode=None):
404         client = self._get_client()
405
406         with client.open_sftp() as sftp:
407             sftp.put(localpath, remotepath)
408             if mode is None:
409                 mode = 0o777 & os.stat(localpath).st_mode
410             sftp.chmod(remotepath, mode)
411
412     TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
413
414     def _put_file_shell(self, localpath, remotepath, mode=None):
415         # quote to stop wordpslit
416         tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
417         if not tilde:
418             tilde = ''
419         cmd = ['cat > %s"%s"' % (tilde, remotepath)]
420         if mode is not None:
421             # use -- so no options
422             cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
423
424         with open(localpath, "rb") as localfile:
425             # only chmod on successful cat
426             self.run("&& ".join(cmd), stdin=localfile)
427
428     def put_file(self, localpath, remotepath, mode=None):
429         """Copy specified local file to the server.
430
431         :param localpath:   Local filename.
432         :param remotepath:  Remote filename.
433         :param mode:        Permissions to set after upload
434         """
435         try:
436             self._put_file_sftp(localpath, remotepath, mode=mode)
437         except (paramiko.SSHException, socket.error):
438             self._put_file_shell(localpath, remotepath, mode=mode)
439
440     def provision_tool(self, tool_path, tool_file=None):
441         return provision_tool(self, tool_path, tool_file)
442
443     def put_file_obj(self, file_obj, remotepath, mode=None):
444         client = self._get_client()
445
446         with client.open_sftp() as sftp:
447             sftp.putfo(file_obj, remotepath)
448             if mode is not None:
449                 sftp.chmod(remotepath, mode)
450
451     def get_file_obj(self, remotepath, file_obj):
452         client = self._get_client()
453
454         with client.open_sftp() as sftp:
455             sftp.getfo(remotepath, file_obj)
456
457
458 class AutoConnectSSH(SSH):
459
460     @classmethod
461     def get_arg_key_map(cls):
462         arg_key_map = super(AutoConnectSSH, cls).get_arg_key_map()
463         arg_key_map['wait'] = ('wait', True)
464         return arg_key_map
465
466     # always wait or we will get OpenStack SSH errors
467     def __init__(self, user, host, port=None, pkey=None,
468                  key_filename=None, password=None, name=None, wait=True):
469         super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
470         if wait and wait is not True:
471             self.wait_timeout = int(wait)
472
473     def _make_dict(self):
474         data = super(AutoConnectSSH, self)._make_dict()
475         data.update({
476             'wait': self.wait_timeout
477         })
478         return data
479
480     def _connect(self):
481         if not self.is_connected:
482             interval = 1
483             timeout = self.wait_timeout
484
485             end_time = time.time() + timeout
486             while True:
487                 try:
488                     return self._get_client()
489                 except (socket.error, exceptions.SSHError) as e:
490                     self.log.debug("Ssh is still unavailable: %r", e)
491                     time.sleep(interval)
492                 if time.time() > end_time:
493                     raise exceptions.SSHTimeout(
494                         error_msg='Timeout waiting for "%s"' % self.host)
495
496     def drop_connection(self):
497         """ Don't close anything, just force creation of a new client """
498         self._client = False
499
500     def execute(self, cmd, stdin=None, timeout=3600):
501         self._connect()
502         return super(AutoConnectSSH, self).execute(cmd, stdin, timeout)
503
504     def run(self, cmd, stdin=None, stdout=None, stderr=None,
505             raise_on_error=True, timeout=3600,
506             keep_stdin_open=False, pty=False):
507         self._connect()
508         return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
509                                                timeout, keep_stdin_open, pty)
510
511     def put(self, files, remote_path=b'.', recursive=False):
512         self._connect()
513         return super(AutoConnectSSH, self).put(files, remote_path, recursive)
514
515     def put_file(self, local_path, remote_path, mode=None):
516         self._connect()
517         return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
518
519     def put_file_obj(self, file_obj, remote_path, mode=None):
520         self._connect()
521         return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
522
523     def get_file_obj(self, remote_path, file_obj):
524         self._connect()
525         return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
526
527     def provision_tool(self, tool_path, tool_file=None):
528         self._connect()
529         return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
530
531     @staticmethod
532     def get_class():
533         # must return static class name, anything else refers to the calling class
534         # i.e. the subclass, not the superclass
535         return AutoConnectSSH