Make "IterationIPC" MQ producer for VNF control messages
[yardstick.git] / yardstick / tests / unit / test_ssh.py
index 615783f..b727e82 100644 (file)
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
-# yardstick comment: this file is a modified copy of
-# rally/tests/unit/common/test_sshutils.py
-
-from __future__ import absolute_import
 import os
 import socket
 import unittest
@@ -26,8 +22,8 @@ from itertools import count
 import mock
 from oslo_utils import encodeutils
 
+from yardstick.common import exceptions
 from yardstick import ssh
-from yardstick.ssh import SSHError, SSHTimeout
 from yardstick.ssh import SSH
 from yardstick.ssh import AutoConnectSSH
 
@@ -127,7 +123,7 @@ class SSHTestCase(unittest.TestCase):
         dss = mock_paramiko.dsskey.DSSKey
         rsa.from_private_key.side_effect = mock_paramiko.SSHException
         dss.from_private_key.side_effect = mock_paramiko.SSHException
-        self.assertRaises(ssh.SSHError, self.test_client._get_pkey, "key")
+        self.assertRaises(exceptions.SSHError, self.test_client._get_pkey, "key")
 
     @mock.patch("yardstick.ssh.six.moves.StringIO")
     @mock.patch("yardstick.ssh.paramiko")
@@ -194,13 +190,13 @@ class SSHTestCase(unittest.TestCase):
 
         test_ssh = ssh.SSH("admin", "example.net", pkey="key")
 
-        with self.assertRaises(SSHError) as raised:
+        with self.assertRaises(exceptions.SSHError) as raised:
             test_ssh._get_client()
 
-        self.assertEqual(mock_paramiko.SSHClient.call_count, 1)
-        self.assertEqual(mock_paramiko.AutoAddPolicy.call_count, 1)
-        self.assertEqual(fake_client.set_missing_host_key_policy.call_count, 1)
-        self.assertEqual(fake_client.connect.call_count, 1)
+        mock_paramiko.SSHClient.assert_called_once()
+        mock_paramiko.AutoAddPolicy.assert_called_once()
+        fake_client.set_missing_host_key_policy.assert_called_once()
+        fake_client.connect.assert_called_once()
         exc_str = str(raised.exception)
         self.assertIn('raised during connect', exc_str)
         self.assertIn('MyError', exc_str)
@@ -242,21 +238,40 @@ class SSHTestCase(unittest.TestCase):
         self.assertEqual("stdout fake data", stdout)
         self.assertEqual("stderr fake data", stderr)
 
+    @mock.patch("yardstick.ssh.six.moves.StringIO")
+    def test_execute_raise_on_error_passed(self, mock_string_io):
+        mock_string_io.side_effect = stdio = [mock.Mock(), mock.Mock()]
+        stdio[0].read.return_value = "stdout fake data"
+        stdio[1].read.return_value = "stderr fake data"
+        with mock.patch.object(self.test_client, "run", return_value=0) \
+                as mock_run:
+            status, stdout, stderr = self.test_client.execute(
+                "cmd",
+                stdin="fake_stdin",
+                timeout=43,
+                raise_on_error=True)
+        mock_run.assert_called_once_with(
+            "cmd", stdin="fake_stdin", stdout=stdio[0],
+            stderr=stdio[1], timeout=43, raise_on_error=True)
+        self.assertEqual(0, status)
+        self.assertEqual("stdout fake data", stdout)
+        self.assertEqual("stderr fake data", stderr)
+
     @mock.patch("yardstick.ssh.time")
     def test_wait_timeout(self, mock_time):
         mock_time.time.side_effect = [1, 50, 150]
