Merge "Concurrency testcases to be configured over cli"
[yardstick.git] / yardstick / ssh.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 # yardstick comment: this is a modified copy of rally/rally/common/sshutils.py
17
18 """High level ssh library.
19
20 Usage examples:
21
22 Execute command and get output:
23
24     ssh = sshclient.SSH("root", "example.com", port=33)
25     status, stdout, stderr = ssh.execute("ps ax")
26     if status:
27         raise Exception("Command failed with non-zero status.")
28     print(stdout.splitlines())
29
30 Execute command with huge output:
31
32     class PseudoFile(io.RawIOBase):
33         def write(chunk):
34             if "error" in chunk:
35                 email_admin(chunk)
36
37     ssh = SSH("root", "example.com")
38     with PseudoFile() as p:
39         ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
40
41 Execute local script on remote side:
42
43     ssh = sshclient.SSH("user", "example.com")
44
45     with open("~/myscript.sh", "r") as stdin_file:
46         status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
47                                        stdin=stdin_file)
48
49 Upload file:
50
51     ssh = SSH("user", "example.com")
52     # use rb for binary files
53     with open("/store/file.gz", "rb") as stdin_file:
54         ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
55
56 Eventlet:
57
58     eventlet.monkey_patch(select=True, time=True)
59     or
60     eventlet.monkey_patch()
61     or
62     sshclient = eventlet.import_patched("yardstick.ssh")
63
64 """
65 import io
66 import logging
67 import os
68 import re
69 import select
70 import socket
71 import time
72
73 import paramiko
74 from chainmap import ChainMap
75 from oslo_utils import encodeutils
76 from scp import SCPClient
77 import six
78
79 from yardstick.common import exceptions
80 from yardstick.common.utils import try_int, NON_NONE_DEFAULT, make_dict_from_map
81 from yardstick.network_services.utils import provision_tool
82
83
84 def convert_key_to_str(key):
85     if not isinstance(key, (paramiko.RSAKey, paramiko.DSSKey)):
86         return key
87     k = io.StringIO()
88     key.write_private_key(k)
89     return k.getvalue()
90
91
92 class SSH(object):
93     """Represent ssh connection."""
94
95     SSH_PORT = paramiko.config.SSH_PORT
96     DEFAULT_WAIT_TIMEOUT = 120
97
98     @staticmethod
99     def gen_keys(key_filename, bit_count=2048):
100         rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None)
101         rsa_key.write_private_key_file(key_filename)
102         print("Writing %s ..." % key_filename)
103         with open('.'.join([key_filename, "pub"]), "w") as pubkey_file:
104             pubkey_file.write(rsa_key.get_name())
105             pubkey_file.write(' ')
106             pubkey_file.write(rsa_key.get_base64())
107             pubkey_file.write('\n')
108
109     @staticmethod
110     def get_class():
111         # must return static class name, anything else refers to the calling class
112         # i.e. the subclass, not the superclass
113         return SSH
114
115     @classmethod
116     def get_arg_key_map(cls):
117         return {
118             'user': ('user', NON_NONE_DEFAULT),
119             'host': ('ip', NON_NONE_DEFAULT),
120             'port': ('ssh_port', cls.SSH_PORT),
121             'pkey': ('pkey', None),
122             'key_filename': ('key_filename', None),
123             'password': ('password', None),
124             'name': ('name', None),
125         }
126
127     def __init__(self, user, host, port=None, pkey=None,
128                  key_filename=None, password=None, name=None):
129         """Initialize SSH client.
130
131         :param user: ssh username
132         :param host: hostname or ip address of remote ssh server
133         :param port: remote ssh port
134         :param pkey: RSA or DSS private key string or file object
135         :param key_filename: private key filename
136         :param password: password
137         """
138         self.name = name
139         if name:
140             self.log = logging.getLogger(__name__ + '.' + self.name)
141         else:
142             self.log = logging.getLogger(__name__)
143
144         self.wait_timeout = self.DEFAULT_WAIT_TIMEOUT
145         self.user = user
146         self.host = host
147         # everybody wants to debug this in the caller, do it here instead
148         self.log.debug("user:%s host:%s", user, host)
149
150         # we may get text port from YAML, convert to int
151         self.port = try_int(port, self.SSH_PORT)
152         self.pkey = self._get_pkey(pkey) if pkey else None
153         self.password = password
154         self.key_filename = key_filename
155         self._client = False
156         # paramiko loglevel debug will output ssh protocl debug
157         # we don't ever really want that unless we are debugging paramiko
158         # ssh issues
159         if os.environ.get("PARAMIKO_DEBUG", "").lower() == "true":
160             logging.getLogger("paramiko").setLevel(logging.DEBUG)
161         else:
162             logging.getLogger("paramiko").setLevel(logging.WARN)
163
164     @classmethod
165     def args_from_node(cls, node, overrides=None, defaults=None):
166         if overrides is None:
167             overrides = {}
168         if defaults is None:
169             defaults = {}
170
171         params = ChainMap(overrides, node, defaults)
172         return make_dict_from_map(params, cls.get_arg_key_map())
173
174     @classmethod
175     def from_node(cls, node, overrides=None, defaults=None):
176         return cls(**cls.args_from_node(node, overrides, defaults))
177
178     def _get_pkey(self, key):
179         if isinstance(key, six.string_types):
180             key = six.moves.StringIO(key)
181         errors = []
182         for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
183             try:
184                 return key_class.from_private_key(key)
185             except paramiko.SSHException as e:
186                 errors.append(e)
187         raise exceptions.SSHError(error_msg='Invalid pkey: %s' % errors)
188
189     @property
190     def is_connected(self):
191         return bool(self._client)
192
193     def _get_client(self):
194         if self.is_connected:
195             return self._client
196         try:
197             self._client = paramiko.SSHClient()
198             self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
199             self._client.connect(self.host, username=self.user,
200                                  port=self.port, pkey=self.pkey,
201                                  key_filename=self.key_filename,
202                                  password=self.password,
203                                  allow_agent=False, look_for_keys=False,
204                                  timeout=1)
205             return self._client
206         except Exception as e:
207             message = ("Exception %(exception_type)s was raised "
208                        "during connect. Exception value is: %(exception)r" %
209                        {"exception": e, "exception_type": type(e)})
210             self._client = False
211             raise exceptions.SSHError(error_msg=message)
212
213     def _make_dict(self):
214         return {
215             'user': self.user,
216             'host': self.host,
217             'port': self.port,
218             'pkey': self.pkey,
219             'key_filename': self.key_filename,
220             'password': self.password,
221             'name': self.name,
222         }
223
224     def copy(self):
225         return self.get_class()(**self._make_dict())
226
227     def close(self):
228         if self._client:
229             self._client.close()
230             self._client = False
231
232     def run(self, cmd, stdin=None, stdout=None, stderr=None,
233             raise_on_error=True, timeout=3600,
234             keep_stdin_open=False, pty=False):
235         """Execute specified command on the server.
236
237         :param cmd:             Command to be executed.
238         :type cmd:              str
239         :param stdin:           Open file or string to pass to stdin.
240         :param stdout:          Open file to connect to stdout.
241         :param stderr:          Open file to connect to stderr.
242         :param raise_on_error:  If False then exit code will be return. If True
243                                 then exception will be raized if non-zero code.
244         :param timeout:         Timeout in seconds for command execution.
245                                 Default 1 hour. No timeout if set to 0.
246         :param keep_stdin_open: don't close stdin on empty reads
247         :type keep_stdin_open:  bool
248         :param pty:             Request a pseudo terminal for this connection.
249                                 This allows passing control characters.
250                                 Default False.
251         :type pty:              bool
252         """
253
254         client = self._get_client()
255
256         if isinstance(stdin, six.string_types):
257             stdin = six.moves.StringIO(stdin)
258
259         return self._run(client, cmd, stdin=stdin, stdout=stdout,
260                          stderr=stderr, raise_on_error=raise_on_error,
261                          timeout=timeout,
262                          keep_stdin_open=keep_stdin_open, pty=pty)
263
264     def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
265              raise_on_error=True, timeout=3600,
266              keep_stdin_open=False, pty=False):
267
268         transport = client.get_transport()
269         session = transport.open_session()
270         if pty:
271             session.get_pty()
272         session.exec_command(cmd)
273         start_time = time.time()
274
275         # encode on transmit, decode on receive
276         data_to_send = encodeutils.safe_encode("", incoming='utf-8')
277         stderr_data = None
278
279         # If we have data to be sent to stdin then `select' should also
280         # check for stdin availability.
281         if stdin and not stdin.closed:
282             writes = [session]
283         else:
284             writes = []
285
286         while True:
287             # Block until data can be read/write.
288             e = select.select([session], writes, [session], 1)[2]
289
290             if session.recv_ready():
291                 data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
292                 self.log.debug("stdout: %r", data)
293                 if stdout is not None:
294                     stdout.write(data)
295                 continue
296
297             if session.recv_stderr_ready():
298                 stderr_data = encodeutils.safe_decode(
299                     session.recv_stderr(4096), 'utf-8')
300                 self.log.debug("stderr: %r", stderr_data)
301                 if stderr is not None:
302                     stderr.write(stderr_data)
303                 continue
304
305             if session.send_ready():
306                 if stdin is not None and not stdin.closed:
307                     if not data_to_send:
308                         stdin_txt = stdin.read(4096)
309                         if stdin_txt is None:
310                             stdin_txt = ''
311                         data_to_send = encodeutils.safe_encode(
312                             stdin_txt, incoming='utf-8')
313                         if not data_to_send:
314                             # we may need to keep stdin open
315                             if not keep_stdin_open:
316                                 stdin.close()
317                                 session.shutdown_write()
318                                 writes = []
319                     if data_to_send:
320                         sent_bytes = session.send(data_to_send)
321                         # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
322                         data_to_send = data_to_send[sent_bytes:]
323
324             if session.exit_status_ready():
325                 break
326
327             if timeout and (time.time() - timeout) > start_time:
328                 message = ('Timeout executing command %(cmd)s on host %(host)s'
329                            % {"cmd": cmd, "host": self.host})
330                 raise exceptions.SSHTimeout(error_msg=message)
331             if e:
332                 raise exceptions.SSHError(error_msg='Socket error')
333
334         exit_status = session.recv_exit_status()
335         if exit_status != 0 and raise_on_error:
336             fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
337             details = fmt % {"cmd": cmd, "status": exit_status}
338             if stderr_data:
339                 details += " Last stderr data: '%s'." % stderr_data
340             raise exceptions.SSHError(error_msg=details)
341         return exit_status
342
343     def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False):
344         """Execute the specified command on the server.
345
346         :param cmd: (str)             Command to be executed.
347         :param stdin: (StringIO)      Open file to be sent on process stdin.
348         :param timeout: (int)         Timeout for execution of the command.
349         :param raise_on_error: (bool) If True, then an SSHError will be raised
350                                       when non-zero exit code.
351
352         :returns: tuple (exit_status, stdout, stderr)
353         """
354         stdout = six.moves.StringIO()
355         stderr = six.moves.StringIO()
356
357         exit_status = self.run(cmd, stderr=stderr,
358                                stdout=stdout, stdin=stdin,
359                                timeout=timeout, raise_on_error=raise_on_error)
360         stdout.seek(0)
361         stderr.seek(0)
362         return exit_status, stdout.read(), stderr.read()
363
364     def wait(self, timeout=None, interval=1):
365         """Wait for the host will be available via ssh."""
366         if timeout is None:
367             timeout = self.wait_timeout
368
369         end_time = time.time() + timeout
370         while True:
371             try:
372                 return self.execute("uname")
373             except (socket.error, exceptions.SSHError) as e:
374                 self.log.debug("Ssh is still unavailable: %r", e)
375                 time.sleep(interval)
376             if time.time() > end_time:
377                 raise exceptions.SSHTimeout(
378                     error_msg='Timeout waiting for "%s"' % self.host)
379
380     def put(self, files, remote_path=b'.', recursive=False):
381         client = self._get_client()
382
383         with SCPClient(client.get_transport()) as scp:
384             scp.put(files, remote_path, recursive)
385
386     def get(self, remote_path, local_path='/tmp/', recursive=True):
387         client = self._get_client()
388
389         with SCPClient(client.get_transport()) as scp:
390             scp.get(remote_path, local_path, recursive)
391
392     # keep shell running in the background, e.g. screen
393     def send_command(self, command):
394         client = self._get_client()
395         client.exec_command(command, get_pty=True)
396
397     def _put_file_sftp(self, localpath, remotepath, mode=None):
398         client = self._get_client()
399
400         with client.open_sftp() as sftp:
401             sftp.put(localpath, remotepath)
402             if mode is None:
403                 mode = 0o777 & os.stat(localpath).st_mode
404             sftp.chmod(remotepath, mode)
405
406     TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
407
408     def _put_file_shell(self, localpath, remotepath, mode=None):
409         # quote to stop wordpslit
410         tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
411         if not tilde:
412             tilde = ''
413         cmd = ['cat > %s"%s"' % (tilde, remotepath)]
414         if mode is not None:
415             # use -- so no options
416             cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
417
418         with open(localpath, "rb") as localfile:
419             # only chmod on successful cat
420             self.run("&& ".join(cmd), stdin=localfile)
421
422     def put_file(self, localpath, remotepath, mode=None):
423         """Copy specified local file to the server.
424
425         :param localpath:   Local filename.
426         :param remotepath:  Remote filename.
427         :param mode:        Permissions to set after upload
428         """
429         try:
430             self._put_file_sftp(localpath, remotepath, mode=mode)
431         except (paramiko.SSHException, socket.error):
432             self._put_file_shell(localpath, remotepath, mode=mode)
433
434     def provision_tool(self, tool_path, tool_file=None):
435         return provision_tool(self, tool_path, tool_file)
436
437     def put_file_obj(self, file_obj, remotepath, mode=None):
438         client = self._get_client()
439
440         with client.open_sftp() as sftp:
441             sftp.putfo(file_obj, remotepath)
442             if mode is not None:
443                 sftp.chmod(remotepath, mode)
444
445     def get_file_obj(self, remotepath, file_obj):
446         client = self._get_client()
447
448         with client.open_sftp() as sftp:
449             sftp.getfo(remotepath, file_obj)
450
451
452 class AutoConnectSSH(SSH):
453
454     @classmethod
455     def get_arg_key_map(cls):
456         arg_key_map = super(AutoConnectSSH, cls).get_arg_key_map()
457         arg_key_map['wait'] = ('wait', True)
458         return arg_key_map
459
460     # always wait or we will get OpenStack SSH errors
461     def __init__(self, user, host, port=None, pkey=None,
462                  key_filename=None, password=None, name=None, wait=True):
463         super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
464         if wait and wait is not True:
465             self.wait_timeout = int(wait)
466
467     def _make_dict(self):
468         data = super(AutoConnectSSH, self)._make_dict()
469         data.update({
470             'wait': self.wait_timeout
471         })
472         return data
473
474     def _connect(self):
475         if not self.is_connected:
476             interval = 1
477             timeout = self.wait_timeout
478
479             end_time = time.time() + timeout
480             while True:
481                 try:
482                     return self._get_client()
483                 except (socket.error, exceptions.SSHError) as e:
484                     self.log.debug("Ssh is still unavailable: %r", e)
485                     time.sleep(interval)
486                 if time.time() > end_time:
487                     raise exceptions.SSHTimeout(
488                         error_msg='Timeout waiting for "%s"' % self.host)
489
490     def drop_connection(self):
491         """ Don't close anything, just force creation of a new client """
492         self._client = False
493
494     def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False):
495         self._connect()
496         return super(AutoConnectSSH, self).execute(cmd, stdin, timeout,
497                                                    raise_on_error)
498
499     def run(self, cmd, stdin=None, stdout=None, stderr=None,
500             raise_on_error=True, timeout=3600,
501             keep_stdin_open=False, pty=False):
502         self._connect()
503         return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
504                                                timeout, keep_stdin_open, pty)
505
506     def put(self, files, remote_path=b'.', recursive=False):
507         self._connect()
508         return super(AutoConnectSSH, self).put(files, remote_path, recursive)
509
510     def put_file(self, local_path, remote_path, mode=None):
511         self._connect()
512         return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
513
514     def put_file_obj(self, file_obj, remote_path, mode=None):
515         self._connect()
516         return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
517
518     def get_file_obj(self, remote_path, file_obj):
519         self._connect()
520         return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
521
522     def provision_tool(self, tool_path, tool_file=None):
523         self._connect()
524         return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
525
526     @staticmethod
527     def get_class():
528         # must return static class name, anything else refers to the calling class
529         # i.e. the subclass, not the superclass
530         return AutoConnectSSH