1 # Copyright 2013: Mirantis Inc.
4 # Licensed under the Apache License, Version 2.0 (the "License"); you may
5 # not use this file except in compliance with the License. You may obtain
6 # a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 # License for the specific language governing permissions and limitations
16 # yardstick comment: this is a modified copy of rally/rally/common/sshutils.py
18 """High level ssh library.
22 Execute command and get output:
24 ssh = sshclient.SSH("root", "example.com", port=33)
25 status, stdout, stderr = ssh.execute("ps ax")
27 raise Exception("Command failed with non-zero status.")
28 print(stdout.splitlines())
30 Execute command with huge output:
32 class PseudoFile(io.RawIOBase):
37 ssh = SSH("root", "example.com")
38 with PseudoFile() as p:
39 ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
41 Execute local script on remote side:
43 ssh = sshclient.SSH("user", "example.com")
45 with open("~/myscript.sh", "r") as stdin_file:
46 status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
51 ssh = SSH("user", "example.com")
52 # use rb for binary files
53 with open("/store/file.gz", "rb") as stdin_file:
54 ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
58 eventlet.monkey_patch(select=True, time=True)
60 eventlet.monkey_patch()
62 sshclient = eventlet.import_patched("yardstick.ssh")
74 from chainmap import ChainMap
75 from oslo_utils import encodeutils
76 from scp import SCPClient
79 from yardstick.common import exceptions
80 from yardstick.common.utils import try_int, NON_NONE_DEFAULT, make_dict_from_map
81 from yardstick.network_services.utils import provision_tool
83 LOG = logging.getLogger(__name__)
85 def convert_key_to_str(key):
86 if not isinstance(key, (paramiko.RSAKey, paramiko.DSSKey)):
89 key.write_private_key(k)
94 """Represent ssh connection."""
96 SSH_PORT = paramiko.config.SSH_PORT
97 DEFAULT_WAIT_TIMEOUT = 120
100 def gen_keys(key_filename, bit_count=2048):
101 rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None)
102 rsa_key.write_private_key_file(key_filename)
103 print("Writing %s ..." % key_filename)
104 with open('.'.join([key_filename, "pub"]), "w") as pubkey_file:
105 pubkey_file.write(rsa_key.get_name())
106 pubkey_file.write(' ')
107 pubkey_file.write(rsa_key.get_base64())
108 pubkey_file.write('\n')
112 # must return static class name, anything else refers to the calling class
113 # i.e. the subclass, not the superclass
117 def get_arg_key_map(cls):
119 'user': ('user', NON_NONE_DEFAULT),
120 'host': ('ip', NON_NONE_DEFAULT),
121 'port': ('ssh_port', cls.SSH_PORT),
122 'pkey': ('pkey', None),
123 'key_filename': ('key_filename', None),
124 'password': ('password', None),
125 'name': ('name', None),
128 def __init__(self, user, host, port=None, pkey=None,
129 key_filename=None, password=None, name=None):
130 """Initialize SSH client.
132 :param user: ssh username
133 :param host: hostname or ip address of remote ssh server
134 :param port: remote ssh port
135 :param pkey: RSA or DSS private key string or file object
136 :param key_filename: private key filename
137 :param password: password
141 self.log = logging.getLogger(__name__ + '.' + self.name)
143 self.log = logging.getLogger(__name__)
145 self.wait_timeout = self.DEFAULT_WAIT_TIMEOUT
148 # everybody wants to debug this in the caller, do it here instead
149 self.log.debug("user:%s host:%s", user, host)
151 # we may get text port from YAML, convert to int
152 self.port = try_int(port, self.SSH_PORT)
153 self.pkey = self._get_pkey(pkey) if pkey else None
154 self.password = password
155 self.key_filename = key_filename
157 # paramiko loglevel debug will output ssh protocl debug
158 # we don't ever really want that unless we are debugging paramiko
160 if os.environ.get("PARAMIKO_DEBUG", "").lower() == "true":
161 logging.getLogger("paramiko").setLevel(logging.DEBUG)
163 logging.getLogger("paramiko").setLevel(logging.WARN)
166 def args_from_node(cls, node, overrides=None, defaults=None):
167 if overrides is None:
172 params = ChainMap(overrides, node, defaults)
173 return make_dict_from_map(params, cls.get_arg_key_map())
176 def from_node(cls, node, overrides=None, defaults=None):
177 return cls(**cls.args_from_node(node, overrides, defaults))
179 def _get_pkey(self, key):
180 if isinstance(key, six.string_types):
181 key = six.moves.StringIO(key)
183 for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
185 return key_class.from_private_key(key)
186 except paramiko.SSHException as e:
188 raise exceptions.SSHError(error_msg='Invalid pkey: %s' % errors)
191 def is_connected(self):
192 return bool(self._client)
194 def _get_client(self):
195 if self.is_connected:
198 self._client = paramiko.SSHClient()
199 self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
200 self._client.connect(self.host, username=self.user,
201 port=self.port, pkey=self.pkey,
202 key_filename=self.key_filename,
203 password=self.password,
204 allow_agent=False, look_for_keys=False,
207 except Exception as e:
208 message = ("Exception %(exception_type)s was raised "
209 "during connect. Exception value is: %(exception)r" %
210 {"exception": e, "exception_type": type(e)})
212 raise exceptions.SSHError(error_msg=message)
214 def _make_dict(self):
220 'key_filename': self.key_filename,
221 'password': self.password,
226 return self.get_class()(**self._make_dict())
233 def run(self, cmd, stdin=None, stdout=None, stderr=None,
234 raise_on_error=True, timeout=3600,
235 keep_stdin_open=False, pty=False):
236 """Execute specified command on the server.
238 :param cmd: Command to be executed.
240 :param stdin: Open file or string to pass to stdin.
241 :param stdout: Open file to connect to stdout.
242 :param stderr: Open file to connect to stderr.
243 :param raise_on_error: If False then exit code will be return. If True
244 then exception will be raized if non-zero code.
245 :param timeout: Timeout in seconds for command execution.
246 Default 1 hour. No timeout if set to 0.
247 :param keep_stdin_open: don't close stdin on empty reads
248 :type keep_stdin_open: bool
249 :param pty: Request a pseudo terminal for this connection.
250 This allows passing control characters.
255 client = self._get_client()
257 if isinstance(stdin, six.string_types):
258 stdin = six.moves.StringIO(stdin)
260 return self._run(client, cmd, stdin=stdin, stdout=stdout,
261 stderr=stderr, raise_on_error=raise_on_error,
263 keep_stdin_open=keep_stdin_open, pty=pty)
265 def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
266 raise_on_error=True, timeout=3600,
267 keep_stdin_open=False, pty=False):
269 transport = client.get_transport()
270 session = transport.open_session()
273 session.exec_command(cmd)
274 start_time = time.time()
276 # encode on transmit, decode on receive
277 data_to_send = encodeutils.safe_encode("", incoming='utf-8')
280 # If we have data to be sent to stdin then `select' should also
281 # check for stdin availability.
282 if stdin and not stdin.closed:
288 # Block until data can be read/write.
289 e = select.select([session], writes, [session], 1)[2]
291 if session.recv_ready():
292 data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
293 self.log.debug("stdout: %r", data)
294 if stdout is not None:
298 if session.recv_stderr_ready():
299 stderr_data = encodeutils.safe_decode(
300 session.recv_stderr(4096), 'utf-8')
301 self.log.debug("stderr: %r", stderr_data)
302 if stderr is not None:
303 stderr.write(stderr_data)
306 if session.send_ready():
307 if stdin is not None and not stdin.closed:
309 stdin_txt = stdin.read(4096)
310 if stdin_txt is None:
312 data_to_send = encodeutils.safe_encode(
313 stdin_txt, incoming='utf-8')
315 # we may need to keep stdin open
316 if not keep_stdin_open:
318 session.shutdown_write()
321 sent_bytes = session.send(data_to_send)
322 # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
323 data_to_send = data_to_send[sent_bytes:]
325 if session.exit_status_ready():
328 if timeout and (time.time() - timeout) > start_time:
329 message = ('Timeout executing command %(cmd)s on host %(host)s'
330 % {"cmd": cmd, "host": self.host})
331 raise exceptions.SSHTimeout(error_msg=message)
333 raise exceptions.SSHError(error_msg='Socket error')
335 exit_status = session.recv_exit_status()
336 if exit_status != 0 and raise_on_error:
337 fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
338 details = fmt % {"cmd": cmd, "status": exit_status}
340 details += " Last stderr data: '%s'." % stderr_data
341 LOG.critical("PROX ERROR: %s", details)
342 raise exceptions.SSHError(error_msg=details)
345 def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False):
346 """Execute the specified command on the server.
348 :param cmd: (str) Command to be executed.
349 :param stdin: (StringIO) Open file to be sent on process stdin.
350 :param timeout: (int) Timeout for execution of the command.
351 :param raise_on_error: (bool) If True, then an SSHError will be raised
352 when non-zero exit code.
354 :returns: tuple (exit_status, stdout, stderr)
356 stdout = six.moves.StringIO()
357 stderr = six.moves.StringIO()
359 exit_status = self.run(cmd, stderr=stderr,
360 stdout=stdout, stdin=stdin,
361 timeout=timeout, raise_on_error=raise_on_error)
364 return exit_status, stdout.read(), stderr.read()
366 def wait(self, timeout=None, interval=1):
367 """Wait for the host will be available via ssh."""
369 timeout = self.wait_timeout
371 end_time = time.time() + timeout
374 return self.execute("uname")
375 except (socket.error, exceptions.SSHError) as e:
376 self.log.debug("Ssh is still unavailable: %r", e)
378 if time.time() > end_time:
379 raise exceptions.SSHTimeout(
380 error_msg='Timeout waiting for "%s"' % self.host)
382 def put(self, files, remote_path=b'.', recursive=False):
383 client = self._get_client()
385 with SCPClient(client.get_transport()) as scp:
386 scp.put(files, remote_path, recursive)
388 def get(self, remote_path, local_path='/tmp/', recursive=True):
389 client = self._get_client()
391 with SCPClient(client.get_transport()) as scp:
392 scp.get(remote_path, local_path, recursive)
394 # keep shell running in the background, e.g. screen
395 def send_command(self, command):
396 client = self._get_client()
397 client.exec_command(command, get_pty=True)
399 def _put_file_sftp(self, localpath, remotepath, mode=None):
400 client = self._get_client()
402 with client.open_sftp() as sftp:
403 sftp.put(localpath, remotepath)
405 mode = 0o777 & os.stat(localpath).st_mode
406 sftp.chmod(remotepath, mode)
408 TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
410 def _put_file_shell(self, localpath, remotepath, mode=None):
411 # quote to stop wordpslit
412 tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
415 cmd = ['cat > %s"%s"' % (tilde, remotepath)]
417 # use -- so no options
418 cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
420 with open(localpath, "rb") as localfile:
421 # only chmod on successful cat
422 self.run("&& ".join(cmd), stdin=localfile)
424 def put_file(self, localpath, remotepath, mode=None):
425 """Copy specified local file to the server.
427 :param localpath: Local filename.
428 :param remotepath: Remote filename.
429 :param mode: Permissions to set after upload
432 self._put_file_sftp(localpath, remotepath, mode=mode)
433 except (paramiko.SSHException, socket.error):
434 self._put_file_shell(localpath, remotepath, mode=mode)
436 def provision_tool(self, tool_path, tool_file=None):
437 return provision_tool(self, tool_path, tool_file)
439 def put_file_obj(self, file_obj, remotepath, mode=None):
440 client = self._get_client()
442 with client.open_sftp() as sftp:
443 sftp.putfo(file_obj, remotepath)
445 sftp.chmod(remotepath, mode)
447 def get_file_obj(self, remotepath, file_obj):
448 client = self._get_client()
450 with client.open_sftp() as sftp:
451 sftp.getfo(remotepath, file_obj)
453 def interactive_terminal_open(self, time_out=45):
454 """Open interactive terminal on a SSH channel.
456 :param time_out: Timeout in seconds.
457 :returns: SSH channel with opened terminal.
459 .. warning:: Interruptingcow is used here, and it uses
460 signal(SIGALRM) to let the operating system interrupt program
461 execution. This has the following limitations: Python signal
462 handlers only apply to the main thread, so you cannot use this
463 from other threads. You must not use this in a program that
464 uses SIGALRM itself (this includes certain profilers)
466 chan = self._get_client().get_transport().open_session()
469 chan.settimeout(int(time_out))
470 chan.set_combine_stderr(True)
473 while not buf.endswith((":~# ", ":~$ ", "~]$ ", "~]# ")):
475 chunk = chan.recv(10 * 1024 * 1024)
479 if chan.exit_status_ready():
480 self.log.error('Channel exit status ready')
482 except socket.timeout:
483 raise exceptions.SSHTimeout(error_msg='Socket timeout: %s' % buf)
486 def interactive_terminal_exec_command(self, chan, cmd, prompt):
487 """Execute command on interactive terminal.
489 interactive_terminal_open() method has to be called first!
491 :param chan: SSH channel with opened terminal.
492 :param cmd: Command to be executed.
493 :param prompt: Command prompt, sequence of characters used to
494 indicate readiness to accept commands.
495 :returns: Command output.
497 .. warning:: Interruptingcow is used here, and it uses
498 signal(SIGALRM) to let the operating system interrupt program
499 execution. This has the following limitations: Python signal
500 handlers only apply to the main thread, so you cannot use this
501 from other threads. You must not use this in a program that
502 uses SIGALRM itself (this includes certain profilers)
504 chan.sendall('{c}\n'.format(c=cmd))
506 while not buf.endswith(prompt):
508 chunk = chan.recv(10 * 1024 * 1024)
512 if chan.exit_status_ready():
513 self.log.error('Channel exit status ready')
515 except socket.timeout:
516 message = ("Socket timeout during execution of command: "
517 "%(cmd)s\nBuffer content:\n%(buf)s" % {"cmd": cmd,
519 raise exceptions.SSHTimeout(error_msg=message)
520 tmp = buf.replace(cmd.replace('\n', ''), '')
522 tmp.replace(item, '')
526 def interactive_terminal_close(chan):
527 """Close interactive terminal SSH channel.
529 :param: chan: SSH channel to be closed.
534 class AutoConnectSSH(SSH):
537 def get_arg_key_map(cls):
538 arg_key_map = super(AutoConnectSSH, cls).get_arg_key_map()
539 arg_key_map['wait'] = ('wait', True)
542 # always wait or we will get OpenStack SSH errors
543 def __init__(self, user, host, port=None, pkey=None,
544 key_filename=None, password=None, name=None, wait=True):
545 super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
546 if wait and wait is not True:
547 self.wait_timeout = int(wait)
549 def _make_dict(self):
550 data = super(AutoConnectSSH, self)._make_dict()
552 'wait': self.wait_timeout
557 if not self.is_connected:
559 timeout = self.wait_timeout
561 end_time = time.time() + timeout
564 return self._get_client()
565 except (socket.error, exceptions.SSHError) as e:
566 self.log.debug("Ssh is still unavailable: %r", e)
568 if time.time() > end_time:
569 raise exceptions.SSHTimeout(
570 error_msg='Timeout waiting for "%s"' % self.host)
572 def drop_connection(self):
573 """ Don't close anything, just force creation of a new client """
576 def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False):
578 return super(AutoConnectSSH, self).execute(cmd, stdin, timeout,
581 def run(self, cmd, stdin=None, stdout=None, stderr=None,
582 raise_on_error=True, timeout=3600,
583 keep_stdin_open=False, pty=False):
585 return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
586 timeout, keep_stdin_open, pty)
588 def put(self, files, remote_path=b'.', recursive=False):
590 return super(AutoConnectSSH, self).put(files, remote_path, recursive)
592 def put_file(self, local_path, remote_path, mode=None):
594 return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
596 def put_file_obj(self, file_obj, remote_path, mode=None):
598 return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
600 def get_file_obj(self, remote_path, file_obj):
602 return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
604 def provision_tool(self, tool_path, tool_file=None):
606 return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
610 # must return static class name, anything else refers to the calling class
611 # i.e. the subclass, not the superclass
612 return AutoConnectSSH