bb715e4b49e0d04c82dedcc094eaaed13673cc80
[yardstick.git] / yardstick / ssh.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 # yardstick comment: this is a modified copy of rally/rally/common/sshutils.py
17
18 """High level ssh library.
19
20 Usage examples:
21
22 Execute command and get output:
23
24     ssh = sshclient.SSH("root", "example.com", port=33)
25     status, stdout, stderr = ssh.execute("ps ax")
26     if status:
27         raise Exception("Command failed with non-zero status.")
28     print(stdout.splitlines())
29
30 Execute command with huge output:
31
32     class PseudoFile(io.RawIOBase):
33         def write(chunk):
34             if "error" in chunk:
35                 email_admin(chunk)
36
37     ssh = SSH("root", "example.com")
38     with PseudoFile() as p:
39         ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
40
41 Execute local script on remote side:
42
43     ssh = sshclient.SSH("user", "example.com")
44
45     with open("~/myscript.sh", "r") as stdin_file:
46         status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
47                                        stdin=stdin_file)
48
49 Upload file:
50
51     ssh = SSH("user", "example.com")
52     # use rb for binary files
53     with open("/store/file.gz", "rb") as stdin_file:
54         ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
55
56 Eventlet:
57
58     eventlet.monkey_patch(select=True, time=True)
59     or
60     eventlet.monkey_patch()
61     or
62     sshclient = eventlet.import_patched("yardstick.ssh")
63
64 """
65 from __future__ import absolute_import
66 import os
67 import select
68 import socket
69 import time
70 import re
71
72 import logging
73
74 import paramiko
75 from chainmap import ChainMap
76 from oslo_utils import encodeutils
77 from scp import SCPClient
78 import six
79
80 from yardstick.common.utils import try_int
81 from yardstick.network_services.utils import provision_tool
82
83
84 class SSHError(Exception):
85     pass
86
87
88 class SSHTimeout(SSHError):
89     pass
90
91
92 class SSH(object):
93     """Represent ssh connection."""
94
95     SSH_PORT = paramiko.config.SSH_PORT
96
97     @staticmethod
98     def gen_keys(key_filename, bit_count=2048):
99         rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None)
100         rsa_key.write_private_key_file(key_filename)
101         print("Writing %s ..." % key_filename)
102         with open('.'.join([key_filename, "pub"]), "w") as pubkey_file:
103             pubkey_file.write(rsa_key.get_name())
104             pubkey_file.write(' ')
105             pubkey_file.write(rsa_key.get_base64())
106             pubkey_file.write('\n')
107
108     @staticmethod
109     def get_class():
110         # must return static class name, anything else refers to the calling class
111         # i.e. the subclass, not the superclass
112         return SSH
113
114     def __init__(self, user, host, port=None, pkey=None,
115                  key_filename=None, password=None, name=None):
116         """Initialize SSH client.
117
118         :param user: ssh username
119         :param host: hostname or ip address of remote ssh server
120         :param port: remote ssh port
121         :param pkey: RSA or DSS private key string or file object
122         :param key_filename: private key filename
123         :param password: password
124         """
125         self.name = name
126         if name:
127             self.log = logging.getLogger(__name__ + '.' + self.name)
128         else:
129             self.log = logging.getLogger(__name__)
130
131         self.user = user
132         self.host = host
133         # everybody wants to debug this in the caller, do it here instead
134         self.log.debug("user:%s host:%s", user, host)
135
136         # we may get text port from YAML, convert to int
137         self.port = try_int(port, self.SSH_PORT)
138         self.pkey = self._get_pkey(pkey) if pkey else None
139         self.password = password
140         self.key_filename = key_filename
141         self._client = False
142         # paramiko loglevel debug will output ssh protocl debug
143         # we don't ever really want that unless we are debugging paramiko
144         # ssh issues
145         if os.environ.get("PARAMIKO_DEBUG", "").lower() == "true":
146             logging.getLogger("paramiko").setLevel(logging.DEBUG)
147         else:
148             logging.getLogger("paramiko").setLevel(logging.WARN)
149
150     @classmethod
151     def args_from_node(cls, node, overrides=None, defaults=None):
152         if overrides is None:
153             overrides = {}
154         if defaults is None:
155             defaults = {}
156         params = ChainMap(overrides, node, defaults)
157         return {
158             'user': params['user'],
159             'host': params['ip'],
160             'port': params.get('ssh_port', cls.SSH_PORT),
161             'pkey': params.get('pkey'),
162             'key_filename': params.get('key_filename'),
163             'password': params.get('password'),
164             'name': params.get('name'),
165         }
166
167     @classmethod
168     def from_node(cls, node, overrides=None, defaults=None):
169         return cls(**cls.args_from_node(node, overrides, defaults))
170
171     def _get_pkey(self, key):
172         if isinstance(key, six.string_types):
173             key = six.moves.StringIO(key)
174         errors = []
175         for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
176             try:
177                 return key_class.from_private_key(key)
178             except paramiko.SSHException as e:
179                 errors.append(e)
180         raise SSHError("Invalid pkey: %s" % (errors))
181
182     @property
183     def is_connected(self):
184         return bool(self._client)
185
186     def _get_client(self):
187         if self.is_connected:
188             return self._client
189         try:
190             self._client = paramiko.SSHClient()
191             self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
192             self._client.connect(self.host, username=self.user,
193                                  port=self.port, pkey=self.pkey,
194                                  key_filename=self.key_filename,
195                                  password=self.password,
196                                  allow_agent=False, look_for_keys=False,
197                                  timeout=1)
198             return self._client
199         except Exception as e:
200             message = ("Exception %(exception_type)s was raised "
201                        "during connect. Exception value is: %(exception)r")
202             self._client = False
203             raise SSHError(message % {"exception": e,
204                                       "exception_type": type(e)})
205
206     def _make_dict(self):
207         return {
208             'user': self.user,
209             'host': self.host,
210             'port': self.port,
211             'pkey': self.pkey,
212             'key_filename': self.key_filename,
213             'password': self.password,
214             'name': self.name,
215         }
216
217     def copy(self):
218         return self.get_class()(**self._make_dict())
219
220     def close(self):
221         if self._client:
222             self._client.close()
223             self._client = False
224
225     def run(self, cmd, stdin=None, stdout=None, stderr=None,
226             raise_on_error=True, timeout=3600,
227             keep_stdin_open=False, pty=False):
228         """Execute specified command on the server.
229
230         :param cmd:             Command to be executed.
231         :type cmd:              str
232         :param stdin:           Open file or string to pass to stdin.
233         :param stdout:          Open file to connect to stdout.
234         :param stderr:          Open file to connect to stderr.
235         :param raise_on_error:  If False then exit code will be return. If True
236                                 then exception will be raized if non-zero code.
237         :param timeout:         Timeout in seconds for command execution.
238                                 Default 1 hour. No timeout if set to 0.
239         :param keep_stdin_open: don't close stdin on empty reads
240         :type keep_stdin_open:  bool
241         :param pty:             Request a pseudo terminal for this connection.
242                                 This allows passing control characters.
243                                 Default False.
244         :type pty:              bool
245         """
246
247         client = self._get_client()
248
249         if isinstance(stdin, six.string_types):
250             stdin = six.moves.StringIO(stdin)
251
252         return self._run(client, cmd, stdin=stdin, stdout=stdout,
253                          stderr=stderr, raise_on_error=raise_on_error,
254                          timeout=timeout,
255                          keep_stdin_open=keep_stdin_open, pty=pty)
256
257     def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
258              raise_on_error=True, timeout=3600,
259              keep_stdin_open=False, pty=False):
260
261         transport = client.get_transport()
262         session = transport.open_session()
263         if pty:
264             session.get_pty()
265         session.exec_command(cmd)
266         start_time = time.time()
267
268         # encode on transmit, decode on receive
269         data_to_send = encodeutils.safe_encode("", incoming='utf-8')
270         stderr_data = None
271
272         # If we have data to be sent to stdin then `select' should also
273         # check for stdin availability.
274         if stdin and not stdin.closed:
275             writes = [session]
276         else:
277             writes = []
278
279         while True:
280             # Block until data can be read/write.
281             r, w, e = select.select([session], writes, [session], 1)
282
283             if session.recv_ready():
284                 data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
285                 self.log.debug("stdout: %r", data)
286                 if stdout is not None:
287                     stdout.write(data)
288                 continue
289
290             if session.recv_stderr_ready():
291                 stderr_data = encodeutils.safe_decode(
292                     session.recv_stderr(4096), 'utf-8')
293                 self.log.debug("stderr: %r", stderr_data)
294                 if stderr is not None:
295                     stderr.write(stderr_data)
296                 continue
297
298             if session.send_ready():
299                 if stdin is not None and not stdin.closed:
300                     if not data_to_send:
301                         stdin_txt = stdin.read(4096)
302                         if stdin_txt is None:
303                             stdin_txt = ''
304                         data_to_send = encodeutils.safe_encode(
305                             stdin_txt, incoming='utf-8')
306                         if not data_to_send:
307                             # we may need to keep stdin open
308                             if not keep_stdin_open:
309                                 stdin.close()
310                                 session.shutdown_write()
311                                 writes = []
312                     if data_to_send:
313                         sent_bytes = session.send(data_to_send)
314                         # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
315                         data_to_send = data_to_send[sent_bytes:]
316
317             if session.exit_status_ready():
318                 break
319
320             if timeout and (time.time() - timeout) > start_time:
321                 args = {"cmd": cmd, "host": self.host}
322                 raise SSHTimeout("Timeout executing command "
323                                  "'%(cmd)s' on host %(host)s" % args)
324             if e:
325                 raise SSHError("Socket error.")
326
327         exit_status = session.recv_exit_status()
328         if exit_status != 0 and raise_on_error:
329             fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
330             details = fmt % {"cmd": cmd, "status": exit_status}
331             if stderr_data:
332                 details += " Last stderr data: '%s'." % stderr_data
333             raise SSHError(details)
334         return exit_status
335
336     def execute(self, cmd, stdin=None, timeout=3600):
337         """Execute the specified command on the server.
338
339         :param cmd:     Command to be executed.
340         :param stdin:   Open file to be sent on process stdin.
341         :param timeout: Timeout for execution of the command.
342
343         :returns: tuple (exit_status, stdout, stderr)
344         """
345         stdout = six.moves.StringIO()
346         stderr = six.moves.StringIO()
347
348         exit_status = self.run(cmd, stderr=stderr,
349                                stdout=stdout, stdin=stdin,
350                                timeout=timeout, raise_on_error=False)
351         stdout.seek(0)
352         stderr.seek(0)
353         return exit_status, stdout.read(), stderr.read()
354
355     def wait(self, timeout=120, interval=1):
356         """Wait for the host will be available via ssh."""
357         start_time = time.time()
358         while True:
359             try:
360                 return self.execute("uname")
361             except (socket.error, SSHError) as e:
362                 self.log.debug("Ssh is still unavailable: %r", e)
363                 time.sleep(interval)
364             if time.time() > (start_time + timeout):
365                 raise SSHTimeout("Timeout waiting for '%s'", self.host)
366
367     def put(self, files, remote_path=b'.', recursive=False):
368         client = self._get_client()
369
370         with SCPClient(client.get_transport()) as scp:
371             scp.put(files, remote_path, recursive)
372
373     # keep shell running in the background, e.g. screen
374     def send_command(self, command):
375         client = self._get_client()
376         client.exec_command(command, get_pty=True)
377
378     def _put_file_sftp(self, localpath, remotepath, mode=None):
379         client = self._get_client()
380
381         with client.open_sftp() as sftp:
382             sftp.put(localpath, remotepath)
383             if mode is None:
384                 mode = 0o777 & os.stat(localpath).st_mode
385             sftp.chmod(remotepath, mode)
386
387     TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
388
389     def _put_file_shell(self, localpath, remotepath, mode=None):
390         # quote to stop wordpslit
391         tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
392         if not tilde:
393             tilde = ''
394         cmd = ['cat > %s"%s"' % (tilde, remotepath)]
395         if mode is not None:
396             # use -- so no options
397             cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
398
399         with open(localpath, "rb") as localfile:
400             # only chmod on successful cat
401             self.run("&& ".join(cmd), stdin=localfile)
402
403     def put_file(self, localpath, remotepath, mode=None):
404         """Copy specified local file to the server.
405
406         :param localpath:   Local filename.
407         :param remotepath:  Remote filename.
408         :param mode:        Permissions to set after upload
409         """
410         try:
411             self._put_file_sftp(localpath, remotepath, mode=mode)
412         except (paramiko.SSHException, socket.error):
413             self._put_file_shell(localpath, remotepath, mode=mode)
414
415     def provision_tool(self, tool_path, tool_file=None):
416         return provision_tool(self, tool_path, tool_file)
417
418     def put_file_obj(self, file_obj, remotepath, mode=None):
419         client = self._get_client()
420
421         with client.open_sftp() as sftp:
422             sftp.putfo(file_obj, remotepath)
423             if mode is not None:
424                 sftp.chmod(remotepath, mode)
425
426     def get_file_obj(self, remotepath, file_obj):
427         client = self._get_client()
428
429         with client.open_sftp() as sftp:
430             sftp.getfo(remotepath, file_obj)
431
432
433 class AutoConnectSSH(SSH):
434
435     # always wait or we will get OpenStack SSH errors
436     def __init__(self, user, host, port=None, pkey=None,
437                  key_filename=None, password=None, name=None, wait=True):
438         super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
439         self._wait = wait
440
441     def _make_dict(self):
442         data = super(AutoConnectSSH, self)._make_dict()
443         data.update({
444             'wait': self._wait
445         })
446         return data
447
448     def _connect(self):
449         if not self.is_connected:
450             self._get_client()
451             if self._wait:
452                 self.wait()
453
454     def drop_connection(self):
455         """ Don't close anything, just force creation of a new client """
456         self._client = False
457
458     def execute(self, cmd, stdin=None, timeout=3600):
459         self._connect()
460         return super(AutoConnectSSH, self).execute(cmd, stdin, timeout)
461
462     def run(self, cmd, stdin=None, stdout=None, stderr=None,
463             raise_on_error=True, timeout=3600,
464             keep_stdin_open=False, pty=False):
465         self._connect()
466         return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
467                                                timeout, keep_stdin_open, pty)
468
469     def put(self, files, remote_path=b'.', recursive=False):
470         self._connect()
471         return super(AutoConnectSSH, self).put(files, remote_path, recursive)
472
473     def put_file(self, local_path, remote_path, mode=None):
474         self._connect()
475         return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
476
477     def put_file_obj(self, file_obj, remote_path, mode=None):
478         self._connect()
479         return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
480
481     def get_file_obj(self, remote_path, file_obj):
482         self._connect()
483         return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
484
485     def provision_tool(self, tool_path, tool_file=None):
486         self._connect()
487         return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
488
489     @staticmethod
490     def get_class():
491         # must return static class name, anything else refers to the calling class
492         # i.e. the subclass, not the superclass
493         return AutoConnectSSH