Merge "[docs] Documentation for BNG PPPoE RFC2544 test cases"
[yardstick.git] / yardstick / ssh.py
index 6ddf327..6bc6010 100644 (file)
@@ -62,15 +62,13 @@ Eventlet:
     sshclient = eventlet.import_patched("yardstick.ssh")
 
 """
-from __future__ import absolute_import
-import os
 import io
+import logging
+import os
+import re
 import select
 import socket
 import time
-import re
-
-import logging
 
 import paramiko
 from chainmap import ChainMap
@@ -78,9 +76,11 @@ from oslo_utils import encodeutils
 from scp import SCPClient
 import six
 
-from yardstick.common.utils import try_int
+from yardstick.common import exceptions
+from yardstick.common.utils import try_int, NON_NONE_DEFAULT, make_dict_from_map
 from yardstick.network_services.utils import provision_tool
 
+LOG = logging.getLogger(__name__)
 
 def convert_key_to_str(key):
     if not isinstance(key, (paramiko.RSAKey, paramiko.DSSKey)):
@@ -90,18 +90,11 @@ def convert_key_to_str(key):
     return k.getvalue()
 
 
-class SSHError(Exception):
-    pass
-
-
-class SSHTimeout(SSHError):
-    pass
-
-
 class SSH(object):
     """Represent ssh connection."""
 
     SSH_PORT = paramiko.config.SSH_PORT
+    DEFAULT_WAIT_TIMEOUT = 120
 
     @staticmethod
     def gen_keys(key_filename, bit_count=2048):
@@ -120,6 +113,18 @@ class SSH(object):
         # i.e. the subclass, not the superclass
         return SSH
 
+    @classmethod
+    def get_arg_key_map(cls):
+        return {
+            'user': ('user', NON_NONE_DEFAULT),
+            'host': ('ip', NON_NONE_DEFAULT),
+            'port': ('ssh_port', cls.SSH_PORT),
+            'pkey': ('pkey', None),
+            'key_filename': ('key_filename', None),
+            'password': ('password', None),
+            'name': ('name', None),
+        }
+
     def __init__(self, user, host, port=None, pkey=None,
                  key_filename=None, password=None, name=None):
         """Initialize SSH client.
@@ -137,6 +142,7 @@ class SSH(object):
         else:
             self.log = logging.getLogger(__name__)
 
+        self.wait_timeout = self.DEFAULT_WAIT_TIMEOUT
         self.user = user
         self.host = host
         # everybody wants to debug this in the caller, do it here instead
@@ -162,16 +168,9 @@ class SSH(object):
             overrides = {}
         if defaults is None:
             defaults = {}
+
         params = ChainMap(overrides, node, defaults)
-        return {
-            'user': params['user'],
-            'host': params['ip'],
-            'port': params.get('ssh_port', cls.SSH_PORT),
-            'pkey': params.get('pkey'),
-            'key_filename': params.get('key_filename'),
-            'password': params.get('password'),
-            'name': params.get('name'),
-        }
+        return make_dict_from_map(params, cls.get_arg_key_map())
 
     @classmethod
     def from_node(cls, node, overrides=None, defaults=None):
@@ -186,7 +185,7 @@ class SSH(object):
                 return key_class.from_private_key(key)
             except paramiko.SSHException as e:
                 errors.append(e)
-        raise SSHError("Invalid pkey: %s" % (errors))
+        raise exceptions.SSHError(error_msg='Invalid pkey: %s' % errors)
 
     @property
     def is_connected(self):
@@ -207,10 +206,10 @@ class SSH(object):
             return self._client
         except Exception as e:
             message = ("Exception %(exception_type)s was raised "
-                       "during connect. Exception value is: %(exception)r")
+                       "during connect. Exception value is: %(exception)r" %
+                       {"exception": e, "exception_type": type(e)})
             self._client = False
