X-Git-Url: https://gerrit.opnfv.org/gerrit/gitweb?a=blobdiff_plain;f=yardstick%2Fssh.py;h=6b5e6faf45fbbe2f74af512ec2056ed3e4ecddbb;hb=efd922d81c99e44d6c8b39f4ea4668bf9871d02f;hp=6ddf327f213cd4ee8a850fcdfb200915b0c7daf5;hpb=6d7314e986c6359b59198d155a7baa31891e888f;p=yardstick.git diff --git a/yardstick/ssh.py b/yardstick/ssh.py index 6ddf327f2..6b5e6faf4 100644 --- a/yardstick/ssh.py +++ b/yardstick/ssh.py @@ -62,15 +62,13 @@ Eventlet: sshclient = eventlet.import_patched("yardstick.ssh") """ -from __future__ import absolute_import -import os import io +import logging +import os +import re import select import socket import time -import re - -import logging import paramiko from chainmap import ChainMap @@ -78,7 +76,8 @@ from oslo_utils import encodeutils from scp import SCPClient import six -from yardstick.common.utils import try_int +from yardstick.common import exceptions +from yardstick.common.utils import try_int, NON_NONE_DEFAULT, make_dict_from_map from yardstick.network_services.utils import provision_tool @@ -90,18 +89,19 @@ def convert_key_to_str(key): return k.getvalue() -class SSHError(Exception): - pass - - -class SSHTimeout(SSHError): - pass +# class SSHError(Exception): +# pass +# +# +# class SSHTimeout(SSHError): +# pass class SSH(object): """Represent ssh connection.""" SSH_PORT = paramiko.config.SSH_PORT + DEFAULT_WAIT_TIMEOUT = 120 @staticmethod def gen_keys(key_filename, bit_count=2048): @@ -120,6 +120,18 @@ class SSH(object): # i.e. the subclass, not the superclass return SSH + @classmethod + def get_arg_key_map(cls): + return { + 'user': ('user', NON_NONE_DEFAULT), + 'host': ('ip', NON_NONE_DEFAULT), + 'port': ('ssh_port', cls.SSH_PORT), + 'pkey': ('pkey', None), + 'key_filename': ('key_filename', None), + 'password': ('password', None), + 'name': ('name', None), + } + def __init__(self, user, host, port=None, pkey=None, key_filename=None, password=None, name=None): """Initialize SSH client. @@ -137,6 +149,7 @@ class SSH(object): else: self.log = logging.getLogger(__name__) + self.wait_timeout = self.DEFAULT_WAIT_TIMEOUT self.user = user self.host = host # everybody wants to debug this in the caller, do it here instead @@ -162,16 +175,9 @@ class SSH(object): overrides = {} if defaults is None: defaults = {} + params = ChainMap(overrides, node, defaults) - return { - 'user': params['user'], - 'host': params['ip'], - 'port': params.get('ssh_port', cls.SSH_PORT), - 'pkey': params.get('pkey'), - 'key_filename': params.get('key_filename'), - 'password': params.get('password'), - 'name': params.get('name'), - } + return make_dict_from_map(params, cls.get_arg_key_map()) @classmethod def from_node(cls, node, overrides=None, defaults=None): @@ -186,7 +192,7 @@ class SSH(object): return key_class.from_private_key(key) except paramiko.SSHException as e: errors.append(e) - raise SSHError("Invalid pkey: %s" % (errors)) + raise exceptions.SSHError(error_msg='Invalid pkey: %s' % errors) @property def is_connected(self): @@ -207,10 +213,10 @@ class SSH(object): return self._client except Exception as e: message = ("Exception %(exception_type)s was raised " - "during connect. Exception value is: %(exception)r") + "during connect. Exception value is: %(exception)r" % + {"exception": e, "exception_type": type(e)}) self._client = False - raise SSHError(message % {"exception": e, - "exception_type": type(e)}) + raise exceptions.SSHError(error_msg=message) def _make_dict(self): return { @@ -287,7 +293,7 @@ class SSH(object): while True: # Block until data can be read/write. - r, w, e = select.select([session], writes, [session], 1) + e = select.select([session], writes, [session], 1)[2] if session.recv_ready(): data = encodeutils.safe_decode(session.recv(4096), 'utf-8') @@ -327,11 +333,11 @@ class SSH(object): break if timeout and (time.time() - timeout) > start_time: - args = {"cmd": cmd, "host": self.host} - raise SSHTimeout("Timeout executing command " - "'%(cmd)s' on host %(host)s" % args) + message = ('Timeout executing command %(cmd)s on host %(host)s' + % {"cmd": cmd, "host": self.host}) + raise exceptions.SSHTimeout(error_msg=message) if e: - raise SSHError("Socket error.") + raise exceptions.SSHError(error_msg='Socket error') exit_status = session.recv_exit_status() if exit_status != 0 and raise_on_error: @@ -339,7 +345,7 @@ class SSH(object): details = fmt % {"cmd": cmd, "status": exit_status} if stderr_data: details += " Last stderr data: '%s'." % stderr_data - raise SSHError(details) + raise exceptions.SSHError(error_msg=details) return exit_status def execute(self, cmd, stdin=None, timeout=3600): @@ -361,17 +367,21 @@ class SSH(object): stderr.seek(0) return exit_status, stdout.read(), stderr.read() - def wait(self, timeout=120, interval=1): + def wait(self, timeout=None, interval=1): """Wait for the host will be available via ssh.""" - start_time = time.time() + if timeout is None: + timeout = self.wait_timeout + + end_time = time.time() + timeout while True: try: return self.execute("uname") - except (socket.error, SSHError) as e: + except (socket.error, exceptions.SSHError) as e: self.log.debug("Ssh is still unavailable: %r", e) time.sleep(interval) - if time.time() > (start_time + timeout): - raise SSHTimeout("Timeout waiting for '%s'", self.host) + if time.time() > end_time: + raise exceptions.SSHTimeout( + error_msg='Timeout waiting for "%s"' % self.host) def put(self, files, remote_path=b'.', recursive=False): client = self._get_client() @@ -447,24 +457,41 @@ class SSH(object): class AutoConnectSSH(SSH): + @classmethod + def get_arg_key_map(cls): + arg_key_map = super(AutoConnectSSH, cls).get_arg_key_map() + arg_key_map['wait'] = ('wait', True) + return arg_key_map + # always wait or we will get OpenStack SSH errors def __init__(self, user, host, port=None, pkey=None, key_filename=None, password=None, name=None, wait=True): super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name) - self._wait = wait + if wait and wait is not True: + self.wait_timeout = int(wait) def _make_dict(self): data = super(AutoConnectSSH, self)._make_dict() data.update({ - 'wait': self._wait + 'wait': self.wait_timeout }) return data def _connect(self): if not self.is_connected: - self._get_client() - if self._wait: - self.wait() + interval = 1 + timeout = self.wait_timeout + + end_time = time.time() + timeout + while True: + try: + return self._get_client() + except (socket.error, exceptions.SSHError) as e: + self.log.debug("Ssh is still unavailable: %r", e) + time.sleep(interval) + if time.time() > end_time: + raise exceptions.SSHTimeout( + error_msg='Timeout waiting for "%s"' % self.host) def drop_connection(self): """ Don't close anything, just force creation of a new client """