Add RFC2544 iteration status field
[yardstick.git] / yardstick / tests / unit / test_ssh.py
index 080d278..71929f1 100644 (file)
@@ -193,10 +193,10 @@ class SSHTestCase(unittest.TestCase):
         with self.assertRaises(exceptions.SSHError) as raised:
             test_ssh._get_client()
 
-        self.assertEqual(mock_paramiko.SSHClient.call_count, 1)
-        self.assertEqual(mock_paramiko.AutoAddPolicy.call_count, 1)
-        self.assertEqual(fake_client.set_missing_host_key_policy.call_count, 1)
-        self.assertEqual(fake_client.connect.call_count, 1)
+        mock_paramiko.SSHClient.assert_called_once()
+        mock_paramiko.AutoAddPolicy.assert_called_once()
+        fake_client.set_missing_host_key_policy.assert_called_once()
+        fake_client.connect.assert_called_once()
         exc_str = str(raised.exception)
         self.assertIn('raised during connect', exc_str)
         self.assertIn('MyError', exc_str)
@@ -238,6 +238,25 @@ class SSHTestCase(unittest.TestCase):
         self.assertEqual("stdout fake data", stdout)
         self.assertEqual("stderr fake data", stderr)
 
+    @mock.patch("yardstick.ssh.six.moves.StringIO")
+    def test_execute_raise_on_error_passed(self, mock_string_io):
+        mock_string_io.side_effect = stdio = [mock.Mock(), mock.Mock()]
+        stdio[0].read.return_value = "stdout fake data"
+        stdio[1].read.return_value = "stderr fake data"
+        with mock.patch.object(self.test_client, "run", return_value=0) \
+                as mock_run:
+            status, stdout, stderr = self.test_client.execute(
+                "cmd",
+                stdin="fake_stdin",
+                timeout=43,
+                raise_on_error=True)
+        mock_run.assert_called_once_with(
+            "cmd", stdin="fake_stdin", stdout=stdio[0],
+            stderr=stdio[1], timeout=43, raise_on_error=True)
+        self.assertEqual(0, status)
+        self.assertEqual("stdout fake data", stdout)
+        self.assertEqual("stderr fake data", stderr)
+
     @mock.patch("yardstick.ssh.time")
     def test_wait_timeout(self, mock_time):
         mock_time.time.side_effect = [1, 50, 150]
@@ -510,7 +529,7 @@ class TestAutoConnectSSH(unittest.TestCase):
         auto_connect_ssh._get_client = mock__get_client = mock.Mock()
 
         auto_connect_ssh._connect()
-        self.assertEqual(mock__get_client.call_count, 1)
+        mock__get_client.assert_called_once()
 
     def test___init___negative(self):
         with self.assertRaises(TypeError):
@@ -543,7 +562,7 @@ class TestAutoConnectSSH(unittest.TestCase):
 
         auto_connect_ssh.get_file_obj('remote/path', mock.Mock())
 
-        self.assertEqual(mock_sftp.getfo.call_count, 1)
+        mock_sftp.getfo.assert_called_once()
 
     def test__make_dict(self):
         auto_connect_ssh = AutoConnectSSH('user1', 'host1')
@@ -580,7 +599,7 @@ class TestAutoConnectSSH(unittest.TestCase):
 
         auto_connect_ssh.put('a', 'z')
         with mock_scp_client_type() as mock_scp_client:
-            self.assertEqual(mock_scp_client.put.call_count, 1)
+            mock_scp_client.put.assert_called_once()
 
     @mock.patch('yardstick.ssh.SCPClient')
     def test_get(self, mock_scp_client_type):
@@ -589,7 +608,7 @@ class TestAutoConnectSSH(unittest.TestCase):
 
         auto_connect_ssh.get('a', 'z')
         with mock_scp_client_type() as mock_scp_client:
-            self.assertEqual(mock_scp_client.get.call_count, 1)
+            mock_scp_client.get.assert_called_once()
 
     def test_put_file(self):
         auto_connect_ssh = AutoConnectSSH('user1', 'host1')
@@ -597,4 +616,27 @@ class TestAutoConnectSSH(unittest.TestCase):
         auto_connect_ssh._put_file_sftp = mock_put_sftp = mock.Mock()
 
         auto_connect_ssh.put_file('a', 'b')
-        self.assertEqual(mock_put_sftp.call_count, 1)
+        mock_put_sftp.assert_called_once()
+
+    def test_execute(self):
+        auto_connect_ssh = AutoConnectSSH('user1', 'host1')
+        auto_connect_ssh._client = mock.Mock()
+        auto_connect_ssh.run = mock.Mock(return_value=0)
+        exit_code, _, _ = auto_connect_ssh.execute('')
+        self.assertEqual(exit_code, 0)
+
+    def _mock_run(self, *args, **kwargs):
+        if args[0] == 'ls':
+            if kwargs.get('raise_on_error'):
+                raise exceptions.SSHError(error_msg='Command error')
+            return 1
+        return 0
+
+    def test_execute_command_error(self):
+        auto_connect_ssh = AutoConnectSSH('user1', 'host1')
+        auto_connect_ssh._client = mock.Mock()
+        auto_connect_ssh.run = mock.Mock(side_effect=self._mock_run)
+        self.assertRaises(exceptions.SSHError, auto_connect_ssh.execute, 'ls',
+                          raise_on_error=True)
+        exit_code, _, _ = auto_connect_ssh.execute('ls')
+        self.assertNotEqual(exit_code, 0)