-            raise SSHError(message % {"exception": e,
-                                      "exception_type": type(e)})
+            raise exceptions.SSHError(error_msg=message)
 
     def _make_dict(self):
         return {
@@ -287,7 +286,7 @@ class SSH(object):
 
         while True:
             # Block until data can be read/write.
-            r, w, e = select.select([session], writes, [session], 1)
+            e = select.select([session], writes, [session], 1)[2]
 
             if session.recv_ready():
                 data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
@@ -327,11 +326,11 @@ class SSH(object):
                 break
 
             if timeout and (time.time() - timeout) > start_time:
-                args = {"cmd": cmd, "host": self.host}
-                raise SSHTimeout("Timeout executing command "
-                                 "'%(cmd)s' on host %(host)s" % args)
+                message = ('Timeout executing command %(cmd)s on host %(host)s'
+                           % {"cmd": cmd, "host": self.host})
+                raise exceptions.SSHTimeout(error_msg=message)
             if e:
-                raise SSHError("Socket error.")
+                raise exceptions.SSHError(error_msg='Socket error')
 
         exit_status = session.recv_exit_status()
         if exit_status != 0 and raise_on_error:
@@ -339,15 +338,18 @@ class SSH(object):
             details = fmt % {"cmd": cmd, "status": exit_status}
             if stderr_data:
                 details += " Last stderr data: '%s'." % stderr_data
-            raise SSHError(details)
+            LOG.critical("PROX ERROR: %s", details)
+            raise exceptions.SSHError(error_msg=details)
         return exit_status
 
-    def execute(self, cmd, stdin=None, timeout=3600):
+    def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False):
         """Execute the specified command on the server.
 
-        :param cmd:     Command to be executed.
-        :param stdin:   Open file to be sent on process stdin.
-        :param timeout: Timeout for execution of the command.
+        :param cmd: (str)             Command to be executed.
+        :param stdin: (StringIO)      Open file to be sent on process stdin.
+        :param timeout: (int)         Timeout for execution of the command.
+        :param raise_on_error: (bool) If True, then an SSHError will be raised
+                                      when non-zero exit code.
 
         :returns: tuple (exit_status, stdout, stderr)
         """
@@ -356,22 +358,26 @@ class SSH(object):
 
         exit_status = self.run(cmd, stderr=stderr,
                                stdout=stdout, stdin=stdin,
-                               timeout=timeout, raise_on_error=False)
+                               timeout=timeout, raise_on_error=raise_on_error)
         stdout.seek(0)
         stderr.seek(0)
         return exit_status, stdout.read(), stderr.read()
 
-    def wait(self, timeout=120, interval=1):
+    def wait(self, timeout=None, interval=1):
         """Wait for the host will be available via ssh."""
-        start_time = time.time()
+        if timeout is None:
+            timeout = self.wait_timeout
+
+        end_time = time.time() + timeout
         while True:
             try:
                 return self.execute("uname")
-            except (socket.error, SSHError) as e:
+            except (socket.error, exceptions.SSHError) as e:
                 self.log.debug("Ssh is still unavailable: %r", e)
                 time.sleep(interval)
-            if time.time() > (start_time + timeout):
-                raise SSHTimeout("Timeout waiting for '%s'", self.host)
+            if time.time() > end_time:
+                raise exceptions.SSHTimeout(
+                    error_msg='Timeout waiting for "%s"' % self.host)
 
     def put(self, files, remote_path=b'.', recursive=False):
         client = self._get_client()
@@ -444,35 +450,133 @@ class SSH(object):
         with client.open_sftp() as sftp:
             sftp.getfo(remotepath, file_obj)
 
