Merge changes from topics 'YARDSTICK-1218', 'YARDSTICK-1216', 'YARDSTICK-1215', ...
[yardstick.git] / yardstick / tests / unit / test_ssh.py
index dbaae8c..080d278 100644 (file)
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
-# yardstick comment: this file is a modified copy of
-# rally/tests/unit/common/test_sshutils.py
-
-from __future__ import absolute_import
 import os
 import socket
 import unittest
 from io import StringIO
+from itertools import count
 
 import mock
 from oslo_utils import encodeutils
 
+from yardstick.common import exceptions
 from yardstick import ssh
-from yardstick.ssh import SSHError
 from yardstick.ssh import SSH
 from yardstick.ssh import AutoConnectSSH
 
@@ -126,7 +123,7 @@ class SSHTestCase(unittest.TestCase):
         dss = mock_paramiko.dsskey.DSSKey
         rsa.from_private_key.side_effect = mock_paramiko.SSHException
         dss.from_private_key.side_effect = mock_paramiko.SSHException
-        self.assertRaises(ssh.SSHError, self.test_client._get_pkey, "key")
+        self.assertRaises(exceptions.SSHError, self.test_client._get_pkey, "key")
 
     @mock.patch("yardstick.ssh.six.moves.StringIO")
     @mock.patch("yardstick.ssh.paramiko")
@@ -193,7 +190,7 @@ class SSHTestCase(unittest.TestCase):
 
         test_ssh = ssh.SSH("admin", "example.net", pkey="key")
 
-        with self.assertRaises(SSHError) as raised:
+        with self.assertRaises(exceptions.SSHError) as raised:
             test_ssh._get_client()
 
         self.assertEqual(mock_paramiko.SSHClient.call_count, 1)
@@ -244,18 +241,18 @@ class SSHTestCase(unittest.TestCase):
     @mock.patch("yardstick.ssh.time")
     def test_wait_timeout(self, mock_time):
         mock_time.time.side_effect = [1, 50, 150]
