X-Git-Url: https://gerrit.opnfv.org/gerrit/gitweb?a=blobdiff_plain;f=yardstick%2Fssh.py;h=d7adc0d05580b69823d5c74d7cbc94692581e408;hb=7411096e94f7e30f03c73cc0b8d66dbd076fc37e;hp=a024cf64ae58ca35cf1506ec57530cf1c01d8f5c;hpb=6b762e9edb2f7913348d99c1c240fe94b9f7cbf9;p=yardstick.git diff --git a/yardstick/ssh.py b/yardstick/ssh.py index a024cf64a..d7adc0d05 100644 --- a/yardstick/ssh.py +++ b/yardstick/ssh.py @@ -64,6 +64,7 @@ Eventlet: """ from __future__ import absolute_import import os +import io import select import socket import time @@ -77,10 +78,18 @@ from oslo_utils import encodeutils from scp import SCPClient import six -from yardstick.common.utils import try_int +from yardstick.common.utils import try_int, NON_NONE_DEFAULT, make_dict_from_map from yardstick.network_services.utils import provision_tool +def convert_key_to_str(key): + if not isinstance(key, (paramiko.RSAKey, paramiko.DSSKey)): + return key + k = io.StringIO() + key.write_private_key(k) + return k.getvalue() + + class SSHError(Exception): pass @@ -93,6 +102,7 @@ class SSH(object): """Represent ssh connection.""" SSH_PORT = paramiko.config.SSH_PORT + DEFAULT_WAIT_TIMEOUT = 120 @staticmethod def gen_keys(key_filename, bit_count=2048): @@ -111,6 +121,18 @@ class SSH(object): # i.e. the subclass, not the superclass return SSH + @classmethod + def get_arg_key_map(cls): + return { + 'user': ('user', NON_NONE_DEFAULT), + 'host': ('ip', NON_NONE_DEFAULT), + 'port': ('ssh_port', cls.SSH_PORT), + 'pkey': ('pkey', None), + 'key_filename': ('key_filename', None), + 'password': ('password', None), + 'name': ('name', None), + } + def __init__(self, user, host, port=None, pkey=None, key_filename=None, password=None, name=None): """Initialize SSH client. @@ -128,6 +150,7 @@ class SSH(object): else: self.log = logging.getLogger(__name__) + self.wait_timeout = self.DEFAULT_WAIT_TIMEOUT self.user = user self.host = host # everybody wants to debug this in the caller, do it here instead @@ -153,16 +176,9 @@ class SSH(object): overrides = {} if defaults is None: defaults = {} + params = ChainMap(overrides, node, defaults) - return { - 'user': params['user'], - 'host': params['ip'], - 'port': params.get('ssh_port', cls.SSH_PORT), - 'pkey': params.get('pkey'), - 'key_filename': params.get('key_filename'), - 'password': params.get('password'), - 'name': params.get('name'), - } + return make_dict_from_map(params, cls.get_arg_key_map()) @classmethod def from_node(cls, node, overrides=None, defaults=None): @@ -177,7 +193,7 @@ class SSH(object): return key_class.from_private_key(key) except paramiko.SSHException as e: errors.append(e) - raise SSHError("Invalid pkey: %s" % (errors)) + raise SSHError("Invalid pkey: %s" % errors) @property def is_connected(self): @@ -278,7 +294,7 @@ class SSH(object): while True: # Block until data can be read/write. - r, w, e = select.select([session], writes, [session], 1) + e = select.select([session], writes, [session], 1)[2] if session.recv_ready(): data = encodeutils.safe_decode(session.recv(4096), 'utf-8') @@ -352,17 +368,20 @@ class SSH(object): stderr.seek(0) return exit_status, stdout.read(), stderr.read() - def wait(self, timeout=120, interval=1): + def wait(self, timeout=None, interval=1): """Wait for the host will be available via ssh.""" - start_time = time.time() + if timeout is None: + timeout = self.wait_timeout + + end_time = time.time() + timeout while True: try: return self.execute("uname") except (socket.error, SSHError) as e: self.log.debug("Ssh is still unavailable: %r", e) time.sleep(interval) - if time.time() > (start_time + timeout): - raise SSHTimeout("Timeout waiting for '%s'", self.host) + if time.time() > end_time: + raise SSHTimeout("Timeout waiting for '%s'" % self.host) def put(self, files, remote_path=b'.', recursive=False): client = self._get_client() @@ -370,6 +389,12 @@ class SSH(object): with SCPClient(client.get_transport()) as scp: scp.put(files, remote_path, recursive) + def get(self, remote_path, local_path='/tmp/', recursive=True): + client = self._get_client() + + with SCPClient(client.get_transport()) as scp: + scp.get(remote_path, local_path, recursive) + # keep shell running in the background, e.g. screen def send_command(self, command): client = self._get_client() @@ -432,23 +457,40 @@ class SSH(object): class AutoConnectSSH(SSH): + @classmethod + def get_arg_key_map(cls): + arg_key_map = super(AutoConnectSSH, cls).get_arg_key_map() + arg_key_map['wait'] = ('wait', True) + return arg_key_map + + # always wait or we will get OpenStack SSH errors def __init__(self, user, host, port=None, pkey=None, - key_filename=None, password=None, name=None, wait=False): + key_filename=None, password=None, name=None, wait=True): super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name) - self._wait = wait + if wait and wait is not True: + self.wait_timeout = int(wait) def _make_dict(self): data = super(AutoConnectSSH, self)._make_dict() data.update({ - 'wait': self._wait + 'wait': self.wait_timeout }) return data def _connect(self): if not self.is_connected: - self._get_client() - if self._wait: - self.wait() + interval = 1 + timeout = self.wait_timeout + + end_time = time.time() + timeout + while True: + try: + return self._get_client() + except (socket.error, SSHError) as e: + self.log.debug("Ssh is still unavailable: %r", e) + time.sleep(interval) + if time.time() > end_time: + raise SSHTimeout("Timeout waiting for '%s'" % self.host) def drop_connection(self): """ Don't close anything, just force creation of a new client """