Merge "testcase: add rate parameter for spec cpu 2006"
[yardstick.git] / yardstick / ssh.py
index 46d53b7..a024cf6 100644 (file)
@@ -25,7 +25,7 @@ Execute command and get output:
     status, stdout, stderr = ssh.execute("ps ax")
     if status:
         raise Exception("Command failed with non-zero status.")
-    print stdout.splitlines()
+    print(stdout.splitlines())
 
 Execute command with huge output:
 
@@ -62,18 +62,23 @@ Eventlet:
     sshclient = eventlet.import_patched("yardstick.ssh")
 
 """
+from __future__ import absolute_import
 import os
 import select
 import socket
 import time
+import re
 
 import logging
+
 import paramiko
+from chainmap import ChainMap
+from oslo_utils import encodeutils
 from scp import SCPClient
 import six
 
-
-DEFAULT_PORT = 22
+from yardstick.common.utils import try_int
+from yardstick.network_services.utils import provision_tool
 
 
 class SSHError(Exception):
@@ -87,7 +92,26 @@ class SSHTimeout(SSHError):
 class SSH(object):
     """Represent ssh connection."""
 
-    def __init__(self, user, host, port=DEFAULT_PORT, pkey=None,
+    SSH_PORT = paramiko.config.SSH_PORT
+
+    @staticmethod
+    def gen_keys(key_filename, bit_count=2048):
+        rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None)
+        rsa_key.write_private_key_file(key_filename)
+        print("Writing %s ..." % key_filename)
+        with open('.'.join([key_filename, "pub"]), "w") as pubkey_file:
+            pubkey_file.write(rsa_key.get_name())
+            pubkey_file.write(' ')
+            pubkey_file.write(rsa_key.get_base64())
+            pubkey_file.write('\n')
+
+    @staticmethod
+    def get_class():
+        # must return static class name, anything else refers to the calling class
+        # i.e. the subclass, not the superclass
+        return SSH
+
+    def __init__(self, user, host, port=None, pkey=None,
                  key_filename=None, password=None, name=None):
         """Initialize SSH client.
 
@@ -106,8 +130,11 @@ class SSH(object):
 
         self.user = user
         self.host = host
+        # everybody wants to debug this in the caller, do it here instead
+        self.log.debug("user:%s host:%s", user, host)
+
         # we may get text port from YAML, convert to int
-        self.port = int(port)
+        self.port = try_int(port, self.SSH_PORT)
         self.pkey = self._get_pkey(pkey) if pkey else None
         self.password = password
         self.key_filename = key_filename
@@ -120,6 +147,27 @@ class SSH(object):
         else:
             logging.getLogger("paramiko").setLevel(logging.WARN)
 
+    @classmethod
+    def args_from_node(cls, node, overrides=None, defaults=None):
+        if overrides is None:
+            overrides = {}
+        if defaults is None:
+            defaults = {}
+        params = ChainMap(overrides, node, defaults)
+        return {
+            'user': params['user'],
+            'host': params['ip'],
+            'port': params.get('ssh_port', cls.SSH_PORT),
+            'pkey': params.get('pkey'),
+            'key_filename': params.get('key_filename'),
+            'password': params.get('password'),
+            'name': params.get('name'),
+        }
+
+    @classmethod
+    def from_node(cls, node, overrides=None, defaults=None):
+        return cls(**cls.args_from_node(node, overrides, defaults))
+
     def _get_pkey(self, key):
         if isinstance(key, six.string_types):
             key = six.moves.StringIO(key)
@@ -131,8 +179,12 @@ class SSH(object):
                 errors.append(e)
         raise SSHError("Invalid pkey: %s" % (errors))
 
+    @property
+    def is_connected(self):
+        return bool(self._client)
+
     def _get_client(self):
-        if self._client:
+        if self.is_connected:
             return self._client
         try:
             self._client = paramiko.SSHClient()
