Refactor "utils.parse_ini_file" testing
[yardstick.git] / yardstick / ssh.py
index cfbc3ca..6b5e6fa 100644 (file)
@@ -62,35 +62,77 @@ Eventlet:
     sshclient = eventlet.import_patched("yardstick.ssh")
 
 """
-from __future__ import absolute_import
+import io
+import logging
 import os
+import re
 import select
 import socket
 import time
-import re
 
-import logging
 import paramiko
+from chainmap import ChainMap
 from oslo_utils import encodeutils
 from scp import SCPClient
 import six
 
-
-DEFAULT_PORT = 22
+from yardstick.common import exceptions
+from yardstick.common.utils import try_int, NON_NONE_DEFAULT, make_dict_from_map
+from yardstick.network_services.utils import provision_tool
 
 
-class SSHError(Exception):
-    pass
+def convert_key_to_str(key):
+    if not isinstance(key, (paramiko.RSAKey, paramiko.DSSKey)):
+        return key
+    k = io.StringIO()
+    key.write_private_key(k)
+    return k.getvalue()
 
 
-class SSHTimeout(SSHError):
-    pass
+# class SSHError(Exception):
+#     pass
+#
+#
+# class SSHTimeout(SSHError):
+#     pass
 
 
 class SSH(object):
     """Represent ssh connection."""
 
-    def __init__(self, user, host, port=DEFAULT_PORT, pkey=None,
+    SSH_PORT = paramiko.config.SSH_PORT
+    DEFAULT_WAIT_TIMEOUT = 120
+
+    @staticmethod
+    def gen_keys(key_filename, bit_count=2048):
+        rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None)
+        rsa_key.write_private_key_file(key_filename)
+        print("Writing %s ..." % key_filename)
+        with open('.'.join([key_filename, "pub"]), "w") as pubkey_file:
+            pubkey_file.write(rsa_key.get_name())
+            pubkey_file.write(' ')
+            pubkey_file.write(rsa_key.get_base64())
+            pubkey_file.write('\n')
+
+    @staticmethod
+    def get_class():
+        # must return static class name, anything else refers to the calling class
+        # i.e. the subclass, not the superclass
+        return SSH
+
+    @classmethod
+    def get_arg_key_map(cls):
+        return {
+            'user': ('user', NON_NONE_DEFAULT),
+            'host': ('ip', NON_NONE_DEFAULT),
+            'port': ('ssh_port', cls.SSH_PORT),
+            'pkey': ('pkey', None),
+            'key_filename': ('key_filename', None),
+            'password': ('password', None),
+            'name': ('name', None),
+        }
+
+    def __init__(self, user, host, port=None, pkey=None,
                  key_filename=None, password=None, name=None):
         """Initialize SSH client.
 
@@ -107,10 +149,14 @@ class SSH(object):
         else:
             self.log = logging.getLogger(__name__)
 
+        self.wait_timeout = self.DEFAULT_WAIT_TIMEOUT
         self.user = user
         self.host = host
+        # everybody wants to debug this in the caller, do it here instead
+        self.log.debug("user:%s host:%s", user, host)
+
         # we may get text port from YAML, convert to int
-        self.port = int(port)
+        self.port = try_int(port, self.SSH_PORT)
         self.pkey = self._get_pkey(pkey) if pkey else None
         self.password = password
         self.key_filename = key_filename
@@ -123,6 +169,20 @@ class SSH(object):
         else:
             logging.getLogger("paramiko").setLevel(logging.WARN)
 
+    @classmethod
+    def args_from_node(cls, node, overrides=None, defaults=None):
+        if overrides is None:
+            overrides = {}
+        if defaults is None:
+            defaults = {}
+
+        params = ChainMap(overrides, node, defaults)
+        return make_dict_from_map(params, cls.get_arg_key_map())
+
+    @classmethod
+    def from_node(cls, node, overrides=None, defaults=None):
+        return cls(**cls.args_from_node(node, overrides, defaults))
+
     def _get_pkey(self, key):
         if isinstance(key, six.string_types):
             key = six.moves.StringIO(key)
