Merge "Added required ubuntu packages to run IxLoad client"
[yardstick.git] / yardstick / ssh.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 # yardstick comment: this is a modified copy of rally/rally/common/sshutils.py
17
18 """High level ssh library.
19
20 Usage examples:
21
22 Execute command and get output:
23
24     ssh = sshclient.SSH("root", "example.com", port=33)
25     status, stdout, stderr = ssh.execute("ps ax")
26     if status:
27         raise Exception("Command failed with non-zero status.")
28     print(stdout.splitlines())
29
30 Execute command with huge output:
31
32     class PseudoFile(io.RawIOBase):
33         def write(chunk):
34             if "error" in chunk:
35                 email_admin(chunk)
36
37     ssh = SSH("root", "example.com")
38     with PseudoFile() as p:
39         ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
40
41 Execute local script on remote side:
42
43     ssh = sshclient.SSH("user", "example.com")
44
45     with open("~/myscript.sh", "r") as stdin_file:
46         status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
47                                        stdin=stdin_file)
48
49 Upload file:
50
51     ssh = SSH("user", "example.com")
52     # use rb for binary files
53     with open("/store/file.gz", "rb") as stdin_file:
54         ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
55
56 Eventlet:
57
58     eventlet.monkey_patch(select=True, time=True)
59     or
60     eventlet.monkey_patch()
61     or
62     sshclient = eventlet.import_patched("yardstick.ssh")
63
64 """
65 from __future__ import absolute_import
66 import os
67 import io
68 import select
69 import socket
70 import time
71 import re
72
73 import logging
74
75 import paramiko
76 from chainmap import ChainMap
77 from oslo_utils import encodeutils
78 from scp import SCPClient
79 import six
80
81 from yardstick.common.utils import try_int
82 from yardstick.network_services.utils import provision_tool
83
84
85 def convert_key_to_str(key):
86     if not isinstance(key, (paramiko.RSAKey, paramiko.DSSKey)):
87         return key
88     k = io.StringIO()
89     key.write_private_key(k)
90     return k.getvalue()
91
92
93 class SSHError(Exception):
94     pass
95
96
97 class SSHTimeout(SSHError):
98     pass
99
100
101 class SSH(object):
102     """Represent ssh connection."""
103
104     SSH_PORT = paramiko.config.SSH_PORT
105
106     @staticmethod
107     def gen_keys(key_filename, bit_count=2048):
108         rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None)
109         rsa_key.write_private_key_file(key_filename)
110         print("Writing %s ..." % key_filename)
111         with open('.'.join([key_filename, "pub"]), "w") as pubkey_file:
112             pubkey_file.write(rsa_key.get_name())
113             pubkey_file.write(' ')
114             pubkey_file.write(rsa_key.get_base64())
115             pubkey_file.write('\n')
116
117     @staticmethod
118     def get_class():
119         # must return static class name, anything else refers to the calling class
120         # i.e. the subclass, not the superclass
121         return SSH
122
123     def __init__(self, user, host, port=None, pkey=None,
124                  key_filename=None, password=None, name=None):
125         """Initialize SSH client.
126
127         :param user: ssh username
128         :param host: hostname or ip address of remote ssh server
129         :param port: remote ssh port
130         :param pkey: RSA or DSS private key string or file object
131         :param key_filename: private key filename
132         :param password: password
133         """
134         self.name = name
135         if name:
136             self.log = logging.getLogger(__name__ + '.' + self.name)
137         else:
138             self.log = logging.getLogger(__name__)
139
140         self.user = user
141         self.host = host
142         # everybody wants to debug this in the caller, do it here instead
143         self.log.debug("user:%s host:%s", user, host)
144
145         # we may get text port from YAML, convert to int
146         self.port = try_int(port, self.SSH_PORT)
147         self.pkey = self._get_pkey(pkey) if pkey else None
148         self.password = password
149         self.key_filename = key_filename
150         self._client = False
151         # paramiko loglevel debug will output ssh protocl debug
152         # we don't ever really want that unless we are debugging paramiko
153         # ssh issues
154         if os.environ.get("PARAMIKO_DEBUG", "").lower() == "true":
155             logging.getLogger("paramiko").setLevel(logging.DEBUG)
156         else:
157             logging.getLogger("paramiko").setLevel(logging.WARN)
158
159     @classmethod
160     def args_from_node(cls, node, overrides=None, defaults=None):
161         if overrides is None:
162             overrides = {}
163         if defaults is None:
164             defaults = {}
165         params = ChainMap(overrides, node, defaults)
166         return {
167             'user': params['user'],
168             'host': params['ip'],
169             'port': params.get('ssh_port', cls.SSH_PORT),
170             'pkey': params.get('pkey'),
171             'key_filename': params.get('key_filename'),
172             'password': params.get('password'),
173             'name': params.get('name'),
174         }
175
176     @classmethod
177     def from_node(cls, node, overrides=None, defaults=None):
178         return cls(**cls.args_from_node(node, overrides, defaults))
179
180     def _get_pkey(self, key):
181         if isinstance(key, six.string_types):
182             key = six.moves.StringIO(key)
183         errors = []
184         for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
185             try:
186                 return key_class.from_private_key(key)
187             except paramiko.SSHException as e:
188                 errors.append(e)
189         raise SSHError("Invalid pkey: %s" % (errors))
190
191     @property
192     def is_connected(self):
193         return bool(self._client)
194
195     def _get_client(self):
196         if self.is_connected:
197             return self._client
198         try:
199             self._client = paramiko.SSHClient()
200             self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
201             self._client.connect(self.host, username=self.user,
202                                  port=self.port, pkey=self.pkey,
203                                  key_filename=self.key_filename,
204                                  password=self.password,
205                                  allow_agent=False, look_for_keys=False,
206                                  timeout=1)
207             return self._client
208         except Exception as e:
209             message = ("Exception %(exception_type)s was raised "
210                        "during connect. Exception value is: %(exception)r")
211             self._client = False
212             raise SSHError(message % {"exception": e,
213                                       "exception_type": type(e)})
214
215     def _make_dict(self):
216         return {
217             'user': self.user,
218             'host': self.host,
219             'port': self.port,
220             'pkey': self.pkey,
221             'key_filename': self.key_filename,
222             'password': self.password,
223             'name': self.name,
224         }
225
226     def copy(self):
227         return self.get_class()(**self._make_dict())
228
229     def close(self):
230         if self._client:
231             self._client.close()
232             self._client = False
233
234     def run(self, cmd, stdin=None, stdout=None, stderr=None,
235             raise_on_error=True, timeout=3600,
236             keep_stdin_open=False, pty=False):
237         """Execute specified command on the server.
238
239         :param cmd:             Command to be executed.
240         :type cmd:              str
241         :param stdin:           Open file or string to pass to stdin.
242         :param stdout:          Open file to connect to stdout.
243         :param stderr:          Open file to connect to stderr.
244         :param raise_on_error:  If False then exit code will be return. If True
245                                 then exception will be raized if non-zero code.
246         :param timeout:         Timeout in seconds for command execution.
247                                 Default 1 hour. No timeout if set to 0.
248         :param keep_stdin_open: don't close stdin on empty reads
249         :type keep_stdin_open:  bool
250         :param pty:             Request a pseudo terminal for this connection.
251                                 This allows passing control characters.
252                                 Default False.
253         :type pty:              bool
254         """
255
256         client = self._get_client()
257
258         if isinstance(stdin, six.string_types):
259             stdin = six.moves.StringIO(stdin)
260
261         return self._run(client, cmd, stdin=stdin, stdout=stdout,
262                          stderr=stderr, raise_on_error=raise_on_error,
263                          timeout=timeout,
264                          keep_stdin_open=keep_stdin_open, pty=pty)
265
266     def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
267              raise_on_error=True, timeout=3600,
268              keep_stdin_open=False, pty=False):
269
270         transport = client.get_transport()
271         session = transport.open_session()
272         if pty:
273             session.get_pty()
274         session.exec_command(cmd)
275         start_time = time.time()
276
277         # encode on transmit, decode on receive
278         data_to_send = encodeutils.safe_encode("", incoming='utf-8')
279         stderr_data = None
280
281         # If we have data to be sent to stdin then `select' should also
282         # check for stdin availability.
283         if stdin and not stdin.closed:
284             writes = [session]
285         else:
286             writes = []
287
288         while True:
289             # Block until data can be read/write.
290             r, w, e = select.select([session], writes, [session], 1)
291
292             if session.recv_ready():
293                 data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
294                 self.log.debug("stdout: %r", data)
295                 if stdout is not None:
296                     stdout.write(data)
297                 continue
298
299             if session.recv_stderr_ready():
300                 stderr_data = encodeutils.safe_decode(
301                     session.recv_stderr(4096), 'utf-8')
302                 self.log.debug("stderr: %r", stderr_data)
303                 if stderr is not None:
304                     stderr.write(stderr_data)
305                 continue
306
307             if session.send_ready():
308                 if stdin is not None and not stdin.closed:
309                     if not data_to_send:
310                         stdin_txt = stdin.read(4096)
311                         if stdin_txt is None:
312                             stdin_txt = ''
313                         data_to_send = encodeutils.safe_encode(
314                             stdin_txt, incoming='utf-8')
315                         if not data_to_send:
316                             # we may need to keep stdin open
317                             if not keep_stdin_open:
318                                 stdin.close()
319                                 session.shutdown_write()
320                                 writes = []
321                     if data_to_send:
322                         sent_bytes = session.send(data_to_send)
323                         # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
324                         data_to_send = data_to_send[sent_bytes:]
325
326             if session.exit_status_ready():
327                 break
328
329             if timeout and (time.time() - timeout) > start_time:
330                 args = {"cmd": cmd, "host": self.host}
331                 raise SSHTimeout("Timeout executing command "
332                                  "'%(cmd)s' on host %(host)s" % args)
333             if e:
334                 raise SSHError("Socket error.")
335
336         exit_status = session.recv_exit_status()
337         if exit_status != 0 and raise_on_error:
338             fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
339             details = fmt % {"cmd": cmd, "status": exit_status}
340             if stderr_data:
341                 details += " Last stderr data: '%s'." % stderr_data
342             raise SSHError(details)
343         return exit_status
344
345     def execute(self, cmd, stdin=None, timeout=3600):
346         """Execute the specified command on the server.
347
348         :param cmd:     Command to be executed.
349         :param stdin:   Open file to be sent on process stdin.
350         :param timeout: Timeout for execution of the command.
351
352         :returns: tuple (exit_status, stdout, stderr)
353         """
354         stdout = six.moves.StringIO()
355         stderr = six.moves.StringIO()
356
357         exit_status = self.run(cmd, stderr=stderr,
358                                stdout=stdout, stdin=stdin,
359                                timeout=timeout, raise_on_error=False)
360         stdout.seek(0)
361         stderr.seek(0)
362         return exit_status, stdout.read(), stderr.read()
363
364     def wait(self, timeout=120, interval=1):
365         """Wait for the host will be available via ssh."""
366         start_time = time.time()
367         while True:
368             try:
369                 return self.execute("uname")
370             except (socket.error, SSHError) as e:
371                 self.log.debug("Ssh is still unavailable: %r", e)
372                 time.sleep(interval)
373             if time.time() > (start_time + timeout):
374                 raise SSHTimeout("Timeout waiting for '%s'", self.host)
375
376     def put(self, files, remote_path=b'.', recursive=False):
377         client = self._get_client()
378
379         with SCPClient(client.get_transport()) as scp:
380             scp.put(files, remote_path, recursive)
381
382     # keep shell running in the background, e.g. screen
383     def send_command(self, command):
384         client = self._get_client()
385         client.exec_command(command, get_pty=True)
386
387     def _put_file_sftp(self, localpath, remotepath, mode=None):
388         client = self._get_client()
389
390         with client.open_sftp() as sftp:
391             sftp.put(localpath, remotepath)
392             if mode is None:
393                 mode = 0o777 & os.stat(localpath).st_mode
394             sftp.chmod(remotepath, mode)
395
396     TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
397
398     def _put_file_shell(self, localpath, remotepath, mode=None):
399         # quote to stop wordpslit
400         tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
401         if not tilde:
402             tilde = ''
403         cmd = ['cat > %s"%s"' % (tilde, remotepath)]
404         if mode is not None:
405             # use -- so no options
406             cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
407
408         with open(localpath, "rb") as localfile:
409             # only chmod on successful cat
410             self.run("&& ".join(cmd), stdin=localfile)
411
412     def put_file(self, localpath, remotepath, mode=None):
413         """Copy specified local file to the server.
414
415         :param localpath:   Local filename.
416         :param remotepath:  Remote filename.
417         :param mode:        Permissions to set after upload
418         """
419         try:
420             self._put_file_sftp(localpath, remotepath, mode=mode)
421         except (paramiko.SSHException, socket.error):
422             self._put_file_shell(localpath, remotepath, mode=mode)
423
424     def provision_tool(self, tool_path, tool_file=None):
425         return provision_tool(self, tool_path, tool_file)
426
427     def put_file_obj(self, file_obj, remotepath, mode=None):
428         client = self._get_client()
429
430         with client.open_sftp() as sftp:
431             sftp.putfo(file_obj, remotepath)
432             if mode is not None:
433                 sftp.chmod(remotepath, mode)
434
435     def get_file_obj(self, remotepath, file_obj):
436         client = self._get_client()
437
438         with client.open_sftp() as sftp:
439             sftp.getfo(remotepath, file_obj)
440
441
442 class AutoConnectSSH(SSH):
443
444     # always wait or we will get OpenStack SSH errors
445     def __init__(self, user, host, port=None, pkey=None,
446                  key_filename=None, password=None, name=None, wait=True):
447         super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
448         self._wait = wait
449
450     def _make_dict(self):
451         data = super(AutoConnectSSH, self)._make_dict()
452         data.update({
453             'wait': self._wait
454         })
455         return data
456
457     def _connect(self):
458         if not self.is_connected:
459             self._get_client()
460             if self._wait:
461                 self.wait()
462
463     def drop_connection(self):
464         """ Don't close anything, just force creation of a new client """
465         self._client = False
466
467     def execute(self, cmd, stdin=None, timeout=3600):
468         self._connect()
469         return super(AutoConnectSSH, self).execute(cmd, stdin, timeout)
470
471     def run(self, cmd, stdin=None, stdout=None, stderr=None,
472             raise_on_error=True, timeout=3600,
473             keep_stdin_open=False, pty=False):
474         self._connect()
475         return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
476                                                timeout, keep_stdin_open, pty)
477
478     def put(self, files, remote_path=b'.', recursive=False):
479         self._connect()
480         return super(AutoConnectSSH, self).put(files, remote_path, recursive)
481
482     def put_file(self, local_path, remote_path, mode=None):
483         self._connect()
484         return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
485
486     def put_file_obj(self, file_obj, remote_path, mode=None):
487         self._connect()
488         return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
489
490     def get_file_obj(self, remote_path, file_obj):
491         self._connect()
492         return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
493
494     def provision_tool(self, tool_path, tool_file=None):
495         self._connect()
496         return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
497
498     @staticmethod
499     def get_class():
500         # must return static class name, anything else refers to the calling class
501         # i.e. the subclass, not the superclass
502         return AutoConnectSSH