Fix remote command execution in common.utils
[yardstick.git] / yardstick / ssh.py
index d7adc0d..69428f3 100644 (file)
@@ -62,15 +62,13 @@ Eventlet:
     sshclient = eventlet.import_patched("yardstick.ssh")
 
 """
-from __future__ import absolute_import
-import os
 import io
+import logging
+import os
+import re
 import select
 import socket
 import time
-import re
-
-import logging
 
 import paramiko
 from chainmap import ChainMap
@@ -78,6 +76,7 @@ from oslo_utils import encodeutils
 from scp import SCPClient
 import six
 
+from yardstick.common import exceptions
 from yardstick.common.utils import try_int, NON_NONE_DEFAULT, make_dict_from_map
 from yardstick.network_services.utils import provision_tool
 
@@ -90,12 +89,12 @@ def convert_key_to_str(key):
     return k.getvalue()
 
 
-class SSHError(Exception):
-    pass
-
-
-class SSHTimeout(SSHError):
-    pass
+class SSHError(Exception):
+    pass
+#
+#
+class SSHTimeout(SSHError):
+    pass
 
 
 class SSH(object):
@@ -193,7 +192,7 @@ class SSH(object):
                 return key_class.from_private_key(key)
             except paramiko.SSHException as e:
                 errors.append(e)
-        raise SSHError("Invalid pkey: %s" % errors)
+        raise exceptions.SSHError(error_msg='Invalid pkey: %s' % errors)
 
     @property
     def is_connected(self):
@@ -214,10 +213,10 @@ class SSH(object):
             return self._client
         except Exception as e:
             message = ("Exception %(exception_type)s was raised "
-                       "during connect. Exception value is: %(exception)r")
+                       "during connect. Exception value is: %(exception)r" %
+                       {"exception": e, "exception_type": type(e)})
             self._client = False
-            raise SSHError(message % {"exception": e,
-                                      "exception_type": type(e)})
+            raise exceptions.SSHError(error_msg=message)
 
     def _make_dict(self):
         return {
@@ -334,11 +333,11 @@ class SSH(object):
                 break
 
             if timeout and (time.time() - timeout) > start_time:
-                args = {"cmd": cmd, "host": self.host}
-                raise SSHTimeout("Timeout executing command "
-                                 "'%(cmd)s' on host %(host)s" % args)
+                message = ('Timeout executing command %(cmd)s on host %(host)s'
+                           % {"cmd": cmd, "host": self.host})
+                raise exceptions.SSHTimeout(error_msg=message)
             if e:
-                raise SSHError("Socket error.")
+                raise exceptions.SSHError(error_msg='Socket error')
 
         exit_status = session.recv_exit_status()
         if exit_status != 0 and raise_on_error:
@@ -346,15 +345,17 @@ class SSH(object):
             details = fmt % {"cmd": cmd, "status": exit_status}
             if stderr_data:
                 details += " Last stderr data: '%s'." % stderr_data
-            raise SSHError(details)
+            raise exceptions.SSHError(error_msg=details)
         return exit_status
 
-    def execute(self, cmd, stdin=None, timeout=3600):
+    def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False):
         """Execute the specified command on the server.
 
-        :param cmd:     Command to be executed.
-        :param stdin:   Open file to be sent on process stdin.
-        :param timeout: Timeout for execution of the command.
+        :param cmd: (str)             Command to be executed.
+        :param stdin: (StringIO)      Open file to be sent on process stdin.
+        :param timeout: (int)         Timeout for execution of the command.
+        :param raise_on_error: (bool) If True, then an SSHError will be raised
+                                      when non-zero exit code.
 
         :returns: tuple (exit_status, stdout, stderr)
         """
@@ -363,7 +364,7 @@ class SSH(object):
 
         exit_status = self.run(cmd, stderr=stderr,
                                stdout=stdout, stdin=stdin,
-                               timeout=timeout, raise_on_error=False)
+                               timeout=timeout, raise_on_error=raise_on_error)
         stdout.seek(0)
         stderr.seek(0)
         return exit_status, stdout.read(), stderr.read()
@@ -377,11 +378,12 @@ class SSH(object):
         while True:
             try:
                 return self.execute("uname")
-            except (socket.error, SSHError) as e:
+            except (socket.error, exceptions.SSHError) as e:
                 self.log.debug("Ssh is still unavailable: %r", e)
                 time.sleep(interval)
             if time.time() > end_time:
-                raise SSHTimeout("Timeout waiting for '%s'" % self.host)
+                raise exceptions.SSHTimeout(
+                    error_msg='Timeout waiting for "%s"' % self.host)
 
     def put(self, files, remote_path=b'.', recursive=False):
         client = self._get_client()
@@ -486,19 +488,21 @@ class AutoConnectSSH(SSH):
             while True:
                 try:
                     return self._get_client()
-                except (socket.error, SSHError) as e:
+                except (socket.error, exceptions.SSHError) as e:
                     self.log.debug("Ssh is still unavailable: %r", e)
                     time.sleep(interval)
                 if time.time() > end_time:
-                    raise SSHTimeout("Timeout waiting for '%s'" % self.host)
+                    raise exceptions.SSHTimeout(
+                        error_msg='Timeout waiting for "%s"' % self.host)
 
     def drop_connection(self):
         """ Don't close anything, just force creation of a new client """
         self._client = False
 
-    def execute(self, cmd, stdin=None, timeout=3600):
+    def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=True):
         self._connect()
-        return super(AutoConnectSSH, self).execute(cmd, stdin, timeout)
+        return super(AutoConnectSSH, self).execute(cmd, stdin, timeout,
+                                                   raise_on_error)
 
     def run(self, cmd, stdin=None, stdout=None, stderr=None,
             raise_on_error=True, timeout=3600,