Merge "test_kubernetes: mock file operations in test_ssh_key"
[yardstick.git] / tests / unit / test_ssh.py
index 574da03..27ed68c 100644 (file)
 # yardstick comment: this file is a modified copy of
 # rally/tests/unit/common/test_sshutils.py
 
+from __future__ import absolute_import
 import os
+import socket
 import unittest
+from io import StringIO
+
 import mock
+from oslo_utils import encodeutils
 
 from yardstick import ssh
+from yardstick.ssh import SSHError
+from yardstick.ssh import SSH
+from yardstick.ssh import AutoConnectSSH
 
 
 class FakeParamikoException(Exception):
@@ -47,6 +55,62 @@ class SSHTestCase(unittest.TestCase):
         self.assertEqual("kf", test_ssh.key_filename)
         self.assertEqual("secret", test_ssh.password)
 
+    @mock.patch("yardstick.ssh.SSH._get_pkey")
+    def test_ssh_from_node(self, mock_ssh__get_pkey):
+        mock_ssh__get_pkey.return_value = "pkey"
+        node = {
+            "user": "root", "ip": "example.net", "ssh_port": 33,
+            "key_filename": "kf", "password": "secret"
+        }
+        test_ssh = ssh.SSH.from_node(node)
+        self.assertEqual("root", test_ssh.user)
+        self.assertEqual("example.net", test_ssh.host)
+        self.assertEqual(33, test_ssh.port)
+        self.assertEqual("kf", test_ssh.key_filename)
+        self.assertEqual("secret", test_ssh.password)
+
+    @mock.patch("yardstick.ssh.SSH._get_pkey")
+    def test_ssh_from_node_password_default(self, mock_ssh__get_pkey):
+        mock_ssh__get_pkey.return_value = "pkey"
+        node = {
+            "user": "root", "ip": "example.net", "ssh_port": 33,
+            "key_filename": "kf"
+        }
+        test_ssh = ssh.SSH.from_node(node)
+        self.assertEqual("root", test_ssh.user)
+        self.assertEqual("example.net", test_ssh.host)
+        self.assertEqual(33, test_ssh.port)
+        self.assertEqual("kf", test_ssh.key_filename)
+        self.assertIsNone(test_ssh.password)
+
+    @mock.patch("yardstick.ssh.SSH._get_pkey")
+    def test_ssh_from_node_ssh_port_default(self, mock_ssh__get_pkey):
+        mock_ssh__get_pkey.return_value = "pkey"
+        node = {
+            "user": "root", "ip": "example.net",
+            "key_filename": "kf", "password": "secret"
+        }
+        test_ssh = ssh.SSH.from_node(node)
+        self.assertEqual("root", test_ssh.user)
+        self.assertEqual("example.net", test_ssh.host)
+        self.assertEqual(ssh.SSH.SSH_PORT, test_ssh.port)
+        self.assertEqual("kf", test_ssh.key_filename)
+        self.assertEqual("secret", test_ssh.password)
+
+    @mock.patch("yardstick.ssh.SSH._get_pkey")
+    def test_ssh_from_node_key_filename_default(self, mock_ssh__get_pkey):
+        mock_ssh__get_pkey.return_value = "pkey"
+        node = {
+            "user": "root", "ip": "example.net", "ssh_port": 33,
+            "password": "secret"
+        }
+        test_ssh = ssh.SSH.from_node(node)
+        self.assertEqual("root", test_ssh.user)
+        self.assertEqual("example.net", test_ssh.host)
+        self.assertEqual(33, test_ssh.port)
+        self.assertIsNone(test_ssh.key_filename)
+        self.assertEqual("secret", test_ssh.password)
+
     def test_construct_default(self):
         self.assertEqual("root", self.test_client.user)
         self.assertEqual("example.net", self.test_client.host)