@@ -132,10 +192,14 @@ class SSH(object):
                 return key_class.from_private_key(key)
             except paramiko.SSHException as e:
                 errors.append(e)
-        raise SSHError("Invalid pkey: %s" % (errors))
+        raise exceptions.SSHError(error_msg='Invalid pkey: %s' % errors)
+
+    @property
+    def is_connected(self):
+        return bool(self._client)
 
     def _get_client(self):
-        if self._client:
+        if self.is_connected:
             return self._client
         try:
             self._client = paramiko.SSHClient()
@@ -149,14 +213,29 @@ class SSH(object):
             return self._client
         except Exception as e:
             message = ("Exception %(exception_type)s was raised "
-                       "during connect. Exception value is: %(exception)r")
+                       "during connect. Exception value is: %(exception)r" %
+                       {"exception": e, "exception_type": type(e)})
             self._client = False
-            raise SSHError(message % {"exception": e,
-                                      "exception_type": type(e)})
+            raise exceptions.SSHError(error_msg=message)
+
+    def _make_dict(self):
+        return {
+            'user': self.user,
+            'host': self.host,
+            'port': self.port,
+            'pkey': self.pkey,
+            'key_filename': self.key_filename,
+            'password': self.password,
+            'name': self.name,
+        }
+
+    def copy(self):
+        return self.get_class()(**self._make_dict())
 
     def close(self):
-        self._client.close()
-        self._client = False
+        if self._client:
+            self._client.close()
+            self._client = False
 
     def run(self, cmd, stdin=None, stdout=None, stderr=None,
             raise_on_error=True, timeout=3600,
@@ -214,7 +293,7 @@ class SSH(object):
 
         while True:
             # Block until data can be read/write.
-            r, w, e = select.select([session], writes, [session], 1)
+            e = select.select([session], writes, [session], 1)[2]
 
             if session.recv_ready():
                 data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
@@ -254,11 +333,11 @@ class SSH(object):
                 break
 
             if timeout and (time.time() - timeout) > start_time:
-                args = {"cmd": cmd, "host": self.host}
-                raise SSHTimeout("Timeout executing command "
-                                 "'%(cmd)s' on host %(host)s" % args)
+                message = ('Timeout executing command %(cmd)s on host %(host)s'
+                           % {"cmd": cmd, "host": self.host})
+                raise exceptions.SSHTimeout(error_msg=message)
             if e:
-                raise SSHError("Socket error.")
+                raise exceptions.SSHError(error_msg='Socket error')
 
         exit_status = session.recv_exit_status()
         if exit_status != 0 and raise_on_error:
@@ -266,7 +345,7 @@ class SSH(object):
             details = fmt % {"cmd": cmd, "status": exit_status}
             if stderr_data:
                 details += " Last stderr data: '%s'." % stderr_data
-            raise SSHError(details)
+            raise exceptions.SSHError(error_msg=details)
         return exit_status
 
     def execute(self, cmd, stdin=None, timeout=3600):
@@ -286,19 +365,23 @@ class SSH(object):
                                timeout=timeout, raise_on_error=False)
         stdout.seek(0)
         stderr.seek(0)
-        return (exit_status, stdout.read(), stderr.read())
+        return exit_status, stdout.read(), stderr.read()
 
-    def wait(self, timeout=120, interval=1):
+    def wait(self, timeout=None, interval=1):
         """Wait for the host will be available via ssh."""
-        start_time = time.time()
+        if timeout is None:
+            timeout = self.wait_timeout
+
+        end_time = time.time() + timeout
         while True:
             try:
                 return self.execute("uname")
-            except (socket.error, SSHError) as e:
+            except (socket.error, exceptions.SSHError) as e:
                 self.log.debug("Ssh is still unavailable: %r", e)
                 time.sleep(interval)
-            if time.time() > (start_time + timeout):
-                raise SSHTimeout("Timeout waiting for '%s'", self.host)
+            if time.time() > end_time:
+                raise exceptions.SSHTimeout(
+                    error_msg='Timeout waiting for "%s"' % self.host)
 
     def put(self, files, remote_path=b'.', recursive=False):
         client = self._get_client()