-        self.test_client.execute = mock.Mock(side_effect=[ssh.SSHError,
-                                                          ssh.SSHError,
+        self.test_client.execute = mock.Mock(side_effect=[exceptions.SSHError,
+                                                          exceptions.SSHError,
                                                           0])
-        self.assertRaises(ssh.SSHTimeout, self.test_client.wait)
+        self.assertRaises(exceptions.SSHTimeout, self.test_client.wait)
         self.assertEqual([mock.call("uname")] * 2,
                          self.test_client.execute.mock_calls)
 
     @mock.patch("yardstick.ssh.time")
     def test_wait(self, mock_time):
         mock_time.time.side_effect = [1, 50, 100]
-        self.test_client.execute = mock.Mock(side_effect=[ssh.SSHError,
-                                                          ssh.SSHError,
+        self.test_client.execute = mock.Mock(side_effect=[exceptions.SSHError,
+                                                          exceptions.SSHError,
                                                           0])
         self.test_client.wait()
         self.assertEqual([mock.call("uname")] * 3,
@@ -332,7 +329,7 @@ class SSHRunTestCase(unittest.TestCase):
     def test_run_nonzero_status(self, mock_select):
         mock_select.select.return_value = ([], [], [])
         self.fake_session.recv_exit_status.return_value = 1
-        self.assertRaises(ssh.SSHError, self.test_client.run, "cmd")
+        self.assertRaises(exceptions.SSHError, self.test_client.run, "cmd")
         self.assertEqual(1, self.test_client.run("cmd", raise_on_error=False))
 
     @mock.patch("yardstick.ssh.select")
@@ -400,7 +397,7 @@ class SSHRunTestCase(unittest.TestCase):
     def test_run_select_error(self, mock_select):
         self.fake_session.exit_status_ready.return_value = False
         mock_select.select.return_value = ([], [], [True])
-        self.assertRaises(ssh.SSHError, self.test_client.run, "cmd")
+        self.assertRaises(exceptions.SSHError, self.test_client.run, "cmd")
 
     @mock.patch("yardstick.ssh.time")
     @mock.patch("yardstick.ssh.select")
@@ -408,7 +405,7 @@ class SSHRunTestCase(unittest.TestCase):
         mock_time.time.side_effect = [1, 3700]
         mock_select.select.return_value = ([], [], [])
         self.fake_session.exit_status_ready.return_value = False
-        self.assertRaises(ssh.SSHTimeout, self.test_client.run, "cmd")
+        self.assertRaises(exceptions.SSHTimeout, self.test_client.run, "cmd")
 
     @mock.patch("yardstick.ssh.open", create=True)
     def test__put_file_shell(self, mock_open):
@@ -508,13 +505,45 @@ class SSHRunTestCase(unittest.TestCase):
 
 class TestAutoConnectSSH(unittest.TestCase):
 
-    def test__connect_with_wait(self):
-        auto_connect_ssh = AutoConnectSSH('user1', 'host1', wait=True)
-        auto_connect_ssh._get_client = mock.Mock()
-        auto_connect_ssh.wait = mock_wait = mock.Mock()
+    def test__connect_loop(self):
+        auto_connect_ssh = AutoConnectSSH('user1', 'host1', wait=0)
+        auto_connect_ssh._get_client = mock__get_client = mock.Mock()
 
         auto_connect_ssh._connect()
-        self.assertEqual(mock_wait.call_count, 1)
+        self.assertEqual(mock__get_client.call_count, 1)
+
+    def test___init___negative(self):
+        with self.assertRaises(TypeError):
+            AutoConnectSSH('user1', 'host1', wait=['wait'])
+
+        with self.assertRaises(ValueError):
+            AutoConnectSSH('user1', 'host1', wait='wait')
+
+    @mock.patch('yardstick.ssh.time')
+    def test__connect_loop_ssh_error(self, mock_time):
+        mock_time.time.side_effect = count()
+
+        auto_connect_ssh = AutoConnectSSH('user1', 'host1', wait=10)
+        auto_connect_ssh._get_client = mock__get_client = mock.Mock()
+        mock__get_client.side_effect = exceptions.SSHError
+
+        with self.assertRaises(exceptions.SSHTimeout):
+            auto_connect_ssh._connect()
+
+        self.assertEqual(mock_time.time.call_count, 12)
+
+    def test_get_file_obj(self):
+        auto_connect_ssh = AutoConnectSSH('user1', 'host1', wait=10)
+        auto_connect_ssh._get_client = mock__get_client = mock.Mock()
+        mock_client = mock__get_client()
+        mock_open_sftp = mock_client.open_sftp()
+        mock_sftp = mock.Mock()
+        mock_open_sftp.__enter__ = mock.Mock(return_value=mock_sftp)
+        mock_open_sftp.__exit__ = mock.Mock()
+
+        auto_connect_ssh.get_file_obj('remote/path', mock.Mock())
+
+        self.assertEqual(mock_sftp.getfo.call_count, 1)
 
     def test__make_dict(self):
         auto_connect_ssh = AutoConnectSSH('user1', 'host1')
@@ -527,7 +556,7 @@ class TestAutoConnectSSH(unittest.TestCase):
             'key_filename': None,
             'password': None,
             'name': None,
-            'wait': True,
+            'wait': AutoConnectSSH.DEFAULT_WAIT_TIMEOUT,
         }
         result = auto_connect_ssh._make_dict()
         self.assertDictEqual(result, expected)
@@ -537,6 +566,13 @@ class TestAutoConnectSSH(unittest.TestCase):
 
         self.assertEqual(auto_connect_ssh.get_class(), AutoConnectSSH)
 
+    def test_drop_connection(self):
+        auto_connect_ssh = AutoConnectSSH('user1', 'host1')
+        self.assertFalse(auto_connect_ssh._client)
+        auto_connect_ssh._client = True
+        auto_connect_ssh.drop_connection()
+        self.assertFalse(auto_connect_ssh._client)
+
     @mock.patch('yardstick.ssh.SCPClient')
     def test_put(self, mock_scp_client_type):
         auto_connect_ssh = AutoConnectSSH('user1', 'host1')
@@ -562,11 +598,3 @@ class TestAutoConnectSSH(unittest.TestCase):
 
         auto_connect_ssh.put_file('a', 'b')
         self.assertEqual(mock_put_sftp.call_count, 1)
-
-
-def main():
-    unittest.main()
-
-
-if __name__ == '__main__':
-    main()