@@ -151,13 +203,28 @@ class SSH(object):
             raise SSHError(message % {"exception": e,
                                       "exception_type": type(e)})
 
+    def _make_dict(self):
+        return {
+            'user': self.user,
+            'host': self.host,
+            'port': self.port,
+            'pkey': self.pkey,
+            'key_filename': self.key_filename,
+            'password': self.password,
+            'name': self.name,
+        }
+
+    def copy(self):
+        return self.get_class()(**self._make_dict())
+
     def close(self):
-        self._client.close()
-        self._client = False
+        if self._client:
+            self._client.close()
+            self._client = False
 
     def run(self, cmd, stdin=None, stdout=None, stderr=None,
             raise_on_error=True, timeout=3600,
-            keep_stdin_open=False):
+            keep_stdin_open=False, pty=False):
         """Execute specified command on the server.
 
         :param cmd:             Command to be executed.
@@ -171,6 +238,10 @@ class SSH(object):
                                 Default 1 hour. No timeout if set to 0.
         :param keep_stdin_open: don't close stdin on empty reads
         :type keep_stdin_open:  bool
+        :param pty:             Request a pseudo terminal for this connection.
+                                This allows passing control characters.
+                                Default False.
+        :type pty:              bool
         """
 
         client = self._get_client()
@@ -181,18 +252,21 @@ class SSH(object):
         return self._run(client, cmd, stdin=stdin, stdout=stdout,
                          stderr=stderr, raise_on_error=raise_on_error,
                          timeout=timeout,
-                         keep_stdin_open=keep_stdin_open)
+                         keep_stdin_open=keep_stdin_open, pty=pty)
 
     def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
              raise_on_error=True, timeout=3600,
-             keep_stdin_open=False):
+             keep_stdin_open=False, pty=False):
 
         transport = client.get_transport()
         session = transport.open_session()
+        if pty:
+            session.get_pty()
         session.exec_command(cmd)
         start_time = time.time()
 
-        data_to_send = ""
+        # encode on transmit, decode on receive
+        data_to_send = encodeutils.safe_encode("", incoming='utf-8')
         stderr_data = None
 
         # If we have data to be sent to stdin then `select' should also
@@ -207,14 +281,15 @@ class SSH(object):
             r, w, e = select.select([session], writes, [session], 1)
 
             if session.recv_ready():
-                data = session.recv(4096)
+                data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
                 self.log.debug("stdout: %r", data)
                 if stdout is not None:
                     stdout.write(data)
                 continue
 
             if session.recv_stderr_ready():
-                stderr_data = session.recv_stderr(4096)
+                stderr_data = encodeutils.safe_decode(
+                    session.recv_stderr(4096), 'utf-8')
                 self.log.debug("stderr: %r", stderr_data)
                 if stderr is not None:
                     stderr.write(stderr_data)
@@ -223,7 +298,11 @@ class SSH(object):
             if session.send_ready():
                 if stdin is not None and not stdin.closed:
                     if not data_to_send:
-                        data_to_send = stdin.read(4096)
+                        stdin_txt = stdin.read(4096)
+                        if stdin_txt is None:
+                            stdin_txt = ''
+                        data_to_send = encodeutils.safe_encode(
+                            stdin_txt, incoming='utf-8')
                         if not data_to_send:
                             # we may need to keep stdin open
                             if not keep_stdin_open:
@@ -246,7 +325,7 @@ class SSH(object):
                 raise SSHError("Socket error.")
 
         exit_status = session.recv_exit_status()
-        if 0 != exit_status and raise_on_error:
+        if exit_status != 0 and raise_on_error:
             fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
             details = fmt % {"cmd": cmd, "status": exit_status}
             if stderr_data:
@@ -271,7 +350,7 @@ class SSH(object):
                                timeout=timeout, raise_on_error=False)
         stdout.seek(0)
         stderr.seek(0)
