+
+ def provision_tool(self, tool_path, tool_file=None):
+ return provision_tool(self, tool_path, tool_file)
+
+ def put_file_obj(self, file_obj, remotepath, mode=None):
+ client = self._get_client()
+
+ with client.open_sftp() as sftp:
+ sftp.putfo(file_obj, remotepath)
+ if mode is not None:
+ sftp.chmod(remotepath, mode)
+
+ def get_file_obj(self, remotepath, file_obj):
+ client = self._get_client()
+
+ with client.open_sftp() as sftp:
+ sftp.getfo(remotepath, file_obj)
+
+
+class AutoConnectSSH(SSH):
+
+ @classmethod
+ def get_arg_key_map(cls):
+ arg_key_map = super(AutoConnectSSH, cls).get_arg_key_map()
+ arg_key_map['wait'] = ('wait', True)
+ return arg_key_map
+
+ # always wait or we will get OpenStack SSH errors
+ def __init__(self, user, host, port=None, pkey=None,
+ key_filename=None, password=None, name=None, wait=True):
+ super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
+ if wait and wait is not True:
+ self.wait_timeout = int(wait)
+
+ def _make_dict(self):
+ data = super(AutoConnectSSH, self)._make_dict()
+ data.update({
+ 'wait': self.wait_timeout
+ })
+ return data
+
+ def _connect(self):
+ if not self.is_connected:
+ interval = 1
+ timeout = self.wait_timeout
+
+ end_time = time.time() + timeout
+ while True:
+ try:
+ return self._get_client()
+ except (socket.error, exceptions.SSHError) as e:
+ self.log.debug("Ssh is still unavailable: %r", e)
+ time.sleep(interval)
+ if time.time() > end_time:
+ raise exceptions.SSHTimeout(
+ error_msg='Timeout waiting for "%s"' % self.host)
+
+ def drop_connection(self):
+ """ Don't close anything, just force creation of a new client """
+ self._client = False
+
+ def execute(self, cmd, stdin=None, timeout=3600):
+ self._connect()
+ return super(AutoConnectSSH, self).execute(cmd, stdin, timeout)
+
+ def run(self, cmd, stdin=None, stdout=None, stderr=None,
+ raise_on_error=True, timeout=3600,
+ keep_stdin_open=False, pty=False):
+ self._connect()
+ return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
+ timeout, keep_stdin_open, pty)
+
+ def put(self, files, remote_path=b'.', recursive=False):
+ self._connect()
+ return super(AutoConnectSSH, self).put(files, remote_path, recursive)
+
+ def put_file(self, local_path, remote_path, mode=None):
+ self._connect()
+ return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
+
+ def put_file_obj(self, file_obj, remote_path, mode=None):
+ self._connect()
+ return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
+
+ def get_file_obj(self, remote_path, file_obj):
+ self._connect()
+ return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
+
+ def provision_tool(self, tool_path, tool_file=None):
+ self._connect()
+ return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
+
+ @staticmethod
+ def get_class():
+ # must return static class name, anything else refers to the calling class
+ # i.e. the subclass, not the superclass
+ return AutoConnectSSH