cf9adf0dc6d5c994fbc57301b18e94a3fd459885
[yardstick.git] / yardstick / ssh.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 # yardstick comment: this is a modified copy of rally/rally/common/sshutils.py
17
18 """High level ssh library.
19
20 Usage examples:
21
22 Execute command and get output:
23
24     ssh = sshclient.SSH("root", "example.com", port=33)
25     status, stdout, stderr = ssh.execute("ps ax")
26     if status:
27         raise Exception("Command failed with non-zero status.")
28     print(stdout.splitlines())
29
30 Execute command with huge output:
31
32     class PseudoFile(io.RawIOBase):
33         def write(chunk):
34             if "error" in chunk:
35                 email_admin(chunk)
36
37     ssh = SSH("root", "example.com")
38     with PseudoFile() as p:
39         ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
40
41 Execute local script on remote side:
42
43     ssh = sshclient.SSH("user", "example.com")
44
45     with open("~/myscript.sh", "r") as stdin_file:
46         status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
47                                        stdin=stdin_file)
48
49 Upload file:
50
51     ssh = SSH("user", "example.com")
52     # use rb for binary files
53     with open("/store/file.gz", "rb") as stdin_file:
54         ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
55
56 Eventlet:
57
58     eventlet.monkey_patch(select=True, time=True)
59     or
60     eventlet.monkey_patch()
61     or
62     sshclient = eventlet.import_patched("yardstick.ssh")
63
64 """
65 from __future__ import absolute_import
66 import os
67 import select
68 import socket
69 import time
70 import re
71
72 import logging
73
74 import paramiko
75 from chainmap import ChainMap
76 from oslo_utils import encodeutils
77 from scp import SCPClient
78 import six
79
80
81 SSH_PORT = paramiko.config.SSH_PORT
82
83
84 class SSHError(Exception):
85     pass
86
87
88 class SSHTimeout(SSHError):
89     pass
90
91
92 class SSH(object):
93     """Represent ssh connection."""
94
95     def __init__(self, user, host, port=SSH_PORT, pkey=None,
96                  key_filename=None, password=None, name=None):
97         """Initialize SSH client.
98
99         :param user: ssh username
100         :param host: hostname or ip address of remote ssh server
101         :param port: remote ssh port
102         :param pkey: RSA or DSS private key string or file object
103         :param key_filename: private key filename
104         :param password: password
105         """
106         self.name = name
107         if name:
108             self.log = logging.getLogger(__name__ + '.' + self.name)
109         else:
110             self.log = logging.getLogger(__name__)
111
112         self.user = user
113         self.host = host
114         # everybody wants to debug this in the caller, do it here instead
115         self.log.debug("user:%s host:%s", user, host)
116
117         # we may get text port from YAML, convert to int
118         self.port = int(port)
119         self.pkey = self._get_pkey(pkey) if pkey else None
120         self.password = password
121         self.key_filename = key_filename
122         self._client = False
123         # paramiko loglevel debug will output ssh protocl debug
124         # we don't ever really want that unless we are debugging paramiko
125         # ssh issues
126         if os.environ.get("PARAMIKO_DEBUG", "").lower() == "true":
127             logging.getLogger("paramiko").setLevel(logging.DEBUG)
128         else:
129             logging.getLogger("paramiko").setLevel(logging.WARN)
130
131     @classmethod
132     def from_node(cls, node, overrides=None, defaults=None):
133         if overrides is None:
134             overrides = {}
135         if defaults is None:
136             defaults = {}
137         params = ChainMap(overrides, node, defaults)
138         return cls(
139             user=params['user'],
140             host=params['ip'],
141             # paramiko doesn't like None default, requires SSH_PORT default
142             port=params.get('ssh_port', SSH_PORT),
143             pkey=params.get('pkey'),
144             key_filename=params.get('key_filename'),
145             password=params.get('password'),
146             name=params.get('name'))
147
148     def _get_pkey(self, key):
149         if isinstance(key, six.string_types):
150             key = six.moves.StringIO(key)
151         errors = []
152         for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
153             try:
154                 return key_class.from_private_key(key)
155             except paramiko.SSHException as e:
156                 errors.append(e)
157         raise SSHError("Invalid pkey: %s" % (errors))
158
159     def _get_client(self):
160         if self._client:
161             return self._client
162         try:
163             self._client = paramiko.SSHClient()
164             self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
165             self._client.connect(self.host, username=self.user,
166                                  port=self.port, pkey=self.pkey,
167                                  key_filename=self.key_filename,
168                                  password=self.password,
169                                  allow_agent=False, look_for_keys=False,
170                                  timeout=1)
171             return self._client
172         except Exception as e:
173             message = ("Exception %(exception_type)s was raised "
174                        "during connect. Exception value is: %(exception)r")
175             self._client = False
176             raise SSHError(message % {"exception": e,
177                                       "exception_type": type(e)})
178
179     def close(self):
180         self._client.close()
181         self._client = False
182
183     def run(self, cmd, stdin=None, stdout=None, stderr=None,
184             raise_on_error=True, timeout=3600,
185             keep_stdin_open=False, pty=False):
186         """Execute specified command on the server.
187
188         :param cmd:             Command to be executed.
189         :type cmd:              str
190         :param stdin:           Open file or string to pass to stdin.
191         :param stdout:          Open file to connect to stdout.
192         :param stderr:          Open file to connect to stderr.
193         :param raise_on_error:  If False then exit code will be return. If True
194                                 then exception will be raized if non-zero code.
195         :param timeout:         Timeout in seconds for command execution.
196                                 Default 1 hour. No timeout if set to 0.
197         :param keep_stdin_open: don't close stdin on empty reads
198         :type keep_stdin_open:  bool
199         :param pty:             Request a pseudo terminal for this connection.
200                                 This allows passing control characters.
201                                 Default False.
202         :type pty:              bool
203         """
204
205         client = self._get_client()
206
207         if isinstance(stdin, six.string_types):
208             stdin = six.moves.StringIO(stdin)
209
210         return self._run(client, cmd, stdin=stdin, stdout=stdout,
211                          stderr=stderr, raise_on_error=raise_on_error,
212                          timeout=timeout,
213                          keep_stdin_open=keep_stdin_open, pty=pty)
214
215     def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
216              raise_on_error=True, timeout=3600,
217              keep_stdin_open=False, pty=False):
218
219         transport = client.get_transport()
220         session = transport.open_session()
221         if pty:
222             session.get_pty()
223         session.exec_command(cmd)
224         start_time = time.time()
225
226         # encode on transmit, decode on receive
227         data_to_send = encodeutils.safe_encode("", incoming='utf-8')
228         stderr_data = None
229
230         # If we have data to be sent to stdin then `select' should also
231         # check for stdin availability.
232         if stdin and not stdin.closed:
233             writes = [session]
234         else:
235             writes = []
236
237         while True:
238             # Block until data can be read/write.
239             r, w, e = select.select([session], writes, [session], 1)
240
241             if session.recv_ready():
242                 data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
243                 self.log.debug("stdout: %r", data)
244                 if stdout is not None:
245                     stdout.write(data)
246                 continue
247
248             if session.recv_stderr_ready():
249                 stderr_data = encodeutils.safe_decode(
250                     session.recv_stderr(4096), 'utf-8')
251                 self.log.debug("stderr: %r", stderr_data)
252                 if stderr is not None:
253                     stderr.write(stderr_data)
254                 continue
255
256             if session.send_ready():
257                 if stdin is not None and not stdin.closed:
258                     if not data_to_send:
259                         stdin_txt = stdin.read(4096)
260                         if stdin_txt is None:
261                             stdin_txt = ''
262                         data_to_send = encodeutils.safe_encode(
263                             stdin_txt, incoming='utf-8')
264                         if not data_to_send:
265                             # we may need to keep stdin open
266                             if not keep_stdin_open:
267                                 stdin.close()
268                                 session.shutdown_write()
269                                 writes = []
270                     if data_to_send:
271                         sent_bytes = session.send(data_to_send)
272                         # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
273                         data_to_send = data_to_send[sent_bytes:]
274
275             if session.exit_status_ready():
276                 break
277
278             if timeout and (time.time() - timeout) > start_time:
279                 args = {"cmd": cmd, "host": self.host}
280                 raise SSHTimeout("Timeout executing command "
281                                  "'%(cmd)s' on host %(host)s" % args)
282             if e:
283                 raise SSHError("Socket error.")
284
285         exit_status = session.recv_exit_status()
286         if exit_status != 0 and raise_on_error:
287             fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
288             details = fmt % {"cmd": cmd, "status": exit_status}
289             if stderr_data:
290                 details += " Last stderr data: '%s'." % stderr_data
291             raise SSHError(details)
292         return exit_status
293
294     def execute(self, cmd, stdin=None, timeout=3600):
295         """Execute the specified command on the server.
296
297         :param cmd:     Command to be executed.
298         :param stdin:   Open file to be sent on process stdin.
299         :param timeout: Timeout for execution of the command.
300
301         :returns: tuple (exit_status, stdout, stderr)
302         """
303         stdout = six.moves.StringIO()
304         stderr = six.moves.StringIO()
305
306         exit_status = self.run(cmd, stderr=stderr,
307                                stdout=stdout, stdin=stdin,
308                                timeout=timeout, raise_on_error=False)
309         stdout.seek(0)
310         stderr.seek(0)
311         return (exit_status, stdout.read(), stderr.read())
312
313     def wait(self, timeout=120, interval=1):
314         """Wait for the host will be available via ssh."""
315         start_time = time.time()
316         while True:
317             try:
318                 return self.execute("uname")
319             except (socket.error, SSHError) as e:
320                 self.log.debug("Ssh is still unavailable: %r", e)
321                 time.sleep(interval)
322             if time.time() > (start_time + timeout):
323                 raise SSHTimeout("Timeout waiting for '%s'", self.host)
324
325     def put(self, files, remote_path=b'.', recursive=False):
326         client = self._get_client()
327
328         with SCPClient(client.get_transport()) as scp:
329             scp.put(files, remote_path, recursive)
330
331     # keep shell running in the background, e.g. screen
332     def send_command(self, command):
333         client = self._get_client()
334         client.exec_command(command, get_pty=True)
335
336     def _put_file_sftp(self, localpath, remotepath, mode=None):
337         client = self._get_client()
338
339         with client.open_sftp() as sftp:
340             sftp.put(localpath, remotepath)
341             if mode is None:
342                 mode = 0o777 & os.stat(localpath).st_mode
343             sftp.chmod(remotepath, mode)
344
345     TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
346
347     def _put_file_shell(self, localpath, remotepath, mode=None):
348         # quote to stop wordpslit
349         tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
350         if not tilde:
351             tilde = ''
352         cmd = ['cat > %s"%s"' % (tilde, remotepath)]
353         if mode is not None:
354             # use -- so no options
355             cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
356
357         with open(localpath, "rb") as localfile:
358             # only chmod on successful cat
359             self.run("&& ".join(cmd), stdin=localfile)
360
361     def put_file(self, localpath, remotepath, mode=None):
362         """Copy specified local file to the server.
363
364         :param localpath:   Local filename.
365         :param remotepath:  Remote filename.
366         :param mode:        Permissions to set after upload
367         """
368         try:
369             self._put_file_sftp(localpath, remotepath, mode=mode)
370         except (paramiko.SSHException, socket.error):
371             self._put_file_shell(localpath, remotepath, mode=mode)