+
+ def _put_file_sftp(self, localpath, remotepath, mode=None):
+ client = self._get_client()
+
+ with client.open_sftp() as sftp:
+ sftp.put(localpath, remotepath)
+ if mode is None:
+ mode = 0o777 & os.stat(localpath).st_mode
+ sftp.chmod(remotepath, mode)
+
+ TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
+
+ def _put_file_shell(self, localpath, remotepath, mode=None):
+ # quote to stop wordpslit
+ tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
+ if not tilde:
+ tilde = ''
+ cmd = ['cat > %s"%s"' % (tilde, remotepath)]
+ if mode is not None:
+ # use -- so no options
+ cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
+
+ with open(localpath, "rb") as localfile:
+ # only chmod on successful cat
+ self.run("&& ".join(cmd), stdin=localfile)
+
+ def put_file(self, localpath, remotepath, mode=None):
+ """Copy specified local file to the server.
+
+ :param localpath: Local filename.
+ :param remotepath: Remote filename.
+ :param mode: Permissions to set after upload
+ """
+ try:
+ self._put_file_sftp(localpath, remotepath, mode=mode)
+ except (paramiko.SSHException, socket.error):
+ self._put_file_shell(localpath, remotepath, mode=mode)
+
+ def provision_tool(self, tool_path, tool_file=None):
+ return provision_tool(self, tool_path, tool_file)
+
+ def put_file_obj(self, file_obj, remotepath, mode=None):
+ client = self._get_client()
+
+ with client.open_sftp() as sftp:
+ sftp.putfo(file_obj, remotepath)
+ if mode is not None:
+ sftp.chmod(remotepath, mode)
+
+ def get_file_obj(self, remotepath, file_obj):
+ client = self._get_client()
+
+ with client.open_sftp() as sftp:
+ sftp.getfo(remotepath, file_obj)
+
+
+class AutoConnectSSH(SSH):
+
+ # always wait or we will get OpenStack SSH errors
+ def __init__(self, user, host, port=None, pkey=None,
+ key_filename=None, password=None, name=None, wait=True):
+ super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
+ self._wait = wait
+
+ def _make_dict(self):
+ data = super(AutoConnectSSH, self)._make_dict()
+ data.update({
+ 'wait': self._wait
+ })
+ return data
+
+ def _connect(self):
+ if not self.is_connected:
+ self._get_client()
+ if self._wait:
+ self.wait()
+
+ def drop_connection(self):
+ """ Don't close anything, just force creation of a new client """
+ self._client = False
+
+ def execute(self, cmd, stdin=None, timeout=3600):
+ self._connect()
+ return super(AutoConnectSSH, self).execute(cmd, stdin, timeout)
+
+ def run(self, cmd, stdin=None, stdout=None, stderr=None,
+ raise_on_error=True, timeout=3600,
+ keep_stdin_open=False, pty=False):
+ self._connect()
+ return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
+ timeout, keep_stdin_open, pty)
+
+ def put(self, files, remote_path=b'.', recursive=False):
+ self._connect()
+ return super(AutoConnectSSH, self).put(files, remote_path, recursive)
+
+ def put_file(self, local_path, remote_path, mode=None):
+ self._connect()
+ return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
+
+ def put_file_obj(self, file_obj, remote_path, mode=None):
+ self._connect()
+ return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
+
+ def get_file_obj(self, remote_path, file_obj):
+ self._connect()
+ return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
+
+ def provision_tool(self, tool_path, tool_file=None):
+ self._connect()
+ return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
+
+ @staticmethod
+ def get_class():
+ # must return static class name, anything else refers to the calling class
+ # i.e. the subclass, not the superclass
+ return AutoConnectSSH