@@ -306,6 +389,12 @@ class SSH(object):
         with SCPClient(client.get_transport()) as scp:
             scp.put(files, remote_path, recursive)
 
+    def get(self, remote_path, local_path='/tmp/', recursive=True):
+        client = self._get_client()
+
+        with SCPClient(client.get_transport()) as scp:
+            scp.get(remote_path, local_path, recursive)
+
     # keep shell running in the background, e.g. screen
     def send_command(self, command):
         client = self._get_client()
@@ -347,3 +436,100 @@ class SSH(object):
             self._put_file_sftp(localpath, remotepath, mode=mode)
         except (paramiko.SSHException, socket.error):
             self._put_file_shell(localpath, remotepath, mode=mode)
+
+    def provision_tool(self, tool_path, tool_file=None):
+        return provision_tool(self, tool_path, tool_file)
+
+    def put_file_obj(self, file_obj, remotepath, mode=None):
+        client = self._get_client()
+
+        with client.open_sftp() as sftp:
+            sftp.putfo(file_obj, remotepath)
+            if mode is not None:
+                sftp.chmod(remotepath, mode)
+
+    def get_file_obj(self, remotepath, file_obj):
+        client = self._get_client()
+
+        with client.open_sftp() as sftp:
+            sftp.getfo(remotepath, file_obj)
+
+
+class AutoConnectSSH(SSH):
+
+    @classmethod
+    def get_arg_key_map(cls):
+        arg_key_map = super(AutoConnectSSH, cls).get_arg_key_map()
+        arg_key_map['wait'] = ('wait', True)
+        return arg_key_map
+
+    # always wait or we will get OpenStack SSH errors
+    def __init__(self, user, host, port=None, pkey=None,
+                 key_filename=None, password=None, name=None, wait=True):
+        super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
+        if wait and wait is not True:
+            self.wait_timeout = int(wait)
+
+    def _make_dict(self):
+        data = super(AutoConnectSSH, self)._make_dict()
+        data.update({
+            'wait': self.wait_timeout
+        })
+        return data
+
+    def _connect(self):
+        if not self.is_connected:
+            interval = 1
+            timeout = self.wait_timeout
+
+            end_time = time.time() + timeout
+            while True:
+                try:
+                    return self._get_client()
+                except (socket.error, exceptions.SSHError) as e:
+                    self.log.debug("Ssh is still unavailable: %r", e)
+                    time.sleep(interval)
+                if time.time() > end_time:
+                    raise exceptions.SSHTimeout(
+                        error_msg='Timeout waiting for "%s"' % self.host)
+
+    def drop_connection(self):
+        """ Don't close anything, just force creation of a new client """
+        self._client = False
+
+    def execute(self, cmd, stdin=None, timeout=3600):
+        self._connect()
+        return super(AutoConnectSSH, self).execute(cmd, stdin, timeout)
+
+    def run(self, cmd, stdin=None, stdout=None, stderr=None,
+            raise_on_error=True, timeout=3600,
+            keep_stdin_open=False, pty=False):
+        self._connect()
+        return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
+                                               timeout, keep_stdin_open, pty)
+
+    def put(self, files, remote_path=b'.', recursive=False):
+        self._connect()
+        return super(AutoConnectSSH, self).put(files, remote_path, recursive)
+
+    def put_file(self, local_path, remote_path, mode=None):
+        self._connect()
+        return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
+
+    def put_file_obj(self, file_obj, remote_path, mode=None):
+        self._connect()
+        return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
+
+    def get_file_obj(self, remote_path, file_obj):
+        self._connect()
+        return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
+
+    def provision_tool(self, tool_path, tool_file=None):
+        self._connect()
+        return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
+
+    @staticmethod
+    def get_class():
+        # must return static class name, anything else refers to the calling class
+        # i.e. the subclass, not the superclass
+        return AutoConnectSSH