add yardstick iruya 9.0.0 release notes
[yardstick.git] / yardstick / ssh.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 # yardstick comment: this is a modified copy of rally/rally/common/sshutils.py
17
18 """High level ssh library.
19
20 Usage examples:
21
22 Execute command and get output:
23
24     ssh = sshclient.SSH("root", "example.com", port=33)
25     status, stdout, stderr = ssh.execute("ps ax")
26     if status:
27         raise Exception("Command failed with non-zero status.")
28     print(stdout.splitlines())
29
30 Execute command with huge output:
31
32     class PseudoFile(io.RawIOBase):
33         def write(chunk):
34             if "error" in chunk:
35                 email_admin(chunk)
36
37     ssh = SSH("root", "example.com")
38     with PseudoFile() as p:
39         ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
40
41 Execute local script on remote side:
42
43     ssh = sshclient.SSH("user", "example.com")
44
45     with open("~/myscript.sh", "r") as stdin_file:
46         status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
47                                        stdin=stdin_file)
48
49 Upload file:
50
51     ssh = SSH("user", "example.com")
52     # use rb for binary files
53     with open("/store/file.gz", "rb") as stdin_file:
54         ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
55
56 Eventlet:
57
58     eventlet.monkey_patch(select=True, time=True)
59     or
60     eventlet.monkey_patch()
61     or
62     sshclient = eventlet.import_patched("yardstick.ssh")
63
64 """
65 import io
66 import logging
67 import os
68 import re
69 import select
70 import socket
71 import time
72
73 import paramiko
74 from chainmap import ChainMap
75 from oslo_utils import encodeutils
76 from scp import SCPClient
77 import six
78
79 from yardstick.common import exceptions
80 from yardstick.common.utils import try_int, NON_NONE_DEFAULT, make_dict_from_map
81 from yardstick.network_services.utils import provision_tool
82
83 LOG = logging.getLogger(__name__)
84
85 def convert_key_to_str(key):
86     if not isinstance(key, (paramiko.RSAKey, paramiko.DSSKey)):
87         return key
88     k = io.StringIO()
89     key.write_private_key(k)
90     return k.getvalue()
91
92
93 class SSH(object):
94     """Represent ssh connection."""
95
96     SSH_PORT = paramiko.config.SSH_PORT
97     DEFAULT_WAIT_TIMEOUT = 120
98
99     @staticmethod
100     def gen_keys(key_filename, bit_count=2048):
101         rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None)
102         rsa_key.write_private_key_file(key_filename)
103         print("Writing %s ..." % key_filename)
104         with open('.'.join([key_filename, "pub"]), "w") as pubkey_file:
105             pubkey_file.write(rsa_key.get_name())
106             pubkey_file.write(' ')
107             pubkey_file.write(rsa_key.get_base64())
108             pubkey_file.write('\n')
109
110     @staticmethod
111     def get_class():
112         # must return static class name, anything else refers to the calling class
113         # i.e. the subclass, not the superclass
114         return SSH
115
116     @classmethod
117     def get_arg_key_map(cls):
118         return {
119             'user': ('user', NON_NONE_DEFAULT),
120             'host': ('ip', NON_NONE_DEFAULT),
121             'port': ('ssh_port', cls.SSH_PORT),
122             'pkey': ('pkey', None),
123             'key_filename': ('key_filename', None),
124             'password': ('password', None),
125             'name': ('name', None),
126         }
127
128     def __init__(self, user, host, port=None, pkey=None,
129                  key_filename=None, password=None, name=None):
130         """Initialize SSH client.
131
132         :param user: ssh username
133         :param host: hostname or ip address of remote ssh server
134         :param port: remote ssh port
135         :param pkey: RSA or DSS private key string or file object
136         :param key_filename: private key filename
137         :param password: password
138         """
139         self.name = name
140         if name:
141             self.log = logging.getLogger(__name__ + '.' + self.name)
142         else:
143             self.log = logging.getLogger(__name__)
144
145         self.wait_timeout = self.DEFAULT_WAIT_TIMEOUT
146         self.user = user
147         self.host = host
148         # everybody wants to debug this in the caller, do it here instead
149         self.log.debug("user:%s host:%s", user, host)
150
151         # we may get text port from YAML, convert to int
152         self.port = try_int(port, self.SSH_PORT)
153         self.pkey = self._get_pkey(pkey) if pkey else None
154         self.password = password
155         self.key_filename = key_filename
156         self._client = False
157         # paramiko loglevel debug will output ssh protocl debug
158         # we don't ever really want that unless we are debugging paramiko
159         # ssh issues
160         if os.environ.get("PARAMIKO_DEBUG", "").lower() == "true":
161             logging.getLogger("paramiko").setLevel(logging.DEBUG)
162         else:
163             logging.getLogger("paramiko").setLevel(logging.WARN)
164
165     @classmethod
166     def args_from_node(cls, node, overrides=None, defaults=None):
167         if overrides is None:
168             overrides = {}
169         if defaults is None:
170             defaults = {}
171
172         params = ChainMap(overrides, node, defaults)
173         return make_dict_from_map(params, cls.get_arg_key_map())
174
175     @classmethod
176     def from_node(cls, node, overrides=None, defaults=None):
177         return cls(**cls.args_from_node(node, overrides, defaults))
178
179     def _get_pkey(self, key):
180         if isinstance(key, six.string_types):
181             key = six.moves.StringIO(key)
182         errors = []
183         for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
184             try:
185                 return key_class.from_private_key(key)
186             except paramiko.SSHException as e:
187                 errors.append(e)
188         raise exceptions.SSHError(error_msg='Invalid pkey: %s' % errors)
189
190     @property
191     def is_connected(self):
192         return bool(self._client)
193
194     def _get_client(self):
195         if self.is_connected:
196             return self._client
197         try:
198             self._client = paramiko.SSHClient()
199             self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
200             self._client.connect(self.host, username=self.user,
201                                  port=self.port, pkey=self.pkey,
202                                  key_filename=self.key_filename,
203                                  password=self.password,
204                                  allow_agent=False, look_for_keys=False,
205                                  timeout=1)
206             return self._client
207         except Exception as e:
208             message = ("Exception %(exception_type)s was raised "
209                        "during connect. Exception value is: %(exception)r" %
210                        {"exception": e, "exception_type": type(e)})
211             self._client = False
212             raise exceptions.SSHError(error_msg=message)
213
214     def _make_dict(self):
215         return {
216             'user': self.user,
217             'host': self.host,
218             'port': self.port,
219             'pkey': self.pkey,
220             'key_filename': self.key_filename,
221             'password': self.password,
222             'name': self.name,
223         }
224
225     def copy(self):
226         return self.get_class()(**self._make_dict())
227
228     def close(self):
229         if self._client:
230             self._client.close()
231             self._client = False
232
233     def run(self, cmd, stdin=None, stdout=None, stderr=None,
234             raise_on_error=True, timeout=3600,
235             keep_stdin_open=False, pty=False):
236         """Execute specified command on the server.
237
238         :param cmd:             Command to be executed.
239         :type cmd:              str
240         :param stdin:           Open file or string to pass to stdin.
241         :param stdout:          Open file to connect to stdout.
242         :param stderr:          Open file to connect to stderr.
243         :param raise_on_error:  If False then exit code will be return. If True
244                                 then exception will be raized if non-zero code.
245         :param timeout:         Timeout in seconds for command execution.
246                                 Default 1 hour. No timeout if set to 0.
247         :param keep_stdin_open: don't close stdin on empty reads
248         :type keep_stdin_open:  bool
249         :param pty:             Request a pseudo terminal for this connection.
250                                 This allows passing control characters.
251                                 Default False.
252         :type pty:              bool
253         """
254
255         client = self._get_client()
256
257         if isinstance(stdin, six.string_types):
258             stdin = six.moves.StringIO(stdin)
259
260         return self._run(client, cmd, stdin=stdin, stdout=stdout,
261                          stderr=stderr, raise_on_error=raise_on_error,
262                          timeout=timeout,
263                          keep_stdin_open=keep_stdin_open, pty=pty)
264
265     def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
266              raise_on_error=True, timeout=3600,
267              keep_stdin_open=False, pty=False):
268
269         transport = client.get_transport()
270         session = transport.open_session()
271         if pty:
272             session.get_pty()
273         session.exec_command(cmd)
274         start_time = time.time()
275
276         # encode on transmit, decode on receive
277         data_to_send = encodeutils.safe_encode("", incoming='utf-8')
278         stderr_data = None
279
280         # If we have data to be sent to stdin then `select' should also
281         # check for stdin availability.
282         if stdin and not stdin.closed:
283             writes = [session]
284         else:
285             writes = []
286
287         while True:
288             # Block until data can be read/write.
289             e = select.select([session], writes, [session], 1)[2]
290
291             if session.recv_ready():
292                 data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
293                 self.log.debug("stdout: %r", data)
294                 if stdout is not None:
295                     stdout.write(data)
296                 continue
297
298             if session.recv_stderr_ready():
299                 stderr_data = encodeutils.safe_decode(
300                     session.recv_stderr(4096), 'utf-8')
301                 self.log.debug("stderr: %r", stderr_data)
302                 if stderr is not None:
303                     stderr.write(stderr_data)
304                 continue
305
306             if session.send_ready():
307                 if stdin is not None and not stdin.closed:
308                     if not data_to_send:
309                         stdin_txt = stdin.read(4096)
310                         if stdin_txt is None:
311                             stdin_txt = ''
312                         data_to_send = encodeutils.safe_encode(
313                             stdin_txt, incoming='utf-8')
314                         if not data_to_send:
315                             # we may need to keep stdin open
316                             if not keep_stdin_open:
317                                 stdin.close()
318                                 session.shutdown_write()
319                                 writes = []
320                     if data_to_send:
321                         sent_bytes = session.send(data_to_send)
322                         # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
323                         data_to_send = data_to_send[sent_bytes:]
324
325             if session.exit_status_ready():
326                 break
327
328             if timeout and (time.time() - timeout) > start_time:
329                 message = ('Timeout executing command %(cmd)s on host %(host)s'
330                            % {"cmd": cmd, "host": self.host})
331                 raise exceptions.SSHTimeout(error_msg=message)
332             if e:
333                 raise exceptions.SSHError(error_msg='Socket error')
334
335         exit_status = session.recv_exit_status()
336         if exit_status != 0 and raise_on_error:
337             fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
338             details = fmt % {"cmd": cmd, "status": exit_status}
339             if stderr_data:
340                 details += " Last stderr data: '%s'." % stderr_data
341             LOG.critical("PROX ERROR: %s", details)
342             raise exceptions.SSHError(error_msg=details)
343         return exit_status
344
345     def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False):
346         """Execute the specified command on the server.
347
348         :param cmd: (str)             Command to be executed.
349         :param stdin: (StringIO)      Open file to be sent on process stdin.
350         :param timeout: (int)         Timeout for execution of the command.
351         :param raise_on_error: (bool) If True, then an SSHError will be raised
352                                       when non-zero exit code.
353
354         :returns: tuple (exit_status, stdout, stderr)
355         """
356         stdout = six.moves.StringIO()
357         stderr = six.moves.StringIO()
358
359         exit_status = self.run(cmd, stderr=stderr,
360                                stdout=stdout, stdin=stdin,
361                                timeout=timeout, raise_on_error=raise_on_error)
362         stdout.seek(0)
363         stderr.seek(0)
364         return exit_status, stdout.read(), stderr.read()
365
366     def wait(self, timeout=None, interval=1):
367         """Wait for the host will be available via ssh."""
368         if timeout is None:
369             timeout = self.wait_timeout
370
371         end_time = time.time() + timeout
372         while True:
373             try:
374                 return self.execute("uname")
375             except (socket.error, exceptions.SSHError) as e:
376                 self.log.debug("Ssh is still unavailable: %r", e)
377                 time.sleep(interval)
378             if time.time() > end_time:
379                 raise exceptions.SSHTimeout(
380                     error_msg='Timeout waiting for "%s"' % self.host)
381
382     def put(self, files, remote_path=b'.', recursive=False):
383         client = self._get_client()
384
385         with SCPClient(client.get_transport()) as scp:
386             scp.put(files, remote_path, recursive)
387
388     def get(self, remote_path, local_path='/tmp/', recursive=True):
389         client = self._get_client()
390
391         with SCPClient(client.get_transport()) as scp:
392             scp.get(remote_path, local_path, recursive)
393
394     # keep shell running in the background, e.g. screen
395     def send_command(self, command):
396         client = self._get_client()
397         client.exec_command(command, get_pty=True)
398
399     def _put_file_sftp(self, localpath, remotepath, mode=None):
400         client = self._get_client()
401
402         with client.open_sftp() as sftp:
403             sftp.put(localpath, remotepath)
404             if mode is None:
405                 mode = 0o777 & os.stat(localpath).st_mode
406             sftp.chmod(remotepath, mode)
407
408     TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
409
410     def _put_file_shell(self, localpath, remotepath, mode=None):
411         # quote to stop wordpslit
412         tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
413         if not tilde:
414             tilde = ''
415         cmd = ['cat > %s"%s"' % (tilde, remotepath)]
416         if mode is not None:
417             # use -- so no options
418             cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
419
420         with open(localpath, "rb") as localfile:
421             # only chmod on successful cat
422             self.run("&& ".join(cmd), stdin=localfile)
423
424     def put_file(self, localpath, remotepath, mode=None):
425         """Copy specified local file to the server.
426
427         :param localpath:   Local filename.
428         :param remotepath:  Remote filename.
429         :param mode:        Permissions to set after upload
430         """
431         try:
432             self._put_file_sftp(localpath, remotepath, mode=mode)
433         except (paramiko.SSHException, socket.error):
434             self._put_file_shell(localpath, remotepath, mode=mode)
435
436     def provision_tool(self, tool_path, tool_file=None):
437         return provision_tool(self, tool_path, tool_file)
438
439     def put_file_obj(self, file_obj, remotepath, mode=None):
440         client = self._get_client()
441
442         with client.open_sftp() as sftp:
443             sftp.putfo(file_obj, remotepath)
444             if mode is not None:
445                 sftp.chmod(remotepath, mode)
446
447     def get_file_obj(self, remotepath, file_obj):
448         client = self._get_client()
449
450         with client.open_sftp() as sftp:
451             sftp.getfo(remotepath, file_obj)
452
453     def interactive_terminal_open(self, time_out=45):
454         """Open interactive terminal on a SSH channel.
455
456         :param time_out: Timeout in seconds.
457         :returns: SSH channel with opened terminal.
458
459         .. warning:: Interruptingcow is used here, and it uses
460            signal(SIGALRM) to let the operating system interrupt program
461            execution. This has the following limitations: Python signal
462            handlers only apply to the main thread, so you cannot use this
463            from other threads. You must not use this in a program that
464            uses SIGALRM itself (this includes certain profilers)
465         """
466         chan = self._get_client().get_transport().open_session()
467         chan.get_pty()
468         chan.invoke_shell()
469         chan.settimeout(int(time_out))
470         chan.set_combine_stderr(True)
471
472         buf = ''
473         while not buf.endswith((":~# ", ":~$ ", "~]$ ", "~]# ")):
474             try:
475                 chunk = chan.recv(10 * 1024 * 1024)
476                 if not chunk:
477                     break
478                 buf += chunk
479                 if chan.exit_status_ready():
480                     self.log.error('Channel exit status ready')
481                     break
482             except socket.timeout:
483                 raise exceptions.SSHTimeout(error_msg='Socket timeout: %s' % buf)
484         return chan
485
486     def interactive_terminal_exec_command(self, chan, cmd, prompt):
487         """Execute command on interactive terminal.
488
489         interactive_terminal_open() method has to be called first!
490
491         :param chan: SSH channel with opened terminal.
492         :param cmd: Command to be executed.
493         :param prompt: Command prompt, sequence of characters used to
494         indicate readiness to accept commands.
495         :returns: Command output.
496
497         .. warning:: Interruptingcow is used here, and it uses
498            signal(SIGALRM) to let the operating system interrupt program
499            execution. This has the following limitations: Python signal
500            handlers only apply to the main thread, so you cannot use this
501            from other threads. You must not use this in a program that
502            uses SIGALRM itself (this includes certain profilers)
503         """
504         chan.sendall('{c}\n'.format(c=cmd))
505         buf = ''
506         while not buf.endswith(prompt):
507             try:
508                 chunk = chan.recv(10 * 1024 * 1024)
509                 if not chunk:
510                     break
511                 buf += chunk
512                 if chan.exit_status_ready():
513                     self.log.error('Channel exit status ready')
514                     break
515             except socket.timeout:
516                 message = ("Socket timeout during execution of command: "
517                            "%(cmd)s\nBuffer content:\n%(buf)s" % {"cmd": cmd,
518                                                                   "buf": buf})
519                 raise exceptions.SSHTimeout(error_msg=message)
520         tmp = buf.replace(cmd.replace('\n', ''), '')
521         for item in prompt:
522             tmp.replace(item, '')
523         return tmp
524
525     @staticmethod
526     def interactive_terminal_close(chan):
527         """Close interactive terminal SSH channel.
528
529         :param: chan: SSH channel to be closed.
530         """
531         chan.close()
532
533
534 class AutoConnectSSH(SSH):
535
536     @classmethod
537     def get_arg_key_map(cls):
538         arg_key_map = super(AutoConnectSSH, cls).get_arg_key_map()
539         arg_key_map['wait'] = ('wait', True)
540         return arg_key_map
541
542     # always wait or we will get OpenStack SSH errors
543     def __init__(self, user, host, port=None, pkey=None,
544                  key_filename=None, password=None, name=None, wait=True):
545         super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
546         if wait and wait is not True:
547             self.wait_timeout = int(wait)
548
549     def _make_dict(self):
550         data = super(AutoConnectSSH, self)._make_dict()
551         data.update({
552             'wait': self.wait_timeout
553         })
554         return data
555
556     def _connect(self):
557         if not self.is_connected:
558             interval = 1
559             timeout = self.wait_timeout
560
561             end_time = time.time() + timeout
562             while True:
563                 try:
564                     return self._get_client()
565                 except (socket.error, exceptions.SSHError) as e:
566                     self.log.debug("Ssh is still unavailable: %r", e)
567                     time.sleep(interval)
568                 if time.time() > end_time:
569                     raise exceptions.SSHTimeout(
570                         error_msg='Timeout waiting for "%s"' % self.host)
571
572     def drop_connection(self):
573         """ Don't close anything, just force creation of a new client """
574         self._client = False
575
576     def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False):
577         self._connect()
578         return super(AutoConnectSSH, self).execute(cmd, stdin, timeout,
579                                                    raise_on_error)
580
581     def run(self, cmd, stdin=None, stdout=None, stderr=None,
582             raise_on_error=True, timeout=3600,
583             keep_stdin_open=False, pty=False):
584         self._connect()
585         return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
586                                                timeout, keep_stdin_open, pty)
587
588     def put(self, files, remote_path=b'.', recursive=False):
589         self._connect()
590         return super(AutoConnectSSH, self).put(files, remote_path, recursive)
591
592     def put_file(self, local_path, remote_path, mode=None):
593         self._connect()
594         return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
595
596     def put_file_obj(self, file_obj, remote_path, mode=None):
597         self._connect()
598         return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
599
600     def get_file_obj(self, remote_path, file_obj):
601         self._connect()
602         return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
603
604     def provision_tool(self, tool_path, tool_file=None):
605         self._connect()
606         return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
607
608     @staticmethod
609     def get_class():
610         # must return static class name, anything else refers to the calling class
611         # i.e. the subclass, not the superclass
612         return AutoConnectSSH