import new _put_file_shell method from upstream rally
[yardstick.git] / tests / unit / test_ssh.py
1 # Copyright 2013: Mirantis Inc.
2 # All Rights Reserved.
3 #
4 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 #    not use this file except in compliance with the License. You may obtain
6 #    a copy of the License at
7 #
8 #         http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #    Unless required by applicable law or agreed to in writing, software
11 #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 #    License for the specific language governing permissions and limitations
14 #    under the License.
15
16 # yardstick comment: this file is a modified copy of
17 # rally/tests/unit/common/test_sshutils.py
18
19 import os
20 import socket
21 import unittest
22 import mock
23
24 from yardstick import ssh
25
26
27 class FakeParamikoException(Exception):
28     pass
29
30
31 class SSHTestCase(unittest.TestCase):
32     """Test all small SSH methods."""
33
34     def setUp(self):
35         super(SSHTestCase, self).setUp()
36         self.test_client = ssh.SSH("root", "example.net")
37
38     @mock.patch("yardstick.ssh.SSH._get_pkey")
39     def test_construct(self, mock_ssh__get_pkey):
40         mock_ssh__get_pkey.return_value = "pkey"
41         test_ssh = ssh.SSH("root", "example.net", port=33, pkey="key",
42                            key_filename="kf", password="secret")
43         mock_ssh__get_pkey.assert_called_once_with("key")
44         self.assertEqual("root", test_ssh.user)
45         self.assertEqual("example.net", test_ssh.host)
46         self.assertEqual(33, test_ssh.port)
47         self.assertEqual("pkey", test_ssh.pkey)
48         self.assertEqual("kf", test_ssh.key_filename)
49         self.assertEqual("secret", test_ssh.password)
50
51     def test_construct_default(self):
52         self.assertEqual("root", self.test_client.user)
53         self.assertEqual("example.net", self.test_client.host)
54         self.assertEqual(22, self.test_client.port)
55         self.assertIsNone(self.test_client.pkey)
56         self.assertIsNone(self.test_client.key_filename)
57         self.assertIsNone(self.test_client.password)
58
59     @mock.patch("yardstick.ssh.paramiko")
60     def test__get_pkey_invalid(self, mock_paramiko):
61         mock_paramiko.SSHException = FakeParamikoException
62         rsa = mock_paramiko.rsakey.RSAKey
63         dss = mock_paramiko.dsskey.DSSKey
64         rsa.from_private_key.side_effect = mock_paramiko.SSHException
65         dss.from_private_key.side_effect = mock_paramiko.SSHException
66         self.assertRaises(ssh.SSHError, self.test_client._get_pkey, "key")
67
68     @mock.patch("yardstick.ssh.six.moves.StringIO")
69     @mock.patch("yardstick.ssh.paramiko")
70     def test__get_pkey_dss(self, mock_paramiko, mock_string_io):
71         mock_paramiko.SSHException = FakeParamikoException
72         mock_string_io.return_value = "string_key"
73         mock_paramiko.dsskey.DSSKey.from_private_key.return_value = "dss_key"
74         rsa = mock_paramiko.rsakey.RSAKey
75         rsa.from_private_key.side_effect = mock_paramiko.SSHException
76         key = self.test_client._get_pkey("key")
77         dss_calls = mock_paramiko.dsskey.DSSKey.from_private_key.mock_calls
78         self.assertEqual([mock.call("string_key")], dss_calls)
79         self.assertEqual(key, "dss_key")
80         mock_string_io.assert_called_once_with("key")
81
82     @mock.patch("yardstick.ssh.six.moves.StringIO")
83     @mock.patch("yardstick.ssh.paramiko")
84     def test__get_pkey_rsa(self, mock_paramiko, mock_string_io):
85         mock_paramiko.SSHException = FakeParamikoException
86         mock_string_io.return_value = "string_key"
87         mock_paramiko.rsakey.RSAKey.from_private_key.return_value = "rsa_key"
88         dss = mock_paramiko.dsskey.DSSKey
89         dss.from_private_key.side_effect = mock_paramiko.SSHException
90         key = self.test_client._get_pkey("key")
91         rsa_calls = mock_paramiko.rsakey.RSAKey.from_private_key.mock_calls
92         self.assertEqual([mock.call("string_key")], rsa_calls)
93         self.assertEqual(key, "rsa_key")
94         mock_string_io.assert_called_once_with("key")
95
96     @mock.patch("yardstick.ssh.SSH._get_pkey")
97     @mock.patch("yardstick.ssh.paramiko")
98     def test__get_client(self, mock_paramiko, mock_ssh__get_pkey):
99         mock_ssh__get_pkey.return_value = "key"
100         fake_client = mock.Mock()
101         mock_paramiko.SSHClient.return_value = fake_client
102         mock_paramiko.AutoAddPolicy.return_value = "autoadd"
103
104         test_ssh = ssh.SSH("admin", "example.net", pkey="key")
105         client = test_ssh._get_client()
106
107         self.assertEqual(fake_client, client)
108         client_calls = [
109             mock.call.set_missing_host_key_policy("autoadd"),
110             mock.call.connect("example.net", username="admin",
111                               port=22, pkey="key", key_filename=None,
112                               password=None,
113                               allow_agent=False, look_for_keys=False,
114                               timeout=1),
115         ]
116         self.assertEqual(client_calls, client.mock_calls)
117
118     def test_close(self):
119         with mock.patch.object(self.test_client, "_client") as m_client:
120             self.test_client.close()
121         m_client.close.assert_called_once_with()
122         self.assertFalse(self.test_client._client)
123
124     @mock.patch("yardstick.ssh.six.moves.StringIO")
125     def test_execute(self, mock_string_io):
126         mock_string_io.side_effect = stdio = [mock.Mock(), mock.Mock()]
127         stdio[0].read.return_value = "stdout fake data"
128         stdio[1].read.return_value = "stderr fake data"
129         with mock.patch.object(self.test_client, "run", return_value=0)\
130                 as mock_run:
131             status, stdout, stderr = self.test_client.execute(
132                 "cmd",
133                 stdin="fake_stdin",
134                 timeout=43)
135         mock_run.assert_called_once_with(
136             "cmd", stdin="fake_stdin", stdout=stdio[0],
137             stderr=stdio[1], timeout=43, raise_on_error=False)
138         self.assertEqual(0, status)
139         self.assertEqual("stdout fake data", stdout)
140         self.assertEqual("stderr fake data", stderr)
141
142     @mock.patch("yardstick.ssh.time")
143     def test_wait_timeout(self, mock_time):
144         mock_time.time.side_effect = [1, 50, 150]
145         self.test_client.execute = mock.Mock(side_effect=[ssh.SSHError,
146                                                           ssh.SSHError,
147                                                           0])
148         self.assertRaises(ssh.SSHTimeout, self.test_client.wait)
149         self.assertEqual([mock.call("uname")] * 2,
150                          self.test_client.execute.mock_calls)
151
152     @mock.patch("yardstick.ssh.time")
153     def test_wait(self, mock_time):
154         mock_time.time.side_effect = [1, 50, 100]
155         self.test_client.execute = mock.Mock(side_effect=[ssh.SSHError,
156                                                           ssh.SSHError,
157                                                           0])
158         self.test_client.wait()
159         self.assertEqual([mock.call("uname")] * 3,
160                          self.test_client.execute.mock_calls)
161
162     @mock.patch("yardstick.ssh.paramiko")
163     def test_send_command(self, mock_paramiko):
164         paramiko_sshclient = self.test_client._get_client()
165         with mock.patch.object(paramiko_sshclient, "exec_command") \
166                 as mock_paramiko_exec_command:
167             self.test_client.send_command('cmd')
168         mock_paramiko_exec_command.assert_called_once_with('cmd',
169                                                            get_pty=True)
170
171
172 class SSHRunTestCase(unittest.TestCase):
173     """Test SSH.run method in different aspects.
174
175     Also tested method "execute".
176     """
177
178     def setUp(self):
179         super(SSHRunTestCase, self).setUp()
180
181         self.fake_client = mock.Mock()
182         self.fake_session = mock.Mock()
183         self.fake_transport = mock.Mock()
184
185         self.fake_transport.open_session.return_value = self.fake_session
186         self.fake_client.get_transport.return_value = self.fake_transport
187
188         self.fake_session.recv_ready.return_value = False
189         self.fake_session.recv_stderr_ready.return_value = False
190         self.fake_session.send_ready.return_value = False
191         self.fake_session.exit_status_ready.return_value = True
192         self.fake_session.recv_exit_status.return_value = 0
193
194         self.test_client = ssh.SSH("admin", "example.net")
195         self.test_client._get_client = mock.Mock(return_value=self.fake_client)
196
197     @mock.patch("yardstick.ssh.select")
198     def test_execute(self, mock_select):
199         mock_select.select.return_value = ([], [], [])
200         self.fake_session.recv_ready.side_effect = [1, 0, 0]
201         self.fake_session.recv_stderr_ready.side_effect = [1, 0]
202         self.fake_session.recv.return_value = "ok"
203         self.fake_session.recv_stderr.return_value = "error"
204         self.fake_session.exit_status_ready.return_value = 1
205         self.fake_session.recv_exit_status.return_value = 127
206         self.assertEqual((127, "ok", "error"), self.test_client.execute("cmd"))
207         self.fake_session.exec_command.assert_called_once_with("cmd")
208
209     @mock.patch("yardstick.ssh.select")
210     def test_execute_args(self, mock_select):
211         mock_select.select.return_value = ([], [], [])
212         self.fake_session.recv_ready.side_effect = [1, 0, 0]
213         self.fake_session.recv_stderr_ready.side_effect = [1, 0]
214         self.fake_session.recv.return_value = "ok"
215         self.fake_session.recv_stderr.return_value = "error"
216         self.fake_session.exit_status_ready.return_value = 1
217         self.fake_session.recv_exit_status.return_value = 127
218
219         result = self.test_client.execute("cmd arg1 'arg2 with space'")
220         self.assertEqual((127, "ok", "error"), result)
221         self.fake_session.exec_command.assert_called_once_with(
222             "cmd arg1 'arg2 with space'")
223
224     @mock.patch("yardstick.ssh.select")
225     def test_run(self, mock_select):
226         mock_select.select.return_value = ([], [], [])
227         self.assertEqual(0, self.test_client.run("cmd"))
228
229     @mock.patch("yardstick.ssh.select")
230     def test_run_nonzero_status(self, mock_select):
231         mock_select.select.return_value = ([], [], [])
232         self.fake_session.recv_exit_status.return_value = 1
233         self.assertRaises(ssh.SSHError, self.test_client.run, "cmd")
234         self.assertEqual(1, self.test_client.run("cmd", raise_on_error=False))
235
236     @mock.patch("yardstick.ssh.select")
237     def test_run_stdout(self, mock_select):
238         mock_select.select.return_value = ([], [], [])
239         self.fake_session.recv_ready.side_effect = [True, True, False]
240         self.fake_session.recv.side_effect = ["ok1", "ok2"]
241         stdout = mock.Mock()
242         self.test_client.run("cmd", stdout=stdout)
243         self.assertEqual([mock.call("ok1"), mock.call("ok2")],
244                          stdout.write.mock_calls)
245
246     @mock.patch("yardstick.ssh.select")
247     def test_run_stderr(self, mock_select):
248         mock_select.select.return_value = ([], [], [])
249         self.fake_session.recv_stderr_ready.side_effect = [True, False]
250         self.fake_session.recv_stderr.return_value = "error"
251         stderr = mock.Mock()
252         self.test_client.run("cmd", stderr=stderr)
253         stderr.write.assert_called_once_with("error")
254
255     @mock.patch("yardstick.ssh.select")
256     def test_run_stdin(self, mock_select):
257         """Test run method with stdin.
258
259         Third send call was called with "e2" because only 3 bytes was sent
260         by second call. So remainig 2 bytes of "line2" was sent by third call.
261         """
262         mock_select.select.return_value = ([], [], [])
263         self.fake_session.exit_status_ready.side_effect = [0, 0, 0, True]
264         self.fake_session.send_ready.return_value = True
265         self.fake_session.send.side_effect = [5, 3, 2]
266         fake_stdin = mock.Mock()
267         fake_stdin.read.side_effect = ["line1", "line2", ""]
268         fake_stdin.closed = False
269
270         def close():
271             fake_stdin.closed = True
272         fake_stdin.close = mock.Mock(side_effect=close)
273         self.test_client.run("cmd", stdin=fake_stdin)
274         call = mock.call
275         send_calls = [call("line1"), call("line2"), call("e2")]
276         self.assertEqual(send_calls, self.fake_session.send.mock_calls)
277
278     @mock.patch("yardstick.ssh.select")
279     def test_run_select_error(self, mock_select):
280         self.fake_session.exit_status_ready.return_value = False
281         mock_select.select.return_value = ([], [], [True])
282         self.assertRaises(ssh.SSHError, self.test_client.run, "cmd")
283
284     @mock.patch("yardstick.ssh.time")
285     @mock.patch("yardstick.ssh.select")
286     def test_run_timemout(self, mock_select, mock_time):
287         mock_time.time.side_effect = [1, 3700]
288         mock_select.select.return_value = ([], [], [])
289         self.fake_session.exit_status_ready.return_value = False
290         self.assertRaises(ssh.SSHTimeout, self.test_client.run, "cmd")
291
292     @mock.patch("yardstick.ssh.open", create=True)
293     def test__put_file_shell(self, mock_open):
294         self.test_client.run = mock.Mock()
295         self.test_client._put_file_shell("localfile", "remotefile", 0o42)
296
297         self.test_client.run.assert_called_once_with(
298             'cat > "remotefile"&& chmod -- 042 "remotefile"',
299             stdin=mock_open.return_value.__enter__.return_value)
300
301     @mock.patch("yardstick.ssh.os.stat")
302     def test__put_file_sftp(self, mock_stat):
303         sftp = self.fake_client.open_sftp.return_value = mock.MagicMock()
304         sftp.__enter__.return_value = sftp
305
306         mock_stat.return_value = os.stat_result([0o753] + [0] * 9)
307
308         self.test_client._put_file_sftp("localfile", "remotefile")
309
310         sftp.put.assert_called_once_with("localfile", "remotefile")
311         mock_stat.assert_called_once_with("localfile")
312         sftp.chmod.assert_called_once_with("remotefile", 0o753)
313         sftp.__exit__.assert_called_once_with(None, None, None)
314
315     def test__put_file_sftp_mode(self):
316         sftp = self.fake_client.open_sftp.return_value = mock.MagicMock()
317         sftp.__enter__.return_value = sftp
318
319         self.test_client._put_file_sftp("localfile", "remotefile", mode=0o753)
320
321         sftp.put.assert_called_once_with("localfile", "remotefile")
322         sftp.chmod.assert_called_once_with("remotefile", 0o753)
323         sftp.__exit__.assert_called_once_with(None, None, None)
324
325     def test_put_file_SSHException(self):
326         exc = ssh.paramiko.SSHException
327         self.test_client._put_file_sftp = mock.Mock(side_effect=exc())
328         self.test_client._put_file_shell = mock.Mock()
329
330         self.test_client.put_file("foo", "bar", 42)
331         self.test_client._put_file_sftp.assert_called_once_with("foo", "bar",
332                                                                 mode=42)
333         self.test_client._put_file_shell.assert_called_once_with("foo", "bar",
334                                                                  mode=42)
335
336     def test_put_file_socket_error(self):
337         exc = socket.error
338         self.test_client._put_file_sftp = mock.Mock(side_effect=exc())
339         self.test_client._put_file_shell = mock.Mock()
340
341         self.test_client.put_file("foo", "bar", 42)
342         self.test_client._put_file_sftp.assert_called_once_with("foo", "bar",
343                                                                 mode=42)
344         self.test_client._put_file_shell.assert_called_once_with("foo", "bar",
345                                                                  mode=42)
346
347
348 def main():
349     unittest.main()
350
351 if __name__ == '__main__':
352     main()