Merge "ssh.py: add flag to request for a pseudo terminal (pty) for ssh connection"
[yardstick.git] / yardstick / ssh.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 # yardstick comment: this is a modified copy of rally/rally/common/sshutils.py
17
18 """High level ssh library.
19
20 Usage examples:
21
22 Execute command and get output:
23
24     ssh = sshclient.SSH("root", "example.com", port=33)
25     status, stdout, stderr = ssh.execute("ps ax")
26     if status:
27         raise Exception("Command failed with non-zero status.")
28     print stdout.splitlines()
29
30 Execute command with huge output:
31
32     class PseudoFile(io.RawIOBase):
33         def write(chunk):
34             if "error" in chunk:
35                 email_admin(chunk)
36
37     ssh = SSH("root", "example.com")
38     with PseudoFile() as p:
39         ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
40
41 Execute local script on remote side:
42
43     ssh = sshclient.SSH("user", "example.com")
44
45     with open("~/myscript.sh", "r") as stdin_file:
46         status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
47                                        stdin=stdin_file)
48
49 Upload file:
50
51     ssh = SSH("user", "example.com")
52     # use rb for binary files
53     with open("/store/file.gz", "rb") as stdin_file:
54         ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
55
56 Eventlet:
57
58     eventlet.monkey_patch(select=True, time=True)
59     or
60     eventlet.monkey_patch()
61     or
62     sshclient = eventlet.import_patched("yardstick.ssh")
63
64 """
65 import os
66 import select
67 import socket
68 import time
69
70 import logging
71 import paramiko
72 from scp import SCPClient
73 import six
74
75
76 DEFAULT_PORT = 22
77
78
79 class SSHError(Exception):
80     pass
81
82
83 class SSHTimeout(SSHError):
84     pass
85
86
87 class SSH(object):
88     """Represent ssh connection."""
89
90     def __init__(self, user, host, port=DEFAULT_PORT, pkey=None,
91                  key_filename=None, password=None, name=None):
92         """Initialize SSH client.
93
94         :param user: ssh username
95         :param host: hostname or ip address of remote ssh server
96         :param port: remote ssh port
97         :param pkey: RSA or DSS private key string or file object
98         :param key_filename: private key filename
99         :param password: password
100         """
101         self.name = name
102         if name:
103             self.log = logging.getLogger(__name__ + '.' + self.name)
104         else:
105             self.log = logging.getLogger(__name__)
106
107         self.user = user
108         self.host = host
109         # we may get text port from YAML, convert to int
110         self.port = int(port)
111         self.pkey = self._get_pkey(pkey) if pkey else None
112         self.password = password
113         self.key_filename = key_filename
114         self._client = False
115         # paramiko loglevel debug will output ssh protocl debug
116         # we don't ever really want that unless we are debugging paramiko
117         # ssh issues
118         if os.environ.get("PARAMIKO_DEBUG", "").lower() == "true":
119             logging.getLogger("paramiko").setLevel(logging.DEBUG)
120         else:
121             logging.getLogger("paramiko").setLevel(logging.WARN)
122
123     def _get_pkey(self, key):
124         if isinstance(key, six.string_types):
125             key = six.moves.StringIO(key)
126         errors = []
127         for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
128             try:
129                 return key_class.from_private_key(key)
130             except paramiko.SSHException as e:
131                 errors.append(e)
132         raise SSHError("Invalid pkey: %s" % (errors))
133
134     def _get_client(self):
135         if self._client:
136             return self._client
137         try:
138             self._client = paramiko.SSHClient()
139             self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
140             self._client.connect(self.host, username=self.user,
141                                  port=self.port, pkey=self.pkey,
142                                  key_filename=self.key_filename,
143                                  password=self.password,
144                                  allow_agent=False, look_for_keys=False,
145                                  timeout=1)
146             return self._client
147         except Exception as e:
148             message = ("Exception %(exception_type)s was raised "
149                        "during connect. Exception value is: %(exception)r")
150             self._client = False
151             raise SSHError(message % {"exception": e,
152                                       "exception_type": type(e)})
153
154     def close(self):
155         self._client.close()
156         self._client = False
157
158     def run(self, cmd, stdin=None, stdout=None, stderr=None,
159             raise_on_error=True, timeout=3600,
160             keep_stdin_open=False, pty=False):
161         """Execute specified command on the server.
162
163         :param cmd:             Command to be executed.
164         :type cmd:              str
165         :param stdin:           Open file or string to pass to stdin.
166         :param stdout:          Open file to connect to stdout.
167         :param stderr:          Open file to connect to stderr.
168         :param raise_on_error:  If False then exit code will be return. If True
169                                 then exception will be raized if non-zero code.
170         :param timeout:         Timeout in seconds for command execution.
171                                 Default 1 hour. No timeout if set to 0.
172         :param keep_stdin_open: don't close stdin on empty reads
173         :type keep_stdin_open:  bool
174         :param pty:             Request a pseudo terminal for this connection.
175                                 This allows passing control characters.
176                                 Default False.
177         :type pty:              bool
178         """
179
180         client = self._get_client()
181
182         if isinstance(stdin, six.string_types):
183             stdin = six.moves.StringIO(stdin)
184
185         return self._run(client, cmd, stdin=stdin, stdout=stdout,
186                          stderr=stderr, raise_on_error=raise_on_error,
187                          timeout=timeout,
188                          keep_stdin_open=keep_stdin_open, pty=pty)
189
190     def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
191              raise_on_error=True, timeout=3600,
192              keep_stdin_open=False, pty=False):
193
194         transport = client.get_transport()
195         session = transport.open_session()
196         if pty:
197             session.get_pty()
198         session.exec_command(cmd)
199         start_time = time.time()
200
201         data_to_send = ""
202         stderr_data = None
203
204         # If we have data to be sent to stdin then `select' should also
205         # check for stdin availability.
206         if stdin and not stdin.closed:
207             writes = [session]
208         else:
209             writes = []
210
211         while True:
212             # Block until data can be read/write.
213             r, w, e = select.select([session], writes, [session], 1)
214
215             if session.recv_ready():
216                 data = session.recv(4096)
217                 self.log.debug("stdout: %r", data)
218                 if stdout is not None:
219                     stdout.write(data)
220                 continue
221
222             if session.recv_stderr_ready():
223                 stderr_data = session.recv_stderr(4096)
224                 self.log.debug("stderr: %r", stderr_data)
225                 if stderr is not None:
226                     stderr.write(stderr_data)
227                 continue
228
229             if session.send_ready():
230                 if stdin is not None and not stdin.closed:
231                     if not data_to_send:
232                         data_to_send = stdin.read(4096)
233                         if not data_to_send:
234                             # we may need to keep stdin open
235                             if not keep_stdin_open:
236                                 stdin.close()
237                                 session.shutdown_write()
238                                 writes = []
239                     if data_to_send:
240                         sent_bytes = session.send(data_to_send)
241                         # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
242                         data_to_send = data_to_send[sent_bytes:]
243
244             if session.exit_status_ready():
245                 break
246
247             if timeout and (time.time() - timeout) > start_time:
248                 args = {"cmd": cmd, "host": self.host}
249                 raise SSHTimeout("Timeout executing command "
250                                  "'%(cmd)s' on host %(host)s" % args)
251             if e:
252                 raise SSHError("Socket error.")
253
254         exit_status = session.recv_exit_status()
255         if 0 != exit_status and raise_on_error:
256             fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
257             details = fmt % {"cmd": cmd, "status": exit_status}
258             if stderr_data:
259                 details += " Last stderr data: '%s'." % stderr_data
260             raise SSHError(details)
261         return exit_status
262
263     def execute(self, cmd, stdin=None, timeout=3600):
264         """Execute the specified command on the server.
265
266         :param cmd:     Command to be executed.
267         :param stdin:   Open file to be sent on process stdin.
268         :param timeout: Timeout for execution of the command.
269
270         :returns: tuple (exit_status, stdout, stderr)
271         """
272         stdout = six.moves.StringIO()
273         stderr = six.moves.StringIO()
274
275         exit_status = self.run(cmd, stderr=stderr,
276                                stdout=stdout, stdin=stdin,
277                                timeout=timeout, raise_on_error=False)
278         stdout.seek(0)
279         stderr.seek(0)
280         return (exit_status, stdout.read(), stderr.read())
281
282     def wait(self, timeout=120, interval=1):
283         """Wait for the host will be available via ssh."""
284         start_time = time.time()
285         while True:
286             try:
287                 return self.execute("uname")
288             except (socket.error, SSHError) as e:
289                 self.log.debug("Ssh is still unavailable: %r", e)
290                 time.sleep(interval)
291             if time.time() > (start_time + timeout):
292                 raise SSHTimeout("Timeout waiting for '%s'", self.host)
293
294     def put(self, files, remote_path=b'.', recursive=False):
295         client = self._get_client()
296
297         with SCPClient(client.get_transport()) as scp:
298             scp.put(files, remote_path, recursive)
299
300     # keep shell running in the background, e.g. screen
301     def send_command(self, command):
302         client = self._get_client()
303         client.exec_command(command, get_pty=True)
304
305     def _put_file_sftp(self, localpath, remotepath, mode=None):
306         client = self._get_client()
307
308         with client.open_sftp() as sftp:
309             sftp.put(localpath, remotepath)
310             if mode is None:
311                 mode = 0o777 & os.stat(localpath).st_mode
312             sftp.chmod(remotepath, mode)
313
314     def _put_file_shell(self, localpath, remotepath, mode=None):
315         # quote to stop wordpslit
316         cmd = ['cat > "%s"' % remotepath]
317         if mode is not None:
318             # use -- so no options
319             cmd.append('chmod -- 0%o "%s"' % (mode, remotepath))
320
321         with open(localpath, "rb") as localfile:
322             # only chmod on successful cat
323             cmd = "&& ".join(cmd)
324             self.run(cmd, stdin=localfile)
325
326     def put_file(self, localpath, remotepath, mode=None):
327         """Copy specified local file to the server.
328
329         :param localpath:   Local filename.
330         :param remotepath:  Remote filename.
331         :param mode:        Permissions to set after upload
332         """
333         import socket
334         try:
335             self._put_file_sftp(localpath, remotepath, mode=mode)
336         except (paramiko.SSHException, socket.error):
337             self._put_file_shell(localpath, remotepath, mode=mode)