+    def interactive_terminal_open(self, time_out=45):
+        """Open interactive terminal on a SSH channel.
+
+        :param time_out: Timeout in seconds.
+        :returns: SSH channel with opened terminal.
+
+        .. warning:: Interruptingcow is used here, and it uses
+           signal(SIGALRM) to let the operating system interrupt program
+           execution. This has the following limitations: Python signal
+           handlers only apply to the main thread, so you cannot use this
+           from other threads. You must not use this in a program that
+           uses SIGALRM itself (this includes certain profilers)
+        """
+        chan = self._get_client().get_transport().open_session()
+        chan.get_pty()
+        chan.invoke_shell()
+        chan.settimeout(int(time_out))
+        chan.set_combine_stderr(True)
+
+        buf = ''
+        while not buf.endswith((":~# ", ":~$ ", "~]$ ", "~]# ")):
+            try:
+                chunk = chan.recv(10 * 1024 * 1024)
+                if not chunk:
+                    break
+                buf += chunk
+                if chan.exit_status_ready():
+                    self.log.error('Channel exit status ready')
+                    break
+            except socket.timeout:
+                raise exceptions.SSHTimeout(error_msg='Socket timeout: %s' % buf)
+        return chan
+
+    def interactive_terminal_exec_command(self, chan, cmd, prompt):
+        """Execute command on interactive terminal.
+
+        interactive_terminal_open() method has to be called first!
+
+        :param chan: SSH channel with opened terminal.
+        :param cmd: Command to be executed.
+        :param prompt: Command prompt, sequence of characters used to
+        indicate readiness to accept commands.
+        :returns: Command output.
+
+        .. warning:: Interruptingcow is used here, and it uses
+           signal(SIGALRM) to let the operating system interrupt program
+           execution. This has the following limitations: Python signal
+           handlers only apply to the main thread, so you cannot use this
+           from other threads. You must not use this in a program that
+           uses SIGALRM itself (this includes certain profilers)
+        """
+        chan.sendall('{c}\n'.format(c=cmd))
+        buf = ''
+        while not buf.endswith(prompt):
+            try:
+                chunk = chan.recv(10 * 1024 * 1024)
+                if not chunk:
+                    break
+                buf += chunk
+                if chan.exit_status_ready():
+                    self.log.error('Channel exit status ready')
+                    break
+            except socket.timeout:
+                message = ("Socket timeout during execution of command: "
+                           "%(cmd)s\nBuffer content:\n%(buf)s" % {"cmd": cmd,
+                                                                  "buf": buf})
+                raise exceptions.SSHTimeout(error_msg=message)
+        tmp = buf.replace(cmd.replace('\n', ''), '')
+        for item in prompt:
+            tmp.replace(item, '')
+        return tmp
+
+    @staticmethod
+    def interactive_terminal_close(chan):
+        """Close interactive terminal SSH channel.
+
+        :param: chan: SSH channel to be closed.
+        """
+        chan.close()
+
 
 class AutoConnectSSH(SSH):
 
+    @classmethod
+    def get_arg_key_map(cls):
+        arg_key_map = super(AutoConnectSSH, cls).get_arg_key_map()
+        arg_key_map['wait'] = ('wait', True)
+        return arg_key_map
+
     # always wait or we will get OpenStack SSH errors
     def __init__(self, user, host, port=None, pkey=None,
                  key_filename=None, password=None, name=None, wait=True):
         super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
-        self._wait = wait
+        if wait and wait is not True:
+            self.wait_timeout = int(wait)
 
     def _make_dict(self):
         data = super(AutoConnectSSH, self)._make_dict()
         data.update({
-            'wait': self._wait
+            'wait': self.wait_timeout
         })
         return data
 
     def _connect(self):
         if not self.is_connected:
-            self._get_client()
-            if self._wait:
-                self.wait()
+            interval = 1
+            timeout = self.wait_timeout
+
+            end_time = time.time() + timeout
+            while True:
+                try:
+                    return self._get_client()
+                except (socket.error, exceptions.SSHError) as e:
+                    self.log.debug("Ssh is still unavailable: %r", e)
+                    time.sleep(interval)
+                if time.time() > end_time:
+                    raise exceptions.SSHTimeout(
+                        error_msg='Timeout waiting for "%s"' % self.host)
 
     def drop_connection(self):
         """ Don't close anything, just force creation of a new client """
         self._client = False
 
-    def execute(self, cmd, stdin=None, timeout=3600):
+    def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False):
         self._connect()
-        return super(AutoConnectSSH, self).execute(cmd, stdin, timeout)
+        return super(AutoConnectSSH, self).execute(cmd, stdin, timeout,
+                                                   raise_on_error)
 
     def run(self, cmd, stdin=None, stdout=None, stderr=None,
             raise_on_error=True, timeout=3600,