927ca94dbb2d045516a1267af5039024e07ada93
[yardstick.git] / yardstick / ssh.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 # yardstick comment: this is a modified copy of rally/rally/common/sshutils.py
17
18 """High level ssh library.
19
20 Usage examples:
21
22 Execute command and get output:
23
24     ssh = sshclient.SSH("root", "example.com", port=33)
25     status, stdout, stderr = ssh.execute("ps ax")
26     if status:
27         raise Exception("Command failed with non-zero status.")
28     print stdout.splitlines()
29
30 Execute command with huge output:
31
32     class PseudoFile(io.RawIOBase):
33         def write(chunk):
34             if "error" in chunk:
35                 email_admin(chunk)
36
37     ssh = SSH("root", "example.com")
38     with PseudoFile() as p:
39         ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
40
41 Execute local script on remote side:
42
43     ssh = sshclient.SSH("user", "example.com")
44
45     with open("~/myscript.sh", "r") as stdin_file:
46         status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
47                                        stdin=stdin_file)
48
49 Upload file:
50
51     ssh = SSH("user", "example.com")
52     # use rb for binary files
53     with open("/store/file.gz", "rb") as stdin_file:
54         ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
55
56 Eventlet:
57
58     eventlet.monkey_patch(select=True, time=True)
59     or
60     eventlet.monkey_patch()
61     or
62     sshclient = eventlet.import_patched("yardstick.ssh")
63
64 """
65 import os
66 import select
67 import socket
68 import time
69 import re
70
71 import logging
72 import paramiko
73 from scp import SCPClient
74 import six
75
76
77 DEFAULT_PORT = 22
78
79
80 class SSHError(Exception):
81     pass
82
83
84 class SSHTimeout(SSHError):
85     pass
86
87
88 class SSH(object):
89     """Represent ssh connection."""
90
91     def __init__(self, user, host, port=DEFAULT_PORT, pkey=None,
92                  key_filename=None, password=None, name=None):
93         """Initialize SSH client.
94
95         :param user: ssh username
96         :param host: hostname or ip address of remote ssh server
97         :param port: remote ssh port
98         :param pkey: RSA or DSS private key string or file object
99         :param key_filename: private key filename
100         :param password: password
101         """
102         self.name = name
103         if name:
104             self.log = logging.getLogger(__name__ + '.' + self.name)
105         else:
106             self.log = logging.getLogger(__name__)
107
108         self.user = user
109         self.host = host
110         # we may get text port from YAML, convert to int
111         self.port = int(port)
112         self.pkey = self._get_pkey(pkey) if pkey else None
113         self.password = password
114         self.key_filename = key_filename
115         self._client = False
116         # paramiko loglevel debug will output ssh protocl debug
117         # we don't ever really want that unless we are debugging paramiko
118         # ssh issues
119         if os.environ.get("PARAMIKO_DEBUG", "").lower() == "true":
120             logging.getLogger("paramiko").setLevel(logging.DEBUG)
121         else:
122             logging.getLogger("paramiko").setLevel(logging.WARN)
123
124     def _get_pkey(self, key):
125         if isinstance(key, six.string_types):
126             key = six.moves.StringIO(key)
127         errors = []
128         for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
129             try:
130                 return key_class.from_private_key(key)
131             except paramiko.SSHException as e:
132                 errors.append(e)
133         raise SSHError("Invalid pkey: %s" % (errors))
134
135     def _get_client(self):
136         if self._client:
137             return self._client
138         try:
139             self._client = paramiko.SSHClient()
140             self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
141             self._client.connect(self.host, username=self.user,
142                                  port=self.port, pkey=self.pkey,
143                                  key_filename=self.key_filename,
144                                  password=self.password,
145                                  allow_agent=False, look_for_keys=False,
146                                  timeout=1)
147             return self._client
148         except Exception as e:
149             message = ("Exception %(exception_type)s was raised "
150                        "during connect. Exception value is: %(exception)r")
151             self._client = False
152             raise SSHError(message % {"exception": e,
153                                       "exception_type": type(e)})
154
155     def close(self):
156         self._client.close()
157         self._client = False
158
159     def run(self, cmd, stdin=None, stdout=None, stderr=None,
160             raise_on_error=True, timeout=3600,
161             keep_stdin_open=False, pty=False):
162         """Execute specified command on the server.
163
164         :param cmd:             Command to be executed.
165         :type cmd:              str
166         :param stdin:           Open file or string to pass to stdin.
167         :param stdout:          Open file to connect to stdout.
168         :param stderr:          Open file to connect to stderr.
169         :param raise_on_error:  If False then exit code will be return. If True
170                                 then exception will be raized if non-zero code.
171         :param timeout:         Timeout in seconds for command execution.
172                                 Default 1 hour. No timeout if set to 0.
173         :param keep_stdin_open: don't close stdin on empty reads
174         :type keep_stdin_open:  bool
175         :param pty:             Request a pseudo terminal for this connection.
176                                 This allows passing control characters.
177                                 Default False.
178         :type pty:              bool
179         """
180
181         client = self._get_client()
182
183         if isinstance(stdin, six.string_types):
184             stdin = six.moves.StringIO(stdin)
185
186         return self._run(client, cmd, stdin=stdin, stdout=stdout,
187                          stderr=stderr, raise_on_error=raise_on_error,
188                          timeout=timeout,
189                          keep_stdin_open=keep_stdin_open, pty=pty)
190
191     def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
192              raise_on_error=True, timeout=3600,
193              keep_stdin_open=False, pty=False):
194
195         transport = client.get_transport()
196         session = transport.open_session()
197         if pty:
198             session.get_pty()
199         session.exec_command(cmd)
200         start_time = time.time()
201
202         data_to_send = ""
203         stderr_data = None
204
205         # If we have data to be sent to stdin then `select' should also
206         # check for stdin availability.
207         if stdin and not stdin.closed:
208             writes = [session]
209         else:
210             writes = []
211
212         while True:
213             # Block until data can be read/write.
214             r, w, e = select.select([session], writes, [session], 1)
215
216             if session.recv_ready():
217                 data = session.recv(4096)
218                 self.log.debug("stdout: %r", data)
219                 if stdout is not None:
220                     stdout.write(data)
221                 continue
222
223             if session.recv_stderr_ready():
224                 stderr_data = session.recv_stderr(4096)
225                 self.log.debug("stderr: %r", stderr_data)
226                 if stderr is not None:
227                     stderr.write(stderr_data)
228                 continue
229
230             if session.send_ready():
231                 if stdin is not None and not stdin.closed:
232                     if not data_to_send:
233                         data_to_send = stdin.read(4096)
234                         if not data_to_send:
235                             # we may need to keep stdin open
236                             if not keep_stdin_open:
237                                 stdin.close()
238                                 session.shutdown_write()
239                                 writes = []
240                     if data_to_send:
241                         sent_bytes = session.send(data_to_send)
242                         # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
243                         data_to_send = data_to_send[sent_bytes:]
244
245             if session.exit_status_ready():
246                 break
247
248             if timeout and (time.time() - timeout) > start_time:
249                 args = {"cmd": cmd, "host": self.host}
250                 raise SSHTimeout("Timeout executing command "
251                                  "'%(cmd)s' on host %(host)s" % args)
252             if e:
253                 raise SSHError("Socket error.")
254
255         exit_status = session.recv_exit_status()
256         if exit_status != 0 and raise_on_error:
257             fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
258             details = fmt % {"cmd": cmd, "status": exit_status}
259             if stderr_data:
260                 details += " Last stderr data: '%s'." % stderr_data
261             raise SSHError(details)
262         return exit_status
263
264     def execute(self, cmd, stdin=None, timeout=3600):
265         """Execute the specified command on the server.
266
267         :param cmd:     Command to be executed.
268         :param stdin:   Open file to be sent on process stdin.
269         :param timeout: Timeout for execution of the command.
270
271         :returns: tuple (exit_status, stdout, stderr)
272         """
273         stdout = six.moves.StringIO()
274         stderr = six.moves.StringIO()
275
276         exit_status = self.run(cmd, stderr=stderr,
277                                stdout=stdout, stdin=stdin,
278                                timeout=timeout, raise_on_error=False)
279         stdout.seek(0)
280         stderr.seek(0)
281         return (exit_status, stdout.read(), stderr.read())
282
283     def wait(self, timeout=120, interval=1):
284         """Wait for the host will be available via ssh."""
285         start_time = time.time()
286         while True:
287             try:
288                 return self.execute("uname")
289             except (socket.error, SSHError) as e:
290                 self.log.debug("Ssh is still unavailable: %r", e)
291                 time.sleep(interval)
292             if time.time() > (start_time + timeout):
293                 raise SSHTimeout("Timeout waiting for '%s'", self.host)
294
295     def put(self, files, remote_path=b'.', recursive=False):
296         client = self._get_client()
297
298         with SCPClient(client.get_transport()) as scp:
299             scp.put(files, remote_path, recursive)
300
301     # keep shell running in the background, e.g. screen
302     def send_command(self, command):
303         client = self._get_client()
304         client.exec_command(command, get_pty=True)
305
306     def _put_file_sftp(self, localpath, remotepath, mode=None):
307         client = self._get_client()
308
309         with client.open_sftp() as sftp:
310             sftp.put(localpath, remotepath)
311             if mode is None:
312                 mode = 0o777 & os.stat(localpath).st_mode
313             sftp.chmod(remotepath, mode)
314
315     TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
316
317     def _put_file_shell(self, localpath, remotepath, mode=None):
318         # quote to stop wordpslit
319         tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
320         if not tilde:
321             tilde = ''
322         cmd = ['cat > %s"%s"' % (tilde, remotepath)]
323         if mode is not None:
324             # use -- so no options
325             cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
326
327         with open(localpath, "rb") as localfile:
328             # only chmod on successful cat
329             self.run("&& ".join(cmd), stdin=localfile)
330
331     def put_file(self, localpath, remotepath, mode=None):
332         """Copy specified local file to the server.
333
334         :param localpath:   Local filename.
335         :param remotepath:  Remote filename.
336         :param mode:        Permissions to set after upload
337         """
338         try:
339             self._put_file_sftp(localpath, remotepath, mode=mode)
340         except (paramiko.SSHException, socket.error):
341             self._put_file_shell(localpath, remotepath, mode=mode)