Add support for Python 3
[yardstick.git] / yardstick / ssh.py
index d287b4d..1cad8ee 100644 (file)
@@ -25,28 +25,33 @@ Execute command and get output:
     status, stdout, stderr = ssh.execute("ps ax")
     if status:
         raise Exception("Command failed with non-zero status.")
-    print stdout.splitlines()
+    print(stdout.splitlines())
 
 Execute command with huge output:
 
-    class PseudoFile(object):
+    class PseudoFile(io.RawIOBase):
         def write(chunk):
             if "error" in chunk:
                 email_admin(chunk)
 
-    ssh = sshclient.SSH("root", "example.com")
-    ssh.run("tail -f /var/log/syslog", stdout=PseudoFile(), timeout=False)
+    ssh = SSH("root", "example.com")
+    with PseudoFile() as p:
+        ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
 
 Execute local script on remote side:
 
     ssh = sshclient.SSH("user", "example.com")
-    status, out, err = ssh.execute("/bin/sh -s arg1 arg2",
-                                   stdin=open("~/myscript.sh", "r"))
+
+    with open("~/myscript.sh", "r") as stdin_file:
+        status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
+                                       stdin=stdin_file)
 
 Upload file:
 
-    ssh = sshclient.SSH("user", "example.com")
-    ssh.run("cat > ~/upload/file.gz", stdin=open("/store/file.gz", "rb"))
+    ssh = SSH("user", "example.com")
+    # use rb for binary files
+    with open("/store/file.gz", "rb") as stdin_file:
+        ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
 
 Eventlet:
 
@@ -54,16 +59,19 @@ Eventlet:
     or
     eventlet.monkey_patch()
     or
-    sshclient = eventlet.import_patched("opentstack.common.sshclient")
+    sshclient = eventlet.import_patched("yardstick.ssh")
 
 """
+from __future__ import absolute_import
 import os
 import select
 import socket
 import time
+import re
 
 import logging
 import paramiko
+from oslo_utils import encodeutils
 from scp import SCPClient
 import six
 
@@ -151,10 +159,12 @@ class SSH(object):
         self._client = False
 
     def run(self, cmd, stdin=None, stdout=None, stderr=None,
-            raise_on_error=True, timeout=3600):
+            raise_on_error=True, timeout=3600,
+            keep_stdin_open=False, pty=False):
         """Execute specified command on the server.
 
         :param cmd:             Command to be executed.
+        :type cmd:              str
         :param stdin:           Open file or string to pass to stdin.
         :param stdout:          Open file to connect to stdout.
         :param stderr:          Open file to connect to stderr.
@@ -162,6 +172,12 @@ class SSH(object):
                                 then exception will be raized if non-zero code.
         :param timeout:         Timeout in seconds for command execution.
                                 Default 1 hour. No timeout if set to 0.