@@ -108,10 +172,51 @@ class SSHTestCase(unittest.TestCase):
             mock.call.set_missing_host_key_policy("autoadd"),
             mock.call.connect("example.net", username="admin",
                               port=22, pkey="key", key_filename=None,
-                              password=None, timeout=1),
+                              password=None,
+                              allow_agent=False, look_for_keys=False,
+                              timeout=1),
         ]
         self.assertEqual(client_calls, client.mock_calls)
 
+    @mock.patch("yardstick.ssh.SSH._get_pkey")
+    @mock.patch("yardstick.ssh.paramiko")
+    def test__get_client_with_exception(self, mock_paramiko, mock_ssh__get_pkey):
+        class MyError(Exception):
+            pass
+
+        mock_ssh__get_pkey.return_value = "pkey"
+        fake_client = mock.Mock()
+        fake_client.connect.side_effect = MyError
+        fake_client.set_missing_host_key_policy.return_value = None
+        mock_paramiko.SSHClient.return_value = fake_client
+        mock_paramiko.AutoAddPolicy.return_value = "autoadd"
+
+        test_ssh = ssh.SSH("admin", "example.net", pkey="key")
+
+        with self.assertRaises(SSHError) as raised:
+            test_ssh._get_client()
+
+        self.assertEqual(mock_paramiko.SSHClient.call_count, 1)
+        self.assertEqual(mock_paramiko.AutoAddPolicy.call_count, 1)
+        self.assertEqual(fake_client.set_missing_host_key_policy.call_count, 1)
+        self.assertEqual(fake_client.connect.call_count, 1)
+        exc_str = str(raised.exception)
+        self.assertIn('raised during connect', exc_str)
+        self.assertIn('MyError', exc_str)
+
+    @mock.patch("yardstick.ssh.SSH._get_pkey")
+    @mock.patch("yardstick.ssh.paramiko")
+    def test_copy(self, mock_paramiko, mock_ssh__get_pkey):
+        mock_ssh__get_pkey.return_value = "pkey"
+        fake_client = mock.Mock()
+        fake_client.connect.side_effect = IOError
+        mock_paramiko.SSHClient.return_value = fake_client
+        mock_paramiko.AutoAddPolicy.return_value = "autoadd"
+
+        test_ssh = ssh.SSH("admin", "example.net", pkey="key")
+        result = test_ssh.copy()
+        self.assertIsNot(test_ssh, result)
+
     def test_close(self):
         with mock.patch.object(self.test_client, "_client") as m_client:
             self.test_client.close()
@@ -160,10 +265,10 @@ class SSHTestCase(unittest.TestCase):
     def test_send_command(self, mock_paramiko):
         paramiko_sshclient = self.test_client._get_client()
         with mock.patch.object(paramiko_sshclient, "exec_command") \
-            as mock_paramiko_exec_command:
+                as mock_paramiko_exec_command:
             self.test_client.send_command('cmd')
         mock_paramiko_exec_command.assert_called_once_with('cmd',
-                                                            get_pty=True)
+                                                           get_pty=True)
 
 
 class SSHRunTestCase(unittest.TestCase):
@@ -269,7 +374,26 @@ class SSHRunTestCase(unittest.TestCase):
         fake_stdin.close = mock.Mock(side_effect=close)
         self.test_client.run("cmd", stdin=fake_stdin)
         call = mock.call
