Merge "Add Testcase for Prox Standalone OvS-DPDK." into stable/gambia
[yardstick.git] / yardstick / ssh.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 # yardstick comment: this is a modified copy of rally/rally/common/sshutils.py
17
18 """High level ssh library.
19
20 Usage examples:
21
22 Execute command and get output:
23
24     ssh = sshclient.SSH("root", "example.com", port=33)
25     status, stdout, stderr = ssh.execute("ps ax")
26     if status:
27         raise Exception("Command failed with non-zero status.")
28     print(stdout.splitlines())
29
30 Execute command with huge output:
31
32     class PseudoFile(io.RawIOBase):
33         def write(chunk):
34             if "error" in chunk:
35                 email_admin(chunk)
36
37     ssh = SSH("root", "example.com")
38     with PseudoFile() as p:
39         ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
40
41 Execute local script on remote side:
42
43     ssh = sshclient.SSH("user", "example.com")
44
45     with open("~/myscript.sh", "r") as stdin_file:
46         status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
47                                        stdin=stdin_file)
48
49 Upload file:
50
51     ssh = SSH("user", "example.com")
52     # use rb for binary files
53     with open("/store/file.gz", "rb") as stdin_file:
54         ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
55
56 Eventlet:
57
58     eventlet.monkey_patch(select=True, time=True)
59     or
60     eventlet.monkey_patch()
61     or
62     sshclient = eventlet.import_patched("yardstick.ssh")
63
64 """
65 import io
66 import logging
67 import os
68 import re
69 import select
70 import socket
71 import time
72
73 import paramiko
74 from chainmap import ChainMap
75 from oslo_utils import encodeutils
76 from scp import SCPClient
77 import six
78
79 from yardstick.common import exceptions
80 from yardstick.common.utils import try_int, NON_NONE_DEFAULT, make_dict_from_map
81 from yardstick.network_services.utils import provision_tool
82
83
84 def convert_key_to_str(key):
85     if not isinstance(key, (paramiko.RSAKey, paramiko.DSSKey)):
86         return key
87     k = io.StringIO()
88     key.write_private_key(k)
89     return k.getvalue()
90
91
92 # class SSHError(Exception):
93 #     pass
94 #
95 #
96 # class SSHTimeout(SSHError):
97 #     pass
98
99
100 class SSH(object):
101     """Represent ssh connection."""
102
103     SSH_PORT = paramiko.config.SSH_PORT
104     DEFAULT_WAIT_TIMEOUT = 120
105
106     @staticmethod
107     def gen_keys(key_filename, bit_count=2048):
108         rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None)
109         rsa_key.write_private_key_file(key_filename)
110         print("Writing %s ..." % key_filename)
111         with open('.'.join([key_filename, "pub"]), "w") as pubkey_file:
112             pubkey_file.write(rsa_key.get_name())
113             pubkey_file.write(' ')
114             pubkey_file.write(rsa_key.get_base64())
115             pubkey_file.write('\n')
116
117     @staticmethod
118     def get_class():
119         # must return static class name, anything else refers to the calling class
120         # i.e. the subclass, not the superclass
121         return SSH
122
123     @classmethod
124     def get_arg_key_map(cls):
125         return {
126             'user': ('user', NON_NONE_DEFAULT),
127             'host': ('ip', NON_NONE_DEFAULT),
128             'port': ('ssh_port', cls.SSH_PORT),
129             'pkey': ('pkey', None),
130             'key_filename': ('key_filename', None),
131             'password': ('password', None),
132             'name': ('name', None),
133         }
134
135     def __init__(self, user, host, port=None, pkey=None,
136                  key_filename=None, password=None, name=None):
137         """Initialize SSH client.
138
139         :param user: ssh username
140         :param host: hostname or ip address of remote ssh server
141         :param port: remote ssh port
142         :param pkey: RSA or DSS private key string or file object
143         :param key_filename: private key filename
144         :param password: password
145         """
146         self.name = name
147         if name:
148             self.log = logging.getLogger(__name__ + '.' + self.name)
149         else:
150             self.log = logging.getLogger(__name__)
151
152         self.wait_timeout = self.DEFAULT_WAIT_TIMEOUT
153         self.user = user
154         self.host = host
155         # everybody wants to debug this in the caller, do it here instead
156         self.log.debug("user:%s host:%s", user, host)
157
158         # we may get text port from YAML, convert to int
159         self.port = try_int(port, self.SSH_PORT)
160         self.pkey = self._get_pkey(pkey) if pkey else None
161         self.password = password
162         self.key_filename = key_filename
163         self._client = False
164         # paramiko loglevel debug will output ssh protocl debug
165         # we don't ever really want that unless we are debugging paramiko
166         # ssh issues
167         if os.environ.get("PARAMIKO_DEBUG", "").lower() == "true":
168             logging.getLogger("paramiko").setLevel(logging.DEBUG)
169         else:
170             logging.getLogger("paramiko").setLevel(logging.WARN)
171
172     @classmethod
173     def args_from_node(cls, node, overrides=None, defaults=None):
174         if overrides is None:
175             overrides = {}
176         if defaults is None:
177             defaults = {}
178
179         params = ChainMap(overrides, node, defaults)
180         return make_dict_from_map(params, cls.get_arg_key_map())
181
182     @classmethod
183     def from_node(cls, node, overrides=None, defaults=None):
184         return cls(**cls.args_from_node(node, overrides, defaults))
185
186     def _get_pkey(self, key):
187         if isinstance(key, six.string_types):
188             key = six.moves.StringIO(key)
189         errors = []
190         for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
191             try:
192                 return key_class.from_private_key(key)
193             except paramiko.SSHException as e:
194                 errors.append(e)
195         raise exceptions.SSHError(error_msg='Invalid pkey: %s' % errors)
196
197     @property
198     def is_connected(self):
199         return bool(self._client)
200
201     def _get_client(self):
202         if self.is_connected:
203             return self._client
204         try:
205             self._client = paramiko.SSHClient()
206             self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
207             self._client.connect(self.host, username=self.user,
208                                  port=self.port, pkey=self.pkey,
209                                  key_filename=self.key_filename,
210                                  password=self.password,
211                                  allow_agent=False, look_for_keys=False,
212                                  timeout=1)
213             return self._client
214         except Exception as e:
215             message = ("Exception %(exception_type)s was raised "
216                        "during connect. Exception value is: %(exception)r" %
217                        {"exception": e, "exception_type": type(e)})
218             self._client = False
219             raise exceptions.SSHError(error_msg=message)
220
221     def _make_dict(self):
222         return {
223             'user': self.user,
224             'host': self.host,
225             'port': self.port,
226             'pkey': self.pkey,
227             'key_filename': self.key_filename,
228             'password': self.password,
229             'name': self.name,
230         }
231
232     def copy(self):
233         return self.get_class()(**self._make_dict())
234
235     def close(self):
236         if self._client:
237             self._client.close()
238             self._client = False
239
240     def run(self, cmd, stdin=None, stdout=None, stderr=None,
241             raise_on_error=True, timeout=3600,
242             keep_stdin_open=False, pty=False):
243         """Execute specified command on the server.
244
245         :param cmd:             Command to be executed.
246         :type cmd:              str
247         :param stdin:           Open file or string to pass to stdin.
248         :param stdout:          Open file to connect to stdout.
249         :param stderr:          Open file to connect to stderr.
250         :param raise_on_error:  If False then exit code will be return. If True
251                                 then exception will be raized if non-zero code.
252         :param timeout:         Timeout in seconds for command execution.
253                                 Default 1 hour. No timeout if set to 0.
254         :param keep_stdin_open: don't close stdin on empty reads
255         :type keep_stdin_open:  bool
256         :param pty:             Request a pseudo terminal for this connection.
257                                 This allows passing control characters.
258                                 Default False.
259         :type pty:              bool
260         """
261
262         client = self._get_client()
263
264         if isinstance(stdin, six.string_types):
265             stdin = six.moves.StringIO(stdin)
266
267         return self._run(client, cmd, stdin=stdin, stdout=stdout,
268                          stderr=stderr, raise_on_error=raise_on_error,
269                          timeout=timeout,
270                          keep_stdin_open=keep_stdin_open, pty=pty)
271
272     def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
273              raise_on_error=True, timeout=3600,
274              keep_stdin_open=False, pty=False):
275
276         transport = client.get_transport()
277         session = transport.open_session()
278         if pty:
279             session.get_pty()
280         session.exec_command(cmd)
281         start_time = time.time()
282
283         # encode on transmit, decode on receive
284         data_to_send = encodeutils.safe_encode("", incoming='utf-8')
285         stderr_data = None
286
287         # If we have data to be sent to stdin then `select' should also
288         # check for stdin availability.
289         if stdin and not stdin.closed:
290             writes = [session]
291         else:
292             writes = []
293
294         while True:
295             # Block until data can be read/write.
296             e = select.select([session], writes, [session], 1)[2]
297
298             if session.recv_ready():
299                 data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
300                 self.log.debug("stdout: %r", data)
301                 if stdout is not None:
302                     stdout.write(data)
303                 continue
304
305             if session.recv_stderr_ready():
306                 stderr_data = encodeutils.safe_decode(
307                     session.recv_stderr(4096), 'utf-8')
308                 self.log.debug("stderr: %r", stderr_data)
309                 if stderr is not None:
310                     stderr.write(stderr_data)
311                 continue
312
313             if session.send_ready():
314                 if stdin is not None and not stdin.closed:
315                     if not data_to_send:
316                         stdin_txt = stdin.read(4096)
317                         if stdin_txt is None:
318                             stdin_txt = ''
319                         data_to_send = encodeutils.safe_encode(
320                             stdin_txt, incoming='utf-8')
321                         if not data_to_send:
322                             # we may need to keep stdin open
323                             if not keep_stdin_open:
324                                 stdin.close()
325                                 session.shutdown_write()
326                                 writes = []
327                     if data_to_send:
328                         sent_bytes = session.send(data_to_send)
329                         # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
330                         data_to_send = data_to_send[sent_bytes:]
331
332             if session.exit_status_ready():
333                 break
334
335             if timeout and (time.time() - timeout) > start_time:
336                 message = ('Timeout executing command %(cmd)s on host %(host)s'
337                            % {"cmd": cmd, "host": self.host})
338                 raise exceptions.SSHTimeout(error_msg=message)
339             if e:
340                 raise exceptions.SSHError(error_msg='Socket error')
341
342         exit_status = session.recv_exit_status()
343         if exit_status != 0 and raise_on_error:
344             fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
345             details = fmt % {"cmd": cmd, "status": exit_status}
346             if stderr_data:
347                 details += " Last stderr data: '%s'." % stderr_data
348             raise exceptions.SSHError(error_msg=details)
349         return exit_status
350
351     def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False):
352         """Execute the specified command on the server.
353
354         :param cmd: (str)             Command to be executed.
355         :param stdin: (StringIO)      Open file to be sent on process stdin.
356         :param timeout: (int)         Timeout for execution of the command.
357         :param raise_on_error: (bool) If True, then an SSHError will be raised
358                                       when non-zero exit code.
359
360         :returns: tuple (exit_status, stdout, stderr)
361         """
362         stdout = six.moves.StringIO()
363         stderr = six.moves.StringIO()
364
365         exit_status = self.run(cmd, stderr=stderr,
366                                stdout=stdout, stdin=stdin,
367                                timeout=timeout, raise_on_error=raise_on_error)
368         stdout.seek(0)
369         stderr.seek(0)
370         return exit_status, stdout.read(), stderr.read()
371
372     def wait(self, timeout=None, interval=1):
373         """Wait for the host will be available via ssh."""
374         if timeout is None:
375             timeout = self.wait_timeout
376
377         end_time = time.time() + timeout
378         while True:
379             try:
380                 return self.execute("uname")
381             except (socket.error, exceptions.SSHError) as e:
382                 self.log.debug("Ssh is still unavailable: %r", e)
383                 time.sleep(interval)
384             if time.time() > end_time:
385                 raise exceptions.SSHTimeout(
386                     error_msg='Timeout waiting for "%s"' % self.host)
387
388     def put(self, files, remote_path=b'.', recursive=False):
389         client = self._get_client()
390
391         with SCPClient(client.get_transport()) as scp:
392             scp.put(files, remote_path, recursive)
393
394     def get(self, remote_path, local_path='/tmp/', recursive=True):
395         client = self._get_client()
396
397         with SCPClient(client.get_transport()) as scp:
398             scp.get(remote_path, local_path, recursive)
399
400     # keep shell running in the background, e.g. screen
401     def send_command(self, command):
402         client = self._get_client()
403         client.exec_command(command, get_pty=True)
404
405     def _put_file_sftp(self, localpath, remotepath, mode=None):
406         client = self._get_client()
407
408         with client.open_sftp() as sftp:
409             sftp.put(localpath, remotepath)
410             if mode is None:
411                 mode = 0o777 & os.stat(localpath).st_mode
412             sftp.chmod(remotepath, mode)
413
414     TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
415
416     def _put_file_shell(self, localpath, remotepath, mode=None):
417         # quote to stop wordpslit
418         tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
419         if not tilde:
420             tilde = ''
421         cmd = ['cat > %s"%s"' % (tilde, remotepath)]
422         if mode is not None:
423             # use -- so no options
424             cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
425
426         with open(localpath, "rb") as localfile:
427             # only chmod on successful cat
428             self.run("&& ".join(cmd), stdin=localfile)
429
430     def put_file(self, localpath, remotepath, mode=None):
431         """Copy specified local file to the server.
432
433         :param localpath:   Local filename.
434         :param remotepath:  Remote filename.
435         :param mode:        Permissions to set after upload
436         """
437         try:
438             self._put_file_sftp(localpath, remotepath, mode=mode)
439         except (paramiko.SSHException, socket.error):
440             self._put_file_shell(localpath, remotepath, mode=mode)
441
442     def provision_tool(self, tool_path, tool_file=None):
443         return provision_tool(self, tool_path, tool_file)
444
445     def put_file_obj(self, file_obj, remotepath, mode=None):
446         client = self._get_client()
447
448         with client.open_sftp() as sftp:
449             sftp.putfo(file_obj, remotepath)
450             if mode is not None:
451                 sftp.chmod(remotepath, mode)
452
453     def get_file_obj(self, remotepath, file_obj):
454         client = self._get_client()
455
456         with client.open_sftp() as sftp:
457             sftp.getfo(remotepath, file_obj)
458
459
460 class AutoConnectSSH(SSH):
461
462     @classmethod
463     def get_arg_key_map(cls):
464         arg_key_map = super(AutoConnectSSH, cls).get_arg_key_map()
465         arg_key_map['wait'] = ('wait', True)
466         return arg_key_map
467
468     # always wait or we will get OpenStack SSH errors
469     def __init__(self, user, host, port=None, pkey=None,
470                  key_filename=None, password=None, name=None, wait=True):
471         super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
472         if wait and wait is not True:
473             self.wait_timeout = int(wait)
474
475     def _make_dict(self):
476         data = super(AutoConnectSSH, self)._make_dict()
477         data.update({
478             'wait': self.wait_timeout
479         })
480         return data
481
482     def _connect(self):
483         if not self.is_connected:
484             interval = 1
485             timeout = self.wait_timeout
486
487             end_time = time.time() + timeout
488             while True:
489                 try:
490                     return self._get_client()
491                 except (socket.error, exceptions.SSHError) as e:
492                     self.log.debug("Ssh is still unavailable: %r", e)
493                     time.sleep(interval)
494                 if time.time() > end_time:
495                     raise exceptions.SSHTimeout(
496                         error_msg='Timeout waiting for "%s"' % self.host)
497
498     def drop_connection(self):
499         """ Don't close anything, just force creation of a new client """
500         self._client = False
501
502     def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False):
503         self._connect()
504         return super(AutoConnectSSH, self).execute(cmd, stdin, timeout,
505                                                    raise_on_error)
506
507     def run(self, cmd, stdin=None, stdout=None, stderr=None,
508             raise_on_error=True, timeout=3600,
509             keep_stdin_open=False, pty=False):
510         self._connect()
511         return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
512                                                timeout, keep_stdin_open, pty)
513
514     def put(self, files, remote_path=b'.', recursive=False):
515         self._connect()
516         return super(AutoConnectSSH, self).put(files, remote_path, recursive)
517
518     def put_file(self, local_path, remote_path, mode=None):
519         self._connect()
520         return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
521
522     def put_file_obj(self, file_obj, remote_path, mode=None):
523         self._connect()
524         return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
525
526     def get_file_obj(self, remote_path, file_obj):
527         self._connect()
528         return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
529
530     def provision_tool(self, tool_path, tool_file=None):
531         self._connect()
532         return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
533
534     @staticmethod
535     def get_class():
536         # must return static class name, anything else refers to the calling class
537         # i.e. the subclass, not the superclass
538         return AutoConnectSSH