X-Git-Url: https://gerrit.opnfv.org/gerrit/gitweb?a=blobdiff_plain;f=yardstick%2Fssh.py;h=a024cf64ae58ca35cf1506ec57530cf1c01d8f5c;hb=7f46ed48c51dde7b11502ebac482536048b70177;hp=46d53b7d240ca9887929f9d84d00bce2ac065814;hpb=ef6ef8ed8f81d950a2c3a1a6f95c1e83879a9310;p=yardstick.git diff --git a/yardstick/ssh.py b/yardstick/ssh.py index 46d53b7d2..a024cf64a 100644 --- a/yardstick/ssh.py +++ b/yardstick/ssh.py @@ -25,7 +25,7 @@ Execute command and get output: status, stdout, stderr = ssh.execute("ps ax") if status: raise Exception("Command failed with non-zero status.") - print stdout.splitlines() + print(stdout.splitlines()) Execute command with huge output: @@ -62,18 +62,23 @@ Eventlet: sshclient = eventlet.import_patched("yardstick.ssh") """ +from __future__ import absolute_import import os import select import socket import time +import re import logging + import paramiko +from chainmap import ChainMap +from oslo_utils import encodeutils from scp import SCPClient import six - -DEFAULT_PORT = 22 +from yardstick.common.utils import try_int +from yardstick.network_services.utils import provision_tool class SSHError(Exception): @@ -87,7 +92,26 @@ class SSHTimeout(SSHError): class SSH(object): """Represent ssh connection.""" - def __init__(self, user, host, port=DEFAULT_PORT, pkey=None, + SSH_PORT = paramiko.config.SSH_PORT + + @staticmethod + def gen_keys(key_filename, bit_count=2048): + rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None) + rsa_key.write_private_key_file(key_filename) + print("Writing %s ..." % key_filename) + with open('.'.join([key_filename, "pub"]), "w") as pubkey_file: + pubkey_file.write(rsa_key.get_name()) + pubkey_file.write(' ') + pubkey_file.write(rsa_key.get_base64()) + pubkey_file.write('\n') + + @staticmethod + def get_class(): + # must return static class name, anything else refers to the calling class + # i.e. the subclass, not the superclass + return SSH + + def __init__(self, user, host, port=None, pkey=None, key_filename=None, password=None, name=None): """Initialize SSH client. @@ -106,8 +130,11 @@ class SSH(object): self.user = user self.host = host + # everybody wants to debug this in the caller, do it here instead + self.log.debug("user:%s host:%s", user, host) + # we may get text port from YAML, convert to int - self.port = int(port) + self.port = try_int(port, self.SSH_PORT) self.pkey = self._get_pkey(pkey) if pkey else None self.password = password self.key_filename = key_filename @@ -120,6 +147,27 @@ class SSH(object): else: logging.getLogger("paramiko").setLevel(logging.WARN) + @classmethod + def args_from_node(cls, node, overrides=None, defaults=None): + if overrides is None: + overrides = {} + if defaults is None: + defaults = {} + params = ChainMap(overrides, node, defaults) + return { + 'user': params['user'], + 'host': params['ip'], + 'port': params.get('ssh_port', cls.SSH_PORT), + 'pkey': params.get('pkey'), + 'key_filename': params.get('key_filename'), + 'password': params.get('password'), + 'name': params.get('name'), + } + + @classmethod + def from_node(cls, node, overrides=None, defaults=None): + return cls(**cls.args_from_node(node, overrides, defaults)) + def _get_pkey(self, key): if isinstance(key, six.string_types): key = six.moves.StringIO(key) @@ -131,8 +179,12 @@ class SSH(object): errors.append(e) raise SSHError("Invalid pkey: %s" % (errors)) + @property + def is_connected(self): + return bool(self._client) + def _get_client(self): - if self._client: + if self.is_connected: return self._client try: self._client = paramiko.SSHClient() @@ -151,13 +203,28 @@ class SSH(object): raise SSHError(message % {"exception": e, "exception_type": type(e)}) + def _make_dict(self): + return { + 'user': self.user, + 'host': self.host, + 'port': self.port, + 'pkey': self.pkey, + 'key_filename': self.key_filename, + 'password': self.password, + 'name': self.name, + } + + def copy(self): + return self.get_class()(**self._make_dict()) + def close(self): - self._client.close() - self._client = False + if self._client: + self._client.close() + self._client = False def run(self, cmd, stdin=None, stdout=None, stderr=None, raise_on_error=True, timeout=3600, - keep_stdin_open=False): + keep_stdin_open=False, pty=False): """Execute specified command on the server. :param cmd: Command to be executed. @@ -171,6 +238,10 @@ class SSH(object): Default 1 hour. No timeout if set to 0. :param keep_stdin_open: don't close stdin on empty reads :type keep_stdin_open: bool + :param pty: Request a pseudo terminal for this connection. + This allows passing control characters. + Default False. + :type pty: bool """ client = self._get_client() @@ -181,18 +252,21 @@ class SSH(object): return self._run(client, cmd, stdin=stdin, stdout=stdout, stderr=stderr, raise_on_error=raise_on_error, timeout=timeout, - keep_stdin_open=keep_stdin_open) + keep_stdin_open=keep_stdin_open, pty=pty) def _run(self, client, cmd, stdin=None, stdout=None, stderr=None, raise_on_error=True, timeout=3600, - keep_stdin_open=False): + keep_stdin_open=False, pty=False): transport = client.get_transport() session = transport.open_session() + if pty: + session.get_pty() session.exec_command(cmd) start_time = time.time() - data_to_send = "" + # encode on transmit, decode on receive + data_to_send = encodeutils.safe_encode("", incoming='utf-8') stderr_data = None # If we have data to be sent to stdin then `select' should also @@ -207,14 +281,15 @@ class SSH(object): r, w, e = select.select([session], writes, [session], 1) if session.recv_ready(): - data = session.recv(4096) + data = encodeutils.safe_decode(session.recv(4096), 'utf-8') self.log.debug("stdout: %r", data) if stdout is not None: stdout.write(data) continue if session.recv_stderr_ready(): - stderr_data = session.recv_stderr(4096) + stderr_data = encodeutils.safe_decode( + session.recv_stderr(4096), 'utf-8') self.log.debug("stderr: %r", stderr_data) if stderr is not None: stderr.write(stderr_data) @@ -223,7 +298,11 @@ class SSH(object): if session.send_ready(): if stdin is not None and not stdin.closed: if not data_to_send: - data_to_send = stdin.read(4096) + stdin_txt = stdin.read(4096) + if stdin_txt is None: + stdin_txt = '' + data_to_send = encodeutils.safe_encode( + stdin_txt, incoming='utf-8') if not data_to_send: # we may need to keep stdin open if not keep_stdin_open: @@ -246,7 +325,7 @@ class SSH(object): raise SSHError("Socket error.") exit_status = session.recv_exit_status() - if 0 != exit_status and raise_on_error: + if exit_status != 0 and raise_on_error: fmt = "Command '%(cmd)s' failed with exit_status %(status)d." details = fmt % {"cmd": cmd, "status": exit_status} if stderr_data: @@ -271,7 +350,7 @@ class SSH(object): timeout=timeout, raise_on_error=False) stdout.seek(0) stderr.seek(0) - return (exit_status, stdout.read(), stderr.read()) + return exit_status, stdout.read(), stderr.read() def wait(self, timeout=120, interval=1): """Wait for the host will be available via ssh.""" @@ -305,17 +384,21 @@ class SSH(object): mode = 0o777 & os.stat(localpath).st_mode sftp.chmod(remotepath, mode) + TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)") + def _put_file_shell(self, localpath, remotepath, mode=None): # quote to stop wordpslit - cmd = ['cat > "%s"' % remotepath] + tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups() + if not tilde: + tilde = '' + cmd = ['cat > %s"%s"' % (tilde, remotepath)] if mode is not None: # use -- so no options - cmd.append('chmod -- 0%o "%s"' % (mode, remotepath)) + cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath)) with open(localpath, "rb") as localfile: # only chmod on successful cat - cmd = "&& ".join(cmd) - self.run(cmd, stdin=localfile) + self.run("&& ".join(cmd), stdin=localfile) def put_file(self, localpath, remotepath, mode=None): """Copy specified local file to the server. @@ -324,8 +407,86 @@ class SSH(object): :param remotepath: Remote filename. :param mode: Permissions to set after upload """ - import socket try: self._put_file_sftp(localpath, remotepath, mode=mode) except (paramiko.SSHException, socket.error): self._put_file_shell(localpath, remotepath, mode=mode) + + def provision_tool(self, tool_path, tool_file=None): + return provision_tool(self, tool_path, tool_file) + + def put_file_obj(self, file_obj, remotepath, mode=None): + client = self._get_client() + + with client.open_sftp() as sftp: + sftp.putfo(file_obj, remotepath) + if mode is not None: + sftp.chmod(remotepath, mode) + + def get_file_obj(self, remotepath, file_obj): + client = self._get_client() + + with client.open_sftp() as sftp: + sftp.getfo(remotepath, file_obj) + + +class AutoConnectSSH(SSH): + + def __init__(self, user, host, port=None, pkey=None, + key_filename=None, password=None, name=None, wait=False): + super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name) + self._wait = wait + + def _make_dict(self): + data = super(AutoConnectSSH, self)._make_dict() + data.update({ + 'wait': self._wait + }) + return data + + def _connect(self): + if not self.is_connected: + self._get_client() + if self._wait: + self.wait() + + def drop_connection(self): + """ Don't close anything, just force creation of a new client """ + self._client = False + + def execute(self, cmd, stdin=None, timeout=3600): + self._connect() + return super(AutoConnectSSH, self).execute(cmd, stdin, timeout) + + def run(self, cmd, stdin=None, stdout=None, stderr=None, + raise_on_error=True, timeout=3600, + keep_stdin_open=False, pty=False): + self._connect() + return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error, + timeout, keep_stdin_open, pty) + + def put(self, files, remote_path=b'.', recursive=False): + self._connect() + return super(AutoConnectSSH, self).put(files, remote_path, recursive) + + def put_file(self, local_path, remote_path, mode=None): + self._connect() + return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode) + + def put_file_obj(self, file_obj, remote_path, mode=None): + self._connect() + return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode) + + def get_file_obj(self, remote_path, file_obj): + self._connect() + return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj) + + def provision_tool(self, tool_path, tool_file=None): + self._connect() + return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file) + + @staticmethod + def get_class(): + # must return static class name, anything else refers to the calling class + # i.e. the subclass, not the superclass + return AutoConnectSSH