-        send_calls = [call("line1"), call("line2"), call("e2")]
+        send_calls = [call(encodeutils.safe_encode("line1", "utf-8")),
+                      call(encodeutils.safe_encode("line2", "utf-8")),
+                      call(encodeutils.safe_encode("e2", "utf-8"))]
+        self.assertEqual(send_calls, self.fake_session.send.mock_calls)
+
+    @mock.patch("yardstick.ssh.select")
+    def test_run_stdin_keep_open(self, mock_select):
+        """Test run method with stdin.
+
+        Third send call was called with "e2" because only 3 bytes was sent
+        by second call. So remainig 2 bytes of "line2" was sent by third call.
+        """
+        mock_select.select.return_value = ([], [], [])
+        self.fake_session.exit_status_ready.side_effect = [0, 0, 0, True]
+        self.fake_session.send_ready.return_value = True
+        self.fake_session.send.side_effect = len
+        fake_stdin = StringIO(u"line1\nline2\n")
+        self.test_client.run("cmd", stdin=fake_stdin, keep_stdin_open=True)
+        call = mock.call
+        send_calls = [call(encodeutils.safe_encode("line1\nline2\n", "utf-8"))]
         self.assertEqual(send_calls, self.fake_session.send.mock_calls)
 
     @mock.patch("yardstick.ssh.select")
@@ -286,9 +410,154 @@ class SSHRunTestCase(unittest.TestCase):
         self.fake_session.exit_status_ready.return_value = False
         self.assertRaises(ssh.SSHTimeout, self.test_client.run, "cmd")
 
