Merge "Add rpm,image directories and SLA options to Livemigration"
[yardstick.git] / yardstick / ssh.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 # yardstick comment: this is a modified copy of rally/rally/common/sshutils.py
17
18 """High level ssh library.
19
20 Usage examples:
21
22 Execute command and get output:
23
24     ssh = sshclient.SSH("root", "example.com", port=33)
25     status, stdout, stderr = ssh.execute("ps ax")
26     if status:
27         raise Exception("Command failed with non-zero status.")
28     print(stdout.splitlines())
29
30 Execute command with huge output:
31
32     class PseudoFile(io.RawIOBase):
33         def write(chunk):
34             if "error" in chunk:
35                 email_admin(chunk)
36
37     ssh = SSH("root", "example.com")
38     with PseudoFile() as p:
39         ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
40
41 Execute local script on remote side:
42
43     ssh = sshclient.SSH("user", "example.com")
44
45     with open("~/myscript.sh", "r") as stdin_file:
46         status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
47                                        stdin=stdin_file)
48
49 Upload file:
50
51     ssh = SSH("user", "example.com")
52     # use rb for binary files
53     with open("/store/file.gz", "rb") as stdin_file:
54         ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
55
56 Eventlet:
57
58     eventlet.monkey_patch(select=True, time=True)
59     or
60     eventlet.monkey_patch()
61     or
62     sshclient = eventlet.import_patched("yardstick.ssh")
63
64 """
65 from __future__ import absolute_import
66 import os
67 import select
68 import socket
69 import time
70 import re
71
72 import logging
73
74 import paramiko
75 from chainmap import ChainMap
76 from oslo_utils import encodeutils
77 from scp import SCPClient
78 import six
79
80 from yardstick.common.utils import try_int
81 from yardstick.network_services.utils import provision_tool
82
83
84 class SSHError(Exception):
85     pass
86
87
88 class SSHTimeout(SSHError):
89     pass
90
91
92 class SSH(object):
93     """Represent ssh connection."""
94
95     SSH_PORT = paramiko.config.SSH_PORT
96
97     @staticmethod
98     def gen_keys(key_filename, bit_count=2048):
99         rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None)
100         rsa_key.write_private_key_file(key_filename)
101         print("Writing %s ..." % key_filename)
102         with open('.'.join([key_filename, "pub"]), "w") as pubkey_file:
103             pubkey_file.write(rsa_key.get_name())
104             pubkey_file.write(' ')
105             pubkey_file.write(rsa_key.get_base64())
106             pubkey_file.write('\n')
107
108     @staticmethod
109     def get_class():
110         # must return static class name, anything else refers to the calling class
111         # i.e. the subclass, not the superclass
112         return SSH
113
114     def __init__(self, user, host, port=None, pkey=None,
115                  key_filename=None, password=None, name=None):
116         """Initialize SSH client.
117
118         :param user: ssh username
119         :param host: hostname or ip address of remote ssh server
120         :param port: remote ssh port
121         :param pkey: RSA or DSS private key string or file object
122         :param key_filename: private key filename
123         :param password: password
124         """
125         self.name = name
126         if name:
127             self.log = logging.getLogger(__name__ + '.' + self.name)
128         else:
129             self.log = logging.getLogger(__name__)
130
131         self.user = user
132         self.host = host
133         # everybody wants to debug this in the caller, do it here instead
134         self.log.debug("user:%s host:%s", user, host)
135
136         # we may get text port from YAML, convert to int
137         self.port = try_int(port, self.SSH_PORT)
138         self.pkey = self._get_pkey(pkey) if pkey else None
139         self.password = password
140         self.key_filename = key_filename
141         self._client = False
142         # paramiko loglevel debug will output ssh protocl debug
143         # we don't ever really want that unless we are debugging paramiko
144         # ssh issues
145         if os.environ.get("PARAMIKO_DEBUG", "").lower() == "true":
146             logging.getLogger("paramiko").setLevel(logging.DEBUG)
147         else:
148             logging.getLogger("paramiko").setLevel(logging.WARN)
149
150     @classmethod
151     def args_from_node(cls, node, overrides=None, defaults=None):
152         if overrides is None:
153             overrides = {}
154         if defaults is None:
155             defaults = {}
156         params = ChainMap(overrides, node, defaults)
157         return {
158             'user': params['user'],
159             'host': params['ip'],
160             'port': params.get('ssh_port', cls.SSH_PORT),
161             'pkey': params.get('pkey'),
162             'key_filename': params.get('key_filename'),
163             'password': params.get('password'),
164             'name': params.get('name'),
165         }
166
167     @classmethod
168     def from_node(cls, node, overrides=None, defaults=None):
169         return cls(**cls.args_from_node(node, overrides, defaults))
170
171     def _get_pkey(self, key):
172         if isinstance(key, six.string_types):
173             key = six.moves.StringIO(key)
174         errors = []
175         for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
176             try:
177                 return key_class.from_private_key(key)
178             except paramiko.SSHException as e:
179                 errors.append(e)
180         raise SSHError("Invalid pkey: %s" % (errors))
181
182     @property
183     def is_connected(self):
184         return bool(self._client)
185
186     def _get_client(self):
187         if self.is_connected:
188             return self._client
189         try:
190             self._client = paramiko.SSHClient()
191             self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
192             self._client.connect(self.host, username=self.user,
193                                  port=self.port, pkey=self.pkey,
194                                  key_filename=self.key_filename,
195                                  password=self.password,
196                                  allow_agent=False, look_for_keys=False,
197                                  timeout=1)
198             return self._client
199         except Exception as e:
200             message = ("Exception %(exception_type)s was raised "
201                        "during connect. Exception value is: %(exception)r")
202             self._client = False
203             raise SSHError(message % {"exception": e,
204                                       "exception_type": type(e)})
205
206     def _make_dict(self):
207         return {
208             'user': self.user,
209             'host': self.host,
210             'port': self.port,
211             'pkey': self.pkey,
212             'key_filename': self.key_filename,
213             'password': self.password,
214             'name': self.name,
215         }
216
217     def copy(self):
218         return self.get_class()(**self._make_dict())
219
220     def close(self):
221         if self._client:
222             self._client.close()
223             self._client = False
224
225     def run(self, cmd, stdin=None, stdout=None, stderr=None,
226             raise_on_error=True, timeout=3600,
227             keep_stdin_open=False, pty=False):
228         """Execute specified command on the server.
229
230         :param cmd:             Command to be executed.
231         :type cmd:              str
232         :param stdin:           Open file or string to pass to stdin.
233         :param stdout:          Open file to connect to stdout.
234         :param stderr:          Open file to connect to stderr.
235         :param raise_on_error:  If False then exit code will be return. If True
236                                 then exception will be raized if non-zero code.
237         :param timeout:         Timeout in seconds for command execution.
238                                 Default 1 hour. No timeout if set to 0.
239         :param keep_stdin_open: don't close stdin on empty reads
240         :type keep_stdin_open:  bool
241         :param pty:             Request a pseudo terminal for this connection.
242                                 This allows passing control characters.
243                                 Default False.
244         :type pty:              bool
245         """
246
247         client = self._get_client()
248
249         if isinstance(stdin, six.string_types):
250             stdin = six.moves.StringIO(stdin)
251
252         return self._run(client, cmd, stdin=stdin, stdout=stdout,
253                          stderr=stderr, raise_on_error=raise_on_error,
254                          timeout=timeout,
255                          keep_stdin_open=keep_stdin_open, pty=pty)
256
257     def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
258              raise_on_error=True, timeout=3600,
259              keep_stdin_open=False, pty=False):
260
261         transport = client.get_transport()
262         session = transport.open_session()
263         if pty:
264             session.get_pty()
265         session.exec_command(cmd)
266         start_time = time.time()
267
268         # encode on transmit, decode on receive
269         data_to_send = encodeutils.safe_encode("", incoming='utf-8')
270         stderr_data = None
271
272         # If we have data to be sent to stdin then `select' should also
273         # check for stdin availability.
274         if stdin and not stdin.closed:
275             writes = [session]
276         else:
277             writes = []
278
279         while True:
280             # Block until data can be read/write.
281             r, w, e = select.select([session], writes, [session], 1)
282
283             if session.recv_ready():
284                 data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
285                 self.log.debug("stdout: %r", data)
286                 if stdout is not None:
287                     stdout.write(data)
288                 continue
289
290             if session.recv_stderr_ready():
291                 stderr_data = encodeutils.safe_decode(
292                     session.recv_stderr(4096), 'utf-8')
293                 self.log.debug("stderr: %r", stderr_data)
294                 if stderr is not None:
295                     stderr.write(stderr_data)
296                 continue
297
298             if session.send_ready():
299                 if stdin is not None and not stdin.closed:
300                     if not data_to_send:
301                         stdin_txt = stdin.read(4096)
302                         if stdin_txt is None:
303                             stdin_txt = ''
304                         data_to_send = encodeutils.safe_encode(
305                             stdin_txt, incoming='utf-8')
306                         if not data_to_send:
307                             # we may need to keep stdin open
308                             if not keep_stdin_open:
309                                 stdin.close()
310                                 session.shutdown_write()
311                                 writes = []
312                     if data_to_send:
313                         sent_bytes = session.send(data_to_send)
314                         # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
315                         data_to_send = data_to_send[sent_bytes:]
316
317             if session.exit_status_ready():
318                 break
319
320             if timeout and (time.time() - timeout) > start_time:
321                 args = {"cmd": cmd, "host": self.host}
322                 raise SSHTimeout("Timeout executing command "
323                                  "'%(cmd)s' on host %(host)s" % args)
324             if e:
325                 raise SSHError("Socket error.")
326
327         exit_status = session.recv_exit_status()
328         if exit_status != 0 and raise_on_error:
329             fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
330             details = fmt % {"cmd": cmd, "status": exit_status}
331             if stderr_data:
332                 details += " Last stderr data: '%s'." % stderr_data
333             raise SSHError(details)
334         return exit_status
335
336     def execute(self, cmd, stdin=None, timeout=3600):
337         """Execute the specified command on the server.
338
339         :param cmd:     Command to be executed.
340         :param stdin:   Open file to be sent on process stdin.
341         :param timeout: Timeout for execution of the command.
342
343         :returns: tuple (exit_status, stdout, stderr)
344         """
345         stdout = six.moves.StringIO()
346         stderr = six.moves.StringIO()
347
348         exit_status = self.run(cmd, stderr=stderr,
349                                stdout=stdout, stdin=stdin,
350                                timeout=timeout, raise_on_error=False)
351         stdout.seek(0)
352         stderr.seek(0)
353         return exit_status, stdout.read(), stderr.read()
354
355     def wait(self, timeout=120, interval=1):
356         """Wait for the host will be available via ssh."""
357         start_time = time.time()
358         while True:
359             try:
360                 return self.execute("uname")
361             except (socket.error, SSHError) as e:
362                 self.log.debug("Ssh is still unavailable: %r", e)
363                 time.sleep(interval)
364             if time.time() > (start_time + timeout):
365                 raise SSHTimeout("Timeout waiting for '%s'", self.host)
366
367     def put(self, files, remote_path=b'.', recursive=False):
368         client = self._get_client()
369
370         with SCPClient(client.get_transport()) as scp:
371             scp.put(files, remote_path, recursive)
372
373     # keep shell running in the background, e.g. screen
374     def send_command(self, command):
375         client = self._get_client()
376         client.exec_command(command, get_pty=True)
377
378     def _put_file_sftp(self, localpath, remotepath, mode=None):
379         client = self._get_client()
380
381         with client.open_sftp() as sftp:
382             sftp.put(localpath, remotepath)
383             if mode is None:
384                 mode = 0o777 & os.stat(localpath).st_mode
385             sftp.chmod(remotepath, mode)
386
387     TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
388
389     def _put_file_shell(self, localpath, remotepath, mode=None):
390         # quote to stop wordpslit
391         tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
392         if not tilde:
393             tilde = ''
394         cmd = ['cat > %s"%s"' % (tilde, remotepath)]
395         if mode is not None:
396             # use -- so no options
397             cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
398
399         with open(localpath, "rb") as localfile:
400             # only chmod on successful cat
401             self.run("&& ".join(cmd), stdin=localfile)
402
403     def put_file(self, localpath, remotepath, mode=None):
404         """Copy specified local file to the server.
405
406         :param localpath:   Local filename.
407         :param remotepath:  Remote filename.
408         :param mode:        Permissions to set after upload
409         """
410         try:
411             self._put_file_sftp(localpath, remotepath, mode=mode)
412         except (paramiko.SSHException, socket.error):
413             self._put_file_shell(localpath, remotepath, mode=mode)
414
415     def provision_tool(self, tool_path, tool_file=None):
416         return provision_tool(self, tool_path, tool_file)
417
418     def put_file_obj(self, file_obj, remotepath, mode=None):
419         client = self._get_client()
420
421         with client.open_sftp() as sftp:
422             sftp.putfo(file_obj, remotepath)
423             if mode is not None:
424                 sftp.chmod(remotepath, mode)
425
426     def get_file_obj(self, remotepath, file_obj):
427         client = self._get_client()
428
429         with client.open_sftp() as sftp:
430             sftp.getfo(remotepath, file_obj)
431
432
433 class AutoConnectSSH(SSH):
434
435     def __init__(self, user, host, port=None, pkey=None,
436                  key_filename=None, password=None, name=None, wait=False):
437         super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
438         self._wait = wait
439
440     def _make_dict(self):
441         data = super(AutoConnectSSH, self)._make_dict()
442         data.update({
443             'wait': self._wait
444         })
445         return data
446
447     def _connect(self):
448         if not self.is_connected:
449             self._get_client()
450             if self._wait:
451                 self.wait()
452
453     def drop_connection(self):
454         """ Don't close anything, just force creation of a new client """
455         self._client = False
456
457     def execute(self, cmd, stdin=None, timeout=3600):
458         self._connect()
459         return super(AutoConnectSSH, self).execute(cmd, stdin, timeout)
460
461     def run(self, cmd, stdin=None, stdout=None, stderr=None,
462             raise_on_error=True, timeout=3600,
463             keep_stdin_open=False, pty=False):
464         self._connect()
465         return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
466                                                timeout, keep_stdin_open, pty)
467
468     def put(self, files, remote_path=b'.', recursive=False):
469         self._connect()
470         return super(AutoConnectSSH, self).put(files, remote_path, recursive)
471
472     def put_file(self, local_path, remote_path, mode=None):
473         self._connect()
474         return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
475
476     def put_file_obj(self, file_obj, remote_path, mode=None):
477         self._connect()
478         return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
479
480     def get_file_obj(self, remote_path, file_obj):
481         self._connect()
482         return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
483
484     def provision_tool(self, tool_path, tool_file=None):
485         self._connect()
486         return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
487
488     @staticmethod
489     def get_class():
490         # must return static class name, anything else refers to the calling class
491         # i.e. the subclass, not the superclass
492         return AutoConnectSSH