X-Git-Url: https://gerrit.opnfv.org/gerrit/gitweb?a=blobdiff_plain;f=yardstick%2Fssh.py;h=e6a26ab6b19548819123e6f4af86e0d9c530df1c;hb=6f450c7a8fc94d4db2e11da9708a8e1835f9233b;hp=cf9adf0dc6d5c994fbc57301b18e94a3fd459885;hpb=87559e992c1f68559950b258146a530c49db9df0;p=yardstick.git diff --git a/yardstick/ssh.py b/yardstick/ssh.py index cf9adf0dc..e6a26ab6b 100644 --- a/yardstick/ssh.py +++ b/yardstick/ssh.py @@ -62,14 +62,13 @@ Eventlet: sshclient = eventlet.import_patched("yardstick.ssh") """ -from __future__ import absolute_import +import io +import logging import os +import re import select import socket import time -import re - -import logging import paramiko from chainmap import ChainMap @@ -77,22 +76,63 @@ from oslo_utils import encodeutils from scp import SCPClient import six - -SSH_PORT = paramiko.config.SSH_PORT +from yardstick.common import exceptions +from yardstick.common.utils import try_int, NON_NONE_DEFAULT, make_dict_from_map +from yardstick.network_services.utils import provision_tool -class SSHError(Exception): - pass +def convert_key_to_str(key): + if not isinstance(key, (paramiko.RSAKey, paramiko.DSSKey)): + return key + k = io.StringIO() + key.write_private_key(k) + return k.getvalue() -class SSHTimeout(SSHError): - pass +# class SSHError(Exception): +# pass +# +# +# class SSHTimeout(SSHError): +# pass class SSH(object): """Represent ssh connection.""" - def __init__(self, user, host, port=SSH_PORT, pkey=None, + SSH_PORT = paramiko.config.SSH_PORT + DEFAULT_WAIT_TIMEOUT = 120 + + @staticmethod + def gen_keys(key_filename, bit_count=2048): + rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None) + rsa_key.write_private_key_file(key_filename) + print("Writing %s ..." % key_filename) + with open('.'.join([key_filename, "pub"]), "w") as pubkey_file: + pubkey_file.write(rsa_key.get_name()) + pubkey_file.write(' ') + pubkey_file.write(rsa_key.get_base64()) + pubkey_file.write('\n') + + @staticmethod + def get_class(): + # must return static class name, anything else refers to the calling class + # i.e. the subclass, not the superclass + return SSH + + @classmethod + def get_arg_key_map(cls): + return { + 'user': ('user', NON_NONE_DEFAULT), + 'host': ('ip', NON_NONE_DEFAULT), + 'port': ('ssh_port', cls.SSH_PORT), + 'pkey': ('pkey', None), + 'key_filename': ('key_filename', None), + 'password': ('password', None), + 'name': ('name', None), + } + + def __init__(self, user, host, port=None, pkey=None, key_filename=None, password=None, name=None): """Initialize SSH client. @@ -109,13 +149,14 @@ class SSH(object): else: self.log = logging.getLogger(__name__) + self.wait_timeout = self.DEFAULT_WAIT_TIMEOUT self.user = user self.host = host # everybody wants to debug this in the caller, do it here instead self.log.debug("user:%s host:%s", user, host) # we may get text port from YAML, convert to int - self.port = int(port) + self.port = try_int(port, self.SSH_PORT) self.pkey = self._get_pkey(pkey) if pkey else None self.password = password self.key_filename = key_filename @@ -129,21 +170,18 @@ class SSH(object): logging.getLogger("paramiko").setLevel(logging.WARN) @classmethod - def from_node(cls, node, overrides=None, defaults=None): + def args_from_node(cls, node, overrides=None, defaults=None): if overrides is None: overrides = {} if defaults is None: defaults = {} + params = ChainMap(overrides, node, defaults) - return cls( - user=params['user'], - host=params['ip'], - # paramiko doesn't like None default, requires SSH_PORT default - port=params.get('ssh_port', SSH_PORT), - pkey=params.get('pkey'), - key_filename=params.get('key_filename'), - password=params.get('password'), - name=params.get('name')) + return make_dict_from_map(params, cls.get_arg_key_map()) + + @classmethod + def from_node(cls, node, overrides=None, defaults=None): + return cls(**cls.args_from_node(node, overrides, defaults)) def _get_pkey(self, key): if isinstance(key, six.string_types): @@ -154,10 +192,14 @@ class SSH(object): return key_class.from_private_key(key) except paramiko.SSHException as e: errors.append(e) - raise SSHError("Invalid pkey: %s" % (errors)) + raise exceptions.SSHError(error_msg='Invalid pkey: %s' % errors) + + @property + def is_connected(self): + return bool(self._client) def _get_client(self): - if self._client: + if self.is_connected: return self._client try: self._client = paramiko.SSHClient() @@ -171,14 +213,29 @@ class SSH(object): return self._client except Exception as e: message = ("Exception %(exception_type)s was raised " - "during connect. Exception value is: %(exception)r") + "during connect. Exception value is: %(exception)r" % + {"exception": e, "exception_type": type(e)}) self._client = False - raise SSHError(message % {"exception": e, - "exception_type": type(e)}) + raise exceptions.SSHError(error_msg=message) + + def _make_dict(self): + return { + 'user': self.user, + 'host': self.host, + 'port': self.port, + 'pkey': self.pkey, + 'key_filename': self.key_filename, + 'password': self.password, + 'name': self.name, + } + + def copy(self): + return self.get_class()(**self._make_dict()) def close(self): - self._client.close() - self._client = False + if self._client: + self._client.close() + self._client = False def run(self, cmd, stdin=None, stdout=None, stderr=None, raise_on_error=True, timeout=3600, @@ -236,7 +293,7 @@ class SSH(object): while True: # Block until data can be read/write. - r, w, e = select.select([session], writes, [session], 1) + e = select.select([session], writes, [session], 1)[2] if session.recv_ready(): data = encodeutils.safe_decode(session.recv(4096), 'utf-8') @@ -276,11 +333,11 @@ class SSH(object): break if timeout and (time.time() - timeout) > start_time: - args = {"cmd": cmd, "host": self.host} - raise SSHTimeout("Timeout executing command " - "'%(cmd)s' on host %(host)s" % args) + message = ('Timeout executing command %(cmd)s on host %(host)s' + % {"cmd": cmd, "host": self.host}) + raise exceptions.SSHTimeout(error_msg=message) if e: - raise SSHError("Socket error.") + raise exceptions.SSHError(error_msg='Socket error') exit_status = session.recv_exit_status() if exit_status != 0 and raise_on_error: @@ -288,15 +345,17 @@ class SSH(object): details = fmt % {"cmd": cmd, "status": exit_status} if stderr_data: details += " Last stderr data: '%s'." % stderr_data - raise SSHError(details) + raise exceptions.SSHError(error_msg=details) return exit_status - def execute(self, cmd, stdin=None, timeout=3600): + def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False): """Execute the specified command on the server. - :param cmd: Command to be executed. - :param stdin: Open file to be sent on process stdin. - :param timeout: Timeout for execution of the command. + :param cmd: (str) Command to be executed. + :param stdin: (StringIO) Open file to be sent on process stdin. + :param timeout: (int) Timeout for execution of the command. + :param raise_on_error: (bool) If True, then an SSHError will be raised + when non-zero exit code. :returns: tuple (exit_status, stdout, stderr) """ @@ -305,22 +364,26 @@ class SSH(object): exit_status = self.run(cmd, stderr=stderr, stdout=stdout, stdin=stdin, - timeout=timeout, raise_on_error=False) + timeout=timeout, raise_on_error=raise_on_error) stdout.seek(0) stderr.seek(0) - return (exit_status, stdout.read(), stderr.read()) + return exit_status, stdout.read(), stderr.read() - def wait(self, timeout=120, interval=1): + def wait(self, timeout=None, interval=1): """Wait for the host will be available via ssh.""" - start_time = time.time() + if timeout is None: + timeout = self.wait_timeout + + end_time = time.time() + timeout while True: try: return self.execute("uname") - except (socket.error, SSHError) as e: + except (socket.error, exceptions.SSHError) as e: self.log.debug("Ssh is still unavailable: %r", e) time.sleep(interval) - if time.time() > (start_time + timeout): - raise SSHTimeout("Timeout waiting for '%s'", self.host) + if time.time() > end_time: + raise exceptions.SSHTimeout( + error_msg='Timeout waiting for "%s"' % self.host) def put(self, files, remote_path=b'.', recursive=False): client = self._get_client() @@ -328,6 +391,12 @@ class SSH(object): with SCPClient(client.get_transport()) as scp: scp.put(files, remote_path, recursive) + def get(self, remote_path, local_path='/tmp/', recursive=True): + client = self._get_client() + + with SCPClient(client.get_transport()) as scp: + scp.get(remote_path, local_path, recursive) + # keep shell running in the background, e.g. screen def send_command(self, command): client = self._get_client() @@ -369,3 +438,100 @@ class SSH(object): self._put_file_sftp(localpath, remotepath, mode=mode) except (paramiko.SSHException, socket.error): self._put_file_shell(localpath, remotepath, mode=mode) + + def provision_tool(self, tool_path, tool_file=None): + return provision_tool(self, tool_path, tool_file) + + def put_file_obj(self, file_obj, remotepath, mode=None): + client = self._get_client() + + with client.open_sftp() as sftp: + sftp.putfo(file_obj, remotepath) + if mode is not None: + sftp.chmod(remotepath, mode) + + def get_file_obj(self, remotepath, file_obj): + client = self._get_client() + + with client.open_sftp() as sftp: + sftp.getfo(remotepath, file_obj) + + +class AutoConnectSSH(SSH): + + @classmethod + def get_arg_key_map(cls): + arg_key_map = super(AutoConnectSSH, cls).get_arg_key_map() + arg_key_map['wait'] = ('wait', True) + return arg_key_map + + # always wait or we will get OpenStack SSH errors + def __init__(self, user, host, port=None, pkey=None, + key_filename=None, password=None, name=None, wait=True): + super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name) + if wait and wait is not True: + self.wait_timeout = int(wait) + + def _make_dict(self): + data = super(AutoConnectSSH, self)._make_dict() + data.update({ + 'wait': self.wait_timeout + }) + return data + + def _connect(self): + if not self.is_connected: + interval = 1 + timeout = self.wait_timeout + + end_time = time.time() + timeout + while True: + try: + return self._get_client() + except (socket.error, exceptions.SSHError) as e: + self.log.debug("Ssh is still unavailable: %r", e) + time.sleep(interval) + if time.time() > end_time: + raise exceptions.SSHTimeout( + error_msg='Timeout waiting for "%s"' % self.host) + + def drop_connection(self): + """ Don't close anything, just force creation of a new client """ + self._client = False + + def execute(self, cmd, stdin=None, timeout=3600): + self._connect() + return super(AutoConnectSSH, self).execute(cmd, stdin, timeout) + + def run(self, cmd, stdin=None, stdout=None, stderr=None, + raise_on_error=True, timeout=3600, + keep_stdin_open=False, pty=False): + self._connect() + return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error, + timeout, keep_stdin_open, pty) + + def put(self, files, remote_path=b'.', recursive=False): + self._connect() + return super(AutoConnectSSH, self).put(files, remote_path, recursive) + + def put_file(self, local_path, remote_path, mode=None): + self._connect() + return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode) + + def put_file_obj(self, file_obj, remote_path, mode=None): + self._connect() + return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode) + + def get_file_obj(self, remote_path, file_obj): + self._connect() + return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj) + + def provision_tool(self, tool_path, tool_file=None): + self._connect() + return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file) + + @staticmethod + def get_class(): + # must return static class name, anything else refers to the calling class + # i.e. the subclass, not the superclass + return AutoConnectSSH