+        :param keep_stdin_open: don't close stdin on empty reads
+        :type keep_stdin_open:  bool
+        :param pty:             Request a pseudo terminal for this connection.
+                                This allows passing control characters.
+                                Default False.
+        :type pty:              bool
         """
 
         client = self._get_client()
@@ -171,17 +187,22 @@ class SSH(object):
 
         return self._run(client, cmd, stdin=stdin, stdout=stdout,
                          stderr=stderr, raise_on_error=raise_on_error,
-                         timeout=timeout)
+                         timeout=timeout,
+                         keep_stdin_open=keep_stdin_open, pty=pty)
 
     def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
-             raise_on_error=True, timeout=3600):
+             raise_on_error=True, timeout=3600,
+             keep_stdin_open=False, pty=False):
 
         transport = client.get_transport()
         session = transport.open_session()
+        if pty:
+            session.get_pty()
         session.exec_command(cmd)
         start_time = time.time()
 
-        data_to_send = ""
+        # encode on transmit, decode on receive
+        data_to_send = encodeutils.safe_encode("")
         stderr_data = None
 
         # If we have data to be sent to stdin then `select' should also
@@ -196,15 +217,16 @@ class SSH(object):
             r, w, e = select.select([session], writes, [session], 1)
 
             if session.recv_ready():
-                data = session.recv(4096)
-                self.log.debug("stdout: %r" % data)
+                data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
+                self.log.debug("stdout: %r", data)
                 if stdout is not None:
                     stdout.write(data)
                 continue
 
             if session.recv_stderr_ready():
-                stderr_data = session.recv_stderr(4096)
-                self.log.debug("stderr: %r" % stderr_data)
+                stderr_data = encodeutils.safe_decode(
+                    session.recv_stderr(4096), 'utf-8')
+                self.log.debug("stderr: %r", stderr_data)
                 if stderr is not None:
                     stderr.write(stderr_data)
                 continue
@@ -212,15 +234,18 @@ class SSH(object):
             if session.send_ready():
                 if stdin is not None and not stdin.closed:
                     if not data_to_send:
-                        data_to_send = stdin.read(4096)
+                        data_to_send = encodeutils.safe_encode(
+                            stdin.read(4096), incoming='utf-8')
                         if not data_to_send:
-                            stdin.close()
-                            session.shutdown_write()
-                            writes = []
-                            continue
-                    sent_bytes = session.send(data_to_send)
-                    # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
-                    data_to_send = data_to_send[sent_bytes:]
+                            # we may need to keep stdin open
+                            if not keep_stdin_open:
+                                stdin.close()
+                                session.shutdown_write()
+                                writes = []
+                    if data_to_send:
+                        sent_bytes = session.send(data_to_send)
+                        # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
+                        data_to_send = data_to_send[sent_bytes:]
 
             if session.exit_status_ready():
                 break
@@ -233,7 +258,7 @@ class SSH(object):
                 raise SSHError("Socket error.")
 
         exit_status = session.recv_exit_status()
-        if 0 != exit_status and raise_on_error:
+        if exit_status != 0 and raise_on_error:
             fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
             details = fmt % {"cmd": cmd, "status": exit_status}
             if stderr_data:
@@ -267,10 +292,10 @@ class SSH(object):
             try:
                 return self.execute("uname")
             except (socket.error, SSHError) as e:
-                self.log.debug("Ssh is still unavailable: %r" % e)
+                self.log.debug("Ssh is still unavailable: %r", e)
                 time.sleep(interval)
             if time.time() > (start_time + timeout):
-                raise SSHTimeout("Timeout waiting for '%s'" % self.host)
+                raise SSHTimeout("Timeout waiting for '%s'", self.host)
 
     def put(self, files, remote_path=b'.', recursive=False):
         client = self._get_client()
@@ -282,3 +307,40 @@ class SSH(object):
     def send_command(self, command):
         client = self._get_client()
         client.exec_command(command, get_pty=True)
+
+    def _put_file_sftp(self, localpath, remotepath, mode=None):
+        client = self._get_client()
+
+        with client.open_sftp() as sftp:
+            sftp.put(localpath, remotepath)
+            if mode is None:
+                mode = 0o777 & os.stat(localpath).st_mode
+            sftp.chmod(remotepath, mode)
+
+    TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
+
+    def _put_file_shell(self, localpath, remotepath, mode=None):
+        # quote to stop wordpslit
+        tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
+        if not tilde:
+            tilde = ''
+        cmd = ['cat > %s"%s"' % (tilde, remotepath)]
+        if mode is not None:
+            # use -- so no options
+            cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
+
+        with open(localpath, "rb") as localfile:
+            # only chmod on successful cat
+            self.run("&& ".join(cmd), stdin=localfile)
+
+    def put_file(self, localpath, remotepath, mode=None):
+        """Copy specified local file to the server.
+
+        :param localpath:   Local filename.
+        :param remotepath:  Remote filename.
+        :param mode:        Permissions to set after upload
+        """
+        try:
+            self._put_file_sftp(localpath, remotepath, mode=mode)
+        except (paramiko.SSHException, socket.error):
+            self._put_file_shell(localpath, remotepath, mode=mode)