+    @mock.patch("yardstick.ssh.open", create=True)
+    def test__put_file_shell(self, mock_open):
+        with mock.patch.object(self.test_client, "run") as run_mock:
+            self.test_client._put_file_shell("localfile", "remotefile", 0o42)
+            run_mock.assert_called_once_with(
+                'cat > "remotefile"&& chmod -- 042 "remotefile"',
+                stdin=mock_open.return_value.__enter__.return_value)
+
+    @mock.patch("yardstick.ssh.open", create=True)
+    def test__put_file_shell_space(self, mock_open):
+        with mock.patch.object(self.test_client, "run") as run_mock:
+            self.test_client._put_file_shell("localfile",
+                                             "filename with space", 0o42)
+            run_mock.assert_called_once_with(
+                'cat > "filename with space"&& chmod -- 042 "filename with '
+                'space"',
+                stdin=mock_open.return_value.__enter__.return_value)
+
+    @mock.patch("yardstick.ssh.open", create=True)
+    def test__put_file_shell_tilde(self, mock_open):
+        with mock.patch.object(self.test_client, "run") as run_mock:
+            self.test_client._put_file_shell("localfile", "~/remotefile", 0o42)
+            run_mock.assert_called_once_with(
+                'cat > ~/"remotefile"&& chmod -- 042 ~/"remotefile"',
+                stdin=mock_open.return_value.__enter__.return_value)
+
+    @mock.patch("yardstick.ssh.open", create=True)
+    def test__put_file_shell_tilde_spaces(self, mock_open):
+        with mock.patch.object(self.test_client, "run") as run_mock:
+            self.test_client._put_file_shell("localfile", "~/file with space",
+                                             0o42)
+            run_mock.assert_called_once_with(
+                'cat > ~/"file with space"&& chmod -- 042 ~/"file with space"',
+                stdin=mock_open.return_value.__enter__.return_value)
+
+    @mock.patch("yardstick.ssh.os.stat")
+    def test__put_file_sftp(self, mock_stat):
+        sftp = self.fake_client.open_sftp.return_value = mock.MagicMock()
+        sftp.__enter__.return_value = sftp
+
+        mock_stat.return_value = os.stat_result([0o753] + [0] * 9)
+
+        self.test_client._put_file_sftp("localfile", "remotefile")
+
+        sftp.put.assert_called_once_with("localfile", "remotefile")
+        mock_stat.assert_any_call("localfile")
+        sftp.chmod.assert_any_call("remotefile", 0o753)
+        sftp.__exit__.assert_called_once_with(None, None, None)
+
+    def test__put_file_sftp_mode(self):
+        sftp = self.fake_client.open_sftp.return_value = mock.MagicMock()
+        sftp.__enter__.return_value = sftp
+
+        self.test_client._put_file_sftp("localfile", "remotefile", mode=0o753)
+
+        sftp.put.assert_called_once_with("localfile", "remotefile")
+        sftp.chmod.assert_called_once_with("remotefile", 0o753)
+        sftp.__exit__.assert_called_once_with(None, None, None)
+
+    def test_put_file_SSHException(self):
+        exc = ssh.paramiko.SSHException
+        self.test_client._put_file_sftp = mock.Mock(side_effect=exc())
+        self.test_client._put_file_shell = mock.Mock()
+
+        self.test_client.put_file("foo", "bar", 42)
+        self.test_client._put_file_sftp.assert_called_once_with("foo", "bar",
+                                                                mode=42)
+        self.test_client._put_file_shell.assert_called_once_with("foo", "bar",
+                                                                 mode=42)
+
+    def test_put_file_socket_error(self):
+        exc = socket.error
+        self.test_client._put_file_sftp = mock.Mock(side_effect=exc())
+        self.test_client._put_file_shell = mock.Mock()
+
+        self.test_client.put_file("foo", "bar", 42)
+        self.test_client._put_file_sftp.assert_called_once_with("foo", "bar",
+                                                                mode=42)
+        self.test_client._put_file_shell.assert_called_once_with("foo", "bar",
+                                                                 mode=42)
+
+    @mock.patch("yardstick.ssh.os.stat")
+    def test_put_file_obj_with_mode(self, mock_stat):
+        sftp = self.fake_client.open_sftp.return_value = mock.MagicMock()
+        sftp.__enter__.return_value = sftp
+
+        mock_stat.return_value = os.stat_result([0o753] + [0] * 9)
+
+        self.test_client.put_file_obj("localfile", "remotefile", 'my_mode')
+
+        sftp.__enter__.assert_called_once()
+        sftp.putfo.assert_called_once_with("localfile", "remotefile")
+        sftp.chmod.assert_called_once_with("remotefile", 'my_mode')
+        sftp.__exit__.assert_called_once_with(None, None, None)
+
+
+class TestAutoConnectSSH(unittest.TestCase):
+
+    def test__connect_with_wait(self):
+        auto_connect_ssh = AutoConnectSSH('user1', 'host1', wait=True)
+        auto_connect_ssh._get_client = mock.Mock()
+        auto_connect_ssh.wait = mock_wait = mock.Mock()
+
+        auto_connect_ssh._connect()
+        self.assertEqual(mock_wait.call_count, 1)
+
+    def test__make_dict(self):
+        auto_connect_ssh = AutoConnectSSH('user1', 'host1')
+
+        expected = {
+            'user': 'user1',
+            'host': 'host1',
+            'port': SSH.SSH_PORT,
+            'pkey': None,
+            'key_filename': None,
+            'password': None,
+            'name': None,
+            'wait': False,
+        }
+        result = auto_connect_ssh._make_dict()
+        self.assertDictEqual(result, expected)
+
+    def test_get_class(self):
+        auto_connect_ssh = AutoConnectSSH('user1', 'host1')
+
+        self.assertEqual(auto_connect_ssh.get_class(), AutoConnectSSH)
+
+    @mock.patch('yardstick.ssh.SCPClient')
+    def test_put(self, mock_scp_client_type):
+        auto_connect_ssh = AutoConnectSSH('user1', 'host1')
+        auto_connect_ssh._client = mock.Mock()
+
+        auto_connect_ssh.put('a', 'z')
+        with mock_scp_client_type() as mock_scp_client:
+            self.assertEqual(mock_scp_client.put.call_count, 1)
+
+    def test_put_file(self):
+        auto_connect_ssh = AutoConnectSSH('user1', 'host1')
+        auto_connect_ssh._client = mock.Mock()
+        auto_connect_ssh._put_file_sftp = mock_put_sftp = mock.Mock()
+
+        auto_connect_ssh.put_file('a', 'b')
+        self.assertEqual(mock_put_sftp.call_count, 1)
+
 
 def main():
     unittest.main()
 
+
 if __name__ == '__main__':
     main()