-        self.test_client.execute = mock.Mock(side_effect=[ssh.SSHError,
-                                                          ssh.SSHError,
+        self.test_client.execute = mock.Mock(side_effect=[exceptions.SSHError,
+                                                          exceptions.SSHError,
                                                           0])
-        self.assertRaises(ssh.SSHTimeout, self.test_client.wait)
+        self.assertRaises(exceptions.SSHTimeout, self.test_client.wait)
         self.assertEqual([mock.call("uname")] * 2,
                          self.test_client.execute.mock_calls)
 
     @mock.patch("yardstick.ssh.time")
     def test_wait(self, mock_time):
         mock_time.time.side_effect = [1, 50, 100]
-        self.test_client.execute = mock.Mock(side_effect=[ssh.SSHError,
-                                                          ssh.SSHError,
+        self.test_client.execute = mock.Mock(side_effect=[exceptions.SSHError,
+                                                          exceptions.SSHError,
                                                           0])
         self.test_client.wait()
         self.assertEqual([mock.call("uname")] * 3,
@@ -333,7 +348,7 @@ class SSHRunTestCase(unittest.TestCase):
     def test_run_nonzero_status(self, mock_select):
         mock_select.select.return_value = ([], [], [])
         self.fake_session.recv_exit_status.return_value = 1
-        self.assertRaises(ssh.SSHError, self.test_client.run, "cmd")
+        self.assertRaises(exceptions.SSHError, self.test_client.run, "cmd")
         self.assertEqual(1, self.test_client.run("cmd", raise_on_error=False))
 
     @mock.patch("yardstick.ssh.select")
@@ -401,7 +416,7 @@ class SSHRunTestCase(unittest.TestCase):
     def test_run_select_error(self, mock_select):
         self.fake_session.exit_status_ready.return_value = False
         mock_select.select.return_value = ([], [], [True])
-        self.assertRaises(ssh.SSHError, self.test_client.run, "cmd")
+        self.assertRaises(exceptions.SSHError, self.test_client.run, "cmd")
 
     @mock.patch("yardstick.ssh.time")
     @mock.patch("yardstick.ssh.select")
@@ -409,7 +424,7 @@ class SSHRunTestCase(unittest.TestCase):
         mock_time.time.side_effect = [1, 3700]
         mock_select.select.return_value = ([], [], [])
         self.fake_session.exit_status_ready.return_value = False
-        self.assertRaises(ssh.SSHTimeout, self.test_client.run, "cmd")
+        self.assertRaises(exceptions.SSHTimeout, self.test_client.run, "cmd")
 
     @mock.patch("yardstick.ssh.open", create=True)
     def test__put_file_shell(self, mock_open):
@@ -514,7 +529,7 @@ class TestAutoConnectSSH(unittest.TestCase):
         auto_connect_ssh._get_client = mock__get_client = mock.Mock()
 
         auto_connect_ssh._connect()
-        self.assertEqual(mock__get_client.call_count, 1)
+        mock__get_client.assert_called_once()
 
     def test___init___negative(self):
         with self.assertRaises(TypeError):
@@ -529,9 +544,9 @@ class TestAutoConnectSSH(unittest.TestCase):
 
         auto_connect_ssh = AutoConnectSSH('user1', 'host1', wait=10)
         auto_connect_ssh._get_client = mock__get_client = mock.Mock()
-        mock__get_client.side_effect = SSHError
+        mock__get_client.side_effect = exceptions.SSHError
 
-        with self.assertRaises(SSHTimeout):
+        with self.assertRaises(exceptions.SSHTimeout):
             auto_connect_ssh._connect()
 
         self.assertEqual(mock_time.time.call_count, 12)
@@ -547,7 +562,7 @@ class TestAutoConnectSSH(unittest.TestCase):
 
         auto_connect_ssh.get_file_obj('remote/path', mock.Mock())
 
-        self.assertEqual(mock_sftp.getfo.call_count, 1)
+        mock_sftp.getfo.assert_called_once()
 
     def test__make_dict(self):
         auto_connect_ssh = AutoConnectSSH('user1', 'host1')
@@ -584,7 +599,7 @@ class TestAutoConnectSSH(unittest.TestCase):
 
         auto_connect_ssh.put('a', 'z')
         with mock_scp_client_type() as mock_scp_client:
-            self.assertEqual(mock_scp_client.put.call_count, 1)
+            mock_scp_client.put.assert_called_once()
 
     @mock.patch('yardstick.ssh.SCPClient')
     def test_get(self, mock_scp_client_type):
@@ -593,7 +608,7 @@ class TestAutoConnectSSH(unittest.TestCase):
 
         auto_connect_ssh.get('a', 'z')
         with mock_scp_client_type() as mock_scp_client:
-            self.assertEqual(mock_scp_client.get.call_count, 1)
+            mock_scp_client.get.assert_called_once()
 
     def test_put_file(self):
         auto_connect_ssh = AutoConnectSSH('user1', 'host1')
@@ -601,12 +616,4 @@ class TestAutoConnectSSH(unittest.TestCase):
         auto_connect_ssh._put_file_sftp = mock_put_sftp = mock.Mock()
 
         auto_connect_ssh.put_file('a', 'b')
-        self.assertEqual(mock_put_sftp.call_count, 1)
-
-
-def main():
-    unittest.main()
-
-
-if __name__ == '__main__':
-    main()
+        mock_put_sftp.assert_called_once()