Prohibit the importation of a list of libraries
[yardstick.git] / yardstick / ssh.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 # yardstick comment: this is a modified copy of rally/rally/common/sshutils.py
17
18 """High level ssh library.
19
20 Usage examples:
21
22 Execute command and get output:
23
24     ssh = sshclient.SSH("root", "example.com", port=33)
25     status, stdout, stderr = ssh.execute("ps ax")
26     if status:
27         raise Exception("Command failed with non-zero status.")
28     print(stdout.splitlines())
29
30 Execute command with huge output:
31
32     class PseudoFile(io.RawIOBase):
33         def write(chunk):
34             if "error" in chunk:
35                 email_admin(chunk)
36
37     ssh = SSH("root", "example.com")
38     with PseudoFile() as p:
39         ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
40
41 Execute local script on remote side:
42
43     ssh = sshclient.SSH("user", "example.com")
44
45     with open("~/myscript.sh", "r") as stdin_file:
46         status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
47                                        stdin=stdin_file)
48
49 Upload file:
50
51     ssh = SSH("user", "example.com")
52     # use rb for binary files
53     with open("/store/file.gz", "rb") as stdin_file:
54         ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
55
56 Eventlet:
57
58     eventlet.monkey_patch(select=True, time=True)
59     or
60     eventlet.monkey_patch()
61     or
62     sshclient = eventlet.import_patched("yardstick.ssh")
63
64 """
65 from __future__ import absolute_import
66 import os
67 import io
68 import select
69 import socket
70 import time
71 import re
72
73 import logging
74
75 import paramiko
76 from chainmap import ChainMap
77 from oslo_utils import encodeutils
78 from scp import SCPClient
79 import six
80
81 from yardstick.common.utils import try_int
82 from yardstick.network_services.utils import provision_tool
83
84
85 def convert_key_to_str(key):
86     if not isinstance(key, (paramiko.RSAKey, paramiko.DSSKey)):
87         return key
88     k = io.StringIO()
89     key.write_private_key(k)
90     return k.getvalue()
91
92
93 class SSHError(Exception):
94     pass
95
96
97 class SSHTimeout(SSHError):
98     pass
99
100
101 class SSH(object):
102     """Represent ssh connection."""
103
104     SSH_PORT = paramiko.config.SSH_PORT
105
106     @staticmethod
107     def gen_keys(key_filename, bit_count=2048):
108         rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None)
109         rsa_key.write_private_key_file(key_filename)
110         print("Writing %s ..." % key_filename)
111         with open('.'.join([key_filename, "pub"]), "w") as pubkey_file:
112             pubkey_file.write(rsa_key.get_name())
113             pubkey_file.write(' ')
114             pubkey_file.write(rsa_key.get_base64())
115             pubkey_file.write('\n')
116
117     @staticmethod
118     def get_class():
119         # must return static class name, anything else refers to the calling class
120         # i.e. the subclass, not the superclass
121         return SSH
122
123     def __init__(self, user, host, port=None, pkey=None,
124                  key_filename=None, password=None, name=None):
125         """Initialize SSH client.
126
127         :param user: ssh username
128         :param host: hostname or ip address of remote ssh server
129         :param port: remote ssh port
130         :param pkey: RSA or DSS private key string or file object
131         :param key_filename: private key filename
132         :param password: password
133         """
134         self.name = name
135         if name:
136             self.log = logging.getLogger(__name__ + '.' + self.name)
137         else:
138             self.log = logging.getLogger(__name__)
139
140         self.user = user
141         self.host = host
142         # everybody wants to debug this in the caller, do it here instead
143         self.log.debug("user:%s host:%s", user, host)
144
145         # we may get text port from YAML, convert to int
146         self.port = try_int(port, self.SSH_PORT)
147         self.pkey = self._get_pkey(pkey) if pkey else None
148         self.password = password
149         self.key_filename = key_filename
150         self._client = False
151         # paramiko loglevel debug will output ssh protocl debug
152         # we don't ever really want that unless we are debugging paramiko
153         # ssh issues
154         if os.environ.get("PARAMIKO_DEBUG", "").lower() == "true":
155             logging.getLogger("paramiko").setLevel(logging.DEBUG)
156         else:
157             logging.getLogger("paramiko").setLevel(logging.WARN)
158
159     @classmethod
160     def args_from_node(cls, node, overrides=None, defaults=None):
161         if overrides is None:
162             overrides = {}
163         if defaults is None:
164             defaults = {}
165         params = ChainMap(overrides, node, defaults)
166         return {
167             'user': params['user'],
168             'host': params['ip'],
169             'port': params.get('ssh_port', cls.SSH_PORT),
170             'pkey': params.get('pkey'),
171             'key_filename': params.get('key_filename'),
172             'password': params.get('password'),
173             'name': params.get('name'),
174         }
175
176     @classmethod
177     def from_node(cls, node, overrides=None, defaults=None):
178         return cls(**cls.args_from_node(node, overrides, defaults))
179
180     def _get_pkey(self, key):
181         if isinstance(key, six.string_types):
182             key = six.moves.StringIO(key)
183         errors = []
184         for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
185             try:
186                 return key_class.from_private_key(key)
187             except paramiko.SSHException as e:
188                 errors.append(e)
189         raise SSHError("Invalid pkey: %s" % (errors))
190
191     @property
192     def is_connected(self):
193         return bool(self._client)
194
195     def _get_client(self):
196         if self.is_connected:
197             return self._client
198         try:
199             self._client = paramiko.SSHClient()
200             self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
201             self._client.connect(self.host, username=self.user,
202                                  port=self.port, pkey=self.pkey,
203                                  key_filename=self.key_filename,
204                                  password=self.password,
205                                  allow_agent=False, look_for_keys=False,
206                                  timeout=1)
207             return self._client
208         except Exception as e:
209             message = ("Exception %(exception_type)s was raised "
210                        "during connect. Exception value is: %(exception)r")
211             self._client = False
212             raise SSHError(message % {"exception": e,
213                                       "exception_type": type(e)})
214
215     def _make_dict(self):
216         return {
217             'user': self.user,
218             'host': self.host,
219             'port': self.port,
220             'pkey': self.pkey,
221             'key_filename': self.key_filename,
222             'password': self.password,
223             'name': self.name,
224         }
225
226     def copy(self):
227         return self.get_class()(**self._make_dict())
228
229     def close(self):
230         if self._client:
231             self._client.close()
232             self._client = False
233
234     def run(self, cmd, stdin=None, stdout=None, stderr=None,
235             raise_on_error=True, timeout=3600,
236             keep_stdin_open=False, pty=False):
237         """Execute specified command on the server.
238
239         :param cmd:             Command to be executed.
240         :type cmd:              str
241         :param stdin:           Open file or string to pass to stdin.
242         :param stdout:          Open file to connect to stdout.
243         :param stderr:          Open file to connect to stderr.
244         :param raise_on_error:  If False then exit code will be return. If True
245                                 then exception will be raized if non-zero code.
246         :param timeout:         Timeout in seconds for command execution.
247                                 Default 1 hour. No timeout if set to 0.
248         :param keep_stdin_open: don't close stdin on empty reads
249         :type keep_stdin_open:  bool
250         :param pty:             Request a pseudo terminal for this connection.
251                                 This allows passing control characters.
252                                 Default False.
253         :type pty:              bool
254         """
255
256         client = self._get_client()
257
258         if isinstance(stdin, six.string_types):
259             stdin = six.moves.StringIO(stdin)
260
261         return self._run(client, cmd, stdin=stdin, stdout=stdout,
262                          stderr=stderr, raise_on_error=raise_on_error,
263                          timeout=timeout,
264                          keep_stdin_open=keep_stdin_open, pty=pty)
265
266     def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
267              raise_on_error=True, timeout=3600,
268              keep_stdin_open=False, pty=False):
269
270         transport = client.get_transport()
271         session = transport.open_session()
272         if pty:
273             session.get_pty()
274         session.exec_command(cmd)
275         start_time = time.time()
276
277         # encode on transmit, decode on receive
278         data_to_send = encodeutils.safe_encode("", incoming='utf-8')
279         stderr_data = None
280
281         # If we have data to be sent to stdin then `select' should also
282         # check for stdin availability.
283         if stdin and not stdin.closed:
284             writes = [session]
285         else:
286             writes = []
287
288         while True:
289             # Block until data can be read/write.
290             r, w, e = select.select([session], writes, [session], 1)
291
292             if session.recv_ready():
293                 data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
294                 self.log.debug("stdout: %r", data)
295                 if stdout is not None:
296                     stdout.write(data)
297                 continue
298
299             if session.recv_stderr_ready():
300                 stderr_data = encodeutils.safe_decode(
301                     session.recv_stderr(4096), 'utf-8')
302                 self.log.debug("stderr: %r", stderr_data)
303                 if stderr is not None:
304                     stderr.write(stderr_data)
305                 continue
306
307             if session.send_ready():
308                 if stdin is not None and not stdin.closed:
309                     if not data_to_send:
310                         stdin_txt = stdin.read(4096)
311                         if stdin_txt is None:
312                             stdin_txt = ''
313                         data_to_send = encodeutils.safe_encode(
314                             stdin_txt, incoming='utf-8')
315                         if not data_to_send:
316                             # we may need to keep stdin open
317                             if not keep_stdin_open:
318                                 stdin.close()
319                                 session.shutdown_write()
320                                 writes = []
321                     if data_to_send:
322                         sent_bytes = session.send(data_to_send)
323                         # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
324                         data_to_send = data_to_send[sent_bytes:]
325
326             if session.exit_status_ready():
327                 break
328
329             if timeout and (time.time() - timeout) > start_time:
330                 args = {"cmd": cmd, "host": self.host}
331                 raise SSHTimeout("Timeout executing command "
332                                  "'%(cmd)s' on host %(host)s" % args)
333             if e:
334                 raise SSHError("Socket error.")
335
336         exit_status = session.recv_exit_status()
337         if exit_status != 0 and raise_on_error:
338             fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
339             details = fmt % {"cmd": cmd, "status": exit_status}
340             if stderr_data:
341                 details += " Last stderr data: '%s'." % stderr_data
342             raise SSHError(details)
343         return exit_status
344
345     def execute(self, cmd, stdin=None, timeout=3600):
346         """Execute the specified command on the server.
347
348         :param cmd:     Command to be executed.
349         :param stdin:   Open file to be sent on process stdin.
350         :param timeout: Timeout for execution of the command.
351
352         :returns: tuple (exit_status, stdout, stderr)
353         """
354         stdout = six.moves.StringIO()
355         stderr = six.moves.StringIO()
356
357         exit_status = self.run(cmd, stderr=stderr,
358                                stdout=stdout, stdin=stdin,
359                                timeout=timeout, raise_on_error=False)
360         stdout.seek(0)
361         stderr.seek(0)
362         return exit_status, stdout.read(), stderr.read()
363
364     def wait(self, timeout=120, interval=1):
365         """Wait for the host will be available via ssh."""
366         start_time = time.time()
367         while True:
368             try:
369                 return self.execute("uname")
370             except (socket.error, SSHError) as e:
371                 self.log.debug("Ssh is still unavailable: %r", e)
372                 time.sleep(interval)
373             if time.time() > (start_time + timeout):
374                 raise SSHTimeout("Timeout waiting for '%s'", self.host)
375
376     def put(self, files, remote_path=b'.', recursive=False):
377         client = self._get_client()
378
379         with SCPClient(client.get_transport()) as scp:
380             scp.put(files, remote_path, recursive)
381
382     def get(self, remote_path, local_path='/tmp/', recursive=True):
383         client = self._get_client()
384
385         with SCPClient(client.get_transport()) as scp:
386             scp.get(remote_path, local_path, recursive)
387
388     # keep shell running in the background, e.g. screen
389     def send_command(self, command):
390         client = self._get_client()
391         client.exec_command(command, get_pty=True)
392
393     def _put_file_sftp(self, localpath, remotepath, mode=None):
394         client = self._get_client()
395
396         with client.open_sftp() as sftp:
397             sftp.put(localpath, remotepath)
398             if mode is None:
399                 mode = 0o777 & os.stat(localpath).st_mode
400             sftp.chmod(remotepath, mode)
401
402     TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
403
404     def _put_file_shell(self, localpath, remotepath, mode=None):
405         # quote to stop wordpslit
406         tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
407         if not tilde:
408             tilde = ''
409         cmd = ['cat > %s"%s"' % (tilde, remotepath)]
410         if mode is not None:
411             # use -- so no options
412             cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
413
414         with open(localpath, "rb") as localfile:
415             # only chmod on successful cat
416             self.run("&& ".join(cmd), stdin=localfile)
417
418     def put_file(self, localpath, remotepath, mode=None):
419         """Copy specified local file to the server.
420
421         :param localpath:   Local filename.
422         :param remotepath:  Remote filename.
423         :param mode:        Permissions to set after upload
424         """
425         try:
426             self._put_file_sftp(localpath, remotepath, mode=mode)
427         except (paramiko.SSHException, socket.error):
428             self._put_file_shell(localpath, remotepath, mode=mode)
429
430     def provision_tool(self, tool_path, tool_file=None):
431         return provision_tool(self, tool_path, tool_file)
432
433     def put_file_obj(self, file_obj, remotepath, mode=None):
434         client = self._get_client()
435
436         with client.open_sftp() as sftp:
437             sftp.putfo(file_obj, remotepath)
438             if mode is not None:
439                 sftp.chmod(remotepath, mode)
440
441     def get_file_obj(self, remotepath, file_obj):
442         client = self._get_client()
443
444         with client.open_sftp() as sftp:
445             sftp.getfo(remotepath, file_obj)
446
447
448 class AutoConnectSSH(SSH):
449
450     # always wait or we will get OpenStack SSH errors
451     def __init__(self, user, host, port=None, pkey=None,
452                  key_filename=None, password=None, name=None, wait=True):
453         super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
454         self._wait = wait
455
456     def _make_dict(self):
457         data = super(AutoConnectSSH, self)._make_dict()
458         data.update({
459             'wait': self._wait
460         })
461         return data
462
463     def _connect(self):
464         if not self.is_connected:
465             self._get_client()
466             if self._wait:
467                 self.wait()
468
469     def drop_connection(self):
470         """ Don't close anything, just force creation of a new client """
471         self._client = False
472
473     def execute(self, cmd, stdin=None, timeout=3600):
474         self._connect()
475         return super(AutoConnectSSH, self).execute(cmd, stdin, timeout)
476
477     def run(self, cmd, stdin=None, stdout=None, stderr=None,
478             raise_on_error=True, timeout=3600,
479             keep_stdin_open=False, pty=False):
480         self._connect()
481         return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
482                                                timeout, keep_stdin_open, pty)
483
484     def put(self, files, remote_path=b'.', recursive=False):
485         self._connect()
486         return super(AutoConnectSSH, self).put(files, remote_path, recursive)
487
488     def put_file(self, local_path, remote_path, mode=None):
489         self._connect()
490         return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
491
492     def put_file_obj(self, file_obj, remote_path, mode=None):
493         self._connect()
494         return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
495
496     def get_file_obj(self, remote_path, file_obj):
497         self._connect()
498         return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
499
500     def provision_tool(self, tool_path, tool_file=None):
501         self._connect()
502         return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
503
504     @staticmethod
505     def get_class():
506         # must return static class name, anything else refers to the calling class
507         # i.e. the subclass, not the superclass
508         return AutoConnectSSH