-        return (exit_status, stdout.read(), stderr.read())
+        return exit_status, stdout.read(), stderr.read()
 
     def wait(self, timeout=120, interval=1):
         """Wait for the host will be available via ssh."""
@@ -305,17 +384,21 @@ class SSH(object):
                 mode = 0o777 & os.stat(localpath).st_mode
             sftp.chmod(remotepath, mode)
 
+    TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
+
     def _put_file_shell(self, localpath, remotepath, mode=None):
         # quote to stop wordpslit
-        cmd = ['cat > "%s"' % remotepath]
+        tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
+        if not tilde:
+            tilde = ''
+        cmd = ['cat > %s"%s"' % (tilde, remotepath)]
         if mode is not None:
             # use -- so no options
-            cmd.append('chmod -- 0%o "%s"' % (mode, remotepath))
+            cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
 
         with open(localpath, "rb") as localfile:
             # only chmod on successful cat
-            cmd = "&& ".join(cmd)
-            self.run(cmd, stdin=localfile)
+            self.run("&& ".join(cmd), stdin=localfile)
 
     def put_file(self, localpath, remotepath, mode=None):
         """Copy specified local file to the server.
@@ -324,8 +407,86 @@ class SSH(object):
         :param remotepath:  Remote filename.
         :param mode:        Permissions to set after upload
         """
-        import socket
         try:
             self._put_file_sftp(localpath, remotepath, mode=mode)
         except (paramiko.SSHException, socket.error):
             self._put_file_shell(localpath, remotepath, mode=mode)
+
+    def provision_tool(self, tool_path, tool_file=None):
+        return provision_tool(self, tool_path, tool_file)
+
+    def put_file_obj(self, file_obj, remotepath, mode=None):
+        client = self._get_client()
+
+        with client.open_sftp() as sftp:
+            sftp.putfo(file_obj, remotepath)
+            if mode is not None:
+                sftp.chmod(remotepath, mode)
+
+    def get_file_obj(self, remotepath, file_obj):
+        client = self._get_client()
+
+        with client.open_sftp() as sftp:
+            sftp.getfo(remotepath, file_obj)
+
+
+class AutoConnectSSH(SSH):
+
+    def __init__(self, user, host, port=None, pkey=None,
+                 key_filename=None, password=None, name=None, wait=False):
+        super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
+        self._wait = wait
+
+    def _make_dict(self):
+        data = super(AutoConnectSSH, self)._make_dict()
+        data.update({
+            'wait': self._wait
+        })
+        return data
+
+    def _connect(self):
+        if not self.is_connected:
+            self._get_client()
+            if self._wait:
+                self.wait()
+
+    def drop_connection(self):
+        """ Don't close anything, just force creation of a new client """
+        self._client = False
+
+    def execute(self, cmd, stdin=None, timeout=3600):
+        self._connect()
+        return super(AutoConnectSSH, self).execute(cmd, stdin, timeout)
+
+    def run(self, cmd, stdin=None, stdout=None, stderr=None,
+            raise_on_error=True, timeout=3600,
+            keep_stdin_open=False, pty=False):
+        self._connect()
+        return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
+                                               timeout, keep_stdin_open, pty)
+
+    def put(self, files, remote_path=b'.', recursive=False):
+        self._connect()
+        return super(AutoConnectSSH, self).put(files, remote_path, recursive)
+
+    def put_file(self, local_path, remote_path, mode=None):
+        self._connect()
+        return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
+
+    def put_file_obj(self, file_obj, remote_path, mode=None):
+        self._connect()
+        return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
+
+    def get_file_obj(self, remote_path, file_obj):
+        self._connect()
+        return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
+
+    def provision_tool(self, tool_path, tool_file=None):
+        self._connect()
+        return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
+
+    @staticmethod
+    def get_class():
+        # must return static class name, anything else refers to the calling class
+        # i.e. the subclass, not the superclass
+        return AutoConnectSSH