Improve NSB Standalone XML generation 53/50653/19
authorRodolfo Alonso Hernandez <rodolfo.alonso.hernandez@intel.com>
Mon, 15 Jan 2018 14:25:46 +0000 (14:25 +0000)
committerRodolfo Alonso Hernandez <rodolfo.alonso.hernandez@intel.com>
Wed, 21 Mar 2018 09:12:49 +0000 (09:12 +0000)
Delayed the generation of the XML file until the last step. The
following functions will return a XML string insted:
- Libvirt.build_vm_xml
- SriovContext._enable_interfaces
- OvsDpdkContext._enable_interfaces

The XML file will be written just before copying the file to the
compute hosting the VMs.

JIRA: YARDSTICK-939

Change-Id: Icc80f4741903bbe335db4ebccab395b72fa87e82
Signed-off-by: Rodolfo Alonso Hernandez <rodolfo.alonso.hernandez@intel.com>
yardstick/benchmark/contexts/standalone/model.py
yardstick/benchmark/contexts/standalone/ovs_dpdk.py
yardstick/benchmark/contexts/standalone/sriov.py
yardstick/tests/unit/benchmark/contexts/standalone/test_model.py
yardstick/tests/unit/benchmark/contexts/standalone/test_ovs_dpdk.py
yardstick/tests/unit/benchmark/contexts/standalone/test_sriov.py

index ac81ee7..7eab1c9 100644 (file)
@@ -29,7 +29,7 @@ from yardstick.common import exceptions
 from yardstick.common.yaml_loader import yaml_load
 from yardstick.network_services.utils import PciAddress
 from yardstick.network_services.helpers.cpu import CpuSysCores
-from yardstick.common.utils import write_file
+
 
 LOG = logging.getLogger(__name__)
 
@@ -132,7 +132,7 @@ class Libvirt(object):
         return vm_pci
 
     @classmethod
-    def add_ovs_interface(cls, vpath, port_num, vpci, vports_mac, xml):
+    def add_ovs_interface(cls, vpath, port_num, vpci, vports_mac, xml_str):
         """Add a DPDK OVS 'interface' XML node in 'devices' node
 
         <devices>
@@ -156,7 +156,7 @@ class Libvirt(object):
 
         vhost_path = ('{0}/var/run/openvswitch/dpdkvhostuser{1}'.
                       format(vpath, port_num))
-        root = ET.parse(xml)
+        root = ET.fromstring(xml_str)
         pci_address = PciAddress(vpci.strip())
         device = root.find('devices')
 
@@ -181,10 +181,10 @@ class Libvirt(object):
 
         cls._add_interface_address(interface, pci_address)
 
-        root.write(xml)
+        return ET.tostring(root)
 
     @classmethod
-    def add_sriov_interfaces(cls, vm_pci, vf_pci, vf_mac, xml):
+    def add_sriov_interfaces(cls, vm_pci, vf_pci, vf_mac, xml_str):
         """Add a SR-IOV 'interface' XML node in 'devices' node
 
         <devices>
@@ -207,7 +207,7 @@ class Libvirt(object):
             -sr_iov-how_sr_iov_libvirt_works
         """
 
-        root = ET.parse(xml)
+        root = ET.fromstring(xml_str)
         device = root.find('devices')
 
         interface = ET.SubElement(device, 'interface')
@@ -224,7 +224,7 @@ class Libvirt(object):
         pci_vm_address = PciAddress(vm_pci.strip())
         cls._add_interface_address(interface, pci_vm_address)
 
-        root.write(xml)
+        return ET.tostring(root)
 
     @staticmethod
     def create_snapshot_qemu(connection, index, vm_image):
@@ -237,7 +237,8 @@ class Libvirt(object):
         return image
 
     @classmethod
-    def build_vm_xml(cls, connection, flavor, cfg, vm_name, index):
+    def build_vm_xml(cls, connection, flavor, vm_name, index):
+        """Build the XML from the configuration parameters"""
         memory = flavor.get('ram', '4096')
         extra_spec = flavor.get('extra_specs', {})
         cpu = extra_spec.get('hw:cpu_cores', '2')
@@ -261,9 +262,7 @@ class Libvirt(object):
             socket=socket, threads=threads,
             vm_image=image, cpuset=cpuset, cputune=cputune)
 
-        write_file(cfg, vm_xml)
-
-        return [vcpu, mac]
+        return vm_xml, mac
 
     @staticmethod
     def update_interrupts_hugepages_perf(connection):
@@ -283,6 +282,13 @@ class Libvirt(object):
         cpuset = "%s,%s" % (cores, threads)
         return cpuset
 
+    @classmethod
+    def write_file(cls, file_name, xml_str):
+        """Dump a XML string to a file"""
+        root = ET.fromstring(xml_str)
+        et = ET.ElementTree(element=root)
+        et.write(file_name, encoding='utf-8', method='xml')
+
 
 class StandaloneContextHelper(object):
     """ This class handles all the common code for standalone
index ee0eb9e..30b685e 100644 (file)
@@ -359,7 +359,7 @@ class OvsDpdkContext(Context):
         self.networks = portlist
         LOG.info("Ports %s", self.networks)
 
-    def _enable_interfaces(self, index, vfs, cfg):
+    def _enable_interfaces(self, index, vfs, xml_str):
         vpath = self.ovs_properties.get("vpath", "/usr/local")
         vf = self.networks[vfs[0]]
         port_num = vf.get('port_num', 0)
@@ -368,8 +368,8 @@ class OvsDpdkContext(Context):
         slot = index + port_num + 10
         vf['vpci'] = \
             "{}:{}:{:02x}.{}".format(vpci.domain, vpci.bus, slot, vpci.function)
-        model.Libvirt.add_ovs_interface(
-            vpath, port_num, vf['vpci'], vf['mac'], str(cfg))
+        return model.Libvirt.add_ovs_interface(
+            vpath, port_num, vf['vpci'], vf['mac'], xml_str)
 
     def setup_ovs_dpdk_context(self):
         nodes = []
@@ -384,17 +384,16 @@ class OvsDpdkContext(Context):
             # 1. Check and delete VM if already exists
             model.Libvirt.check_if_vm_exists_and_delete(vm_name,
                                                         self.connection)
+            xml_str, mac = model.Libvirt.build_vm_xml(
+                self.connection, self.vm_flavor, vm_name, index)
 
-            _, mac = model.Libvirt.build_vm_xml(
-                self.connection, self.vm_flavor, cfg, vm_name, index)
             # 2: Cleanup already available VMs
-            for vkey, vfs in collections.OrderedDict(
-                    vnf["network_ports"]).items():
-                if vkey == "mgmt":
-                    continue
-                self._enable_interfaces(index, vfs, cfg)
+            for vfs in [vfs for vfs_name, vfs in vnf["network_ports"].items()
+                        if vfs_name != 'mgmt']:
+                xml_str = self._enable_interfaces(index, vfs, xml_str)
 
             # copy xml to target...
+            model.Libvirt.write_file(cfg, xml_str)
             self.connection.put(cfg, cfg)
 
             # NOTE: launch through libvirt
index d762055..5db419e 100644 (file)
@@ -16,15 +16,12 @@ from __future__ import absolute_import
 import os
 import logging
 import collections
-from collections import OrderedDict
 
 from yardstick import ssh
 from yardstick.network_services.utils import get_nsb_option
 from yardstick.network_services.utils import provision_tool
 from yardstick.benchmark.contexts.base import Context
-from yardstick.benchmark.contexts.standalone.model import Libvirt
-from yardstick.benchmark.contexts.standalone.model import StandaloneContextHelper
-from yardstick.benchmark.contexts.standalone.model import Server
+from yardstick.benchmark.contexts.standalone import model
 from yardstick.network_services.utils import PciAddress
 
 LOG = logging.getLogger(__name__)
@@ -49,8 +46,8 @@ class SriovContext(Context):
         self.attrs = {}
         self.vm_flavor = None
         self.servers = None
-        self.helper = StandaloneContextHelper()
-        self.vnf_node = Server()
+        self.helper = model.StandaloneContextHelper()
+        self.vnf_node = model.Server()
         self.drivers = []
         super(SriovContext, self).__init__()
 
@@ -87,15 +84,14 @@ class SriovContext(Context):
             os.path.join(get_nsb_option("bin_path"), "dpdk-devbind.py"))
 
         #    Todo: NFVi deploy (sriov, vswitch, ovs etc) based on the config.
-        StandaloneContextHelper.install_req_libs(self.connection)
-        self.networks = StandaloneContextHelper.get_nic_details(
+        model.StandaloneContextHelper.install_req_libs(self.connection)
+        self.networks = model.StandaloneContextHelper.get_nic_details(
             self.connection, self.networks, self.dpdk_devbind)
         self.nodes = self.setup_sriov_context()
 
         LOG.debug("Waiting for VM to come up...")
-        self.nodes = StandaloneContextHelper.wait_for_vnfs_to_start(self.connection,
-                                                                    self.servers,
-                                                                    self.nodes)
+        self.nodes = model.StandaloneContextHelper.wait_for_vnfs_to_start(
+            self.connection, self.servers, self.nodes)
 
     def undeploy(self):
         """don't need to undeploy"""
@@ -105,7 +101,7 @@ class SriovContext(Context):
 
         # Todo: NFVi undeploy (sriov, vswitch, ovs etc) based on the config.
         for vm in self.vm_names:
-            Libvirt.check_if_vm_exists_and_delete(vm, self.connection)
+            model.Libvirt.check_if_vm_exists_and_delete(vm, self.connection)
 
         # Bind nics back to kernel
         for ports in self.networks.values():
@@ -136,8 +132,8 @@ class SriovContext(Context):
         except StopIteration:
             pass
         else:
-            raise ValueError("Duplicate nodes!!! Nodes: %s %s" %
-                             (node, duplicate))
+            raise ValueError("Duplicate nodes!!! Nodes: %s %s"
+                             (node, duplicate))
 
         node["name"] = attr_name
         return node
@@ -179,7 +175,7 @@ class SriovContext(Context):
             self.connection.execute(build_vfs.format(ports.get('phy_port')))
 
             # configure VFs...
-            mac = StandaloneContextHelper.get_mac_address()
+            mac = model.StandaloneContextHelper.get_mac_address()
             interface = ports.get('interface')
             if interface is not None:
                 self.connection.execute(vf_cmd.format(interface, mac))
@@ -201,7 +197,7 @@ class SriovContext(Context):
         slot = index + idx + 10
         vf['vpci'] = \
             "{}:{}:{:02x}.{}".format(vpci.domain, vpci.bus, slot, vpci.function)
-        Libvirt.add_sriov_interfaces(
+        model.Libvirt.add_sriov_interfaces(
             vf['vpci'], vf['vf_pci']['vf_pci'], vf['mac'], str(cfg))
         self.connection.execute("ifconfig %s up" % vf['interface'])
         self.connection.execute(vf_spoofchk.format(vf['interface']))
@@ -212,34 +208,37 @@ class SriovContext(Context):
         #   1 : modprobe host_driver with num_vfs
         self.configure_nics_for_sriov()
 
-        for index, (key, vnf) in enumerate(OrderedDict(self.servers).items()):
+        for index, (key, vnf) in enumerate(collections.OrderedDict(
+                self.servers).items()):
             cfg = '/tmp/vm_sriov_%s.xml' % str(index)
             vm_name = "vm_%s" % str(index)
 
             # 1. Check and delete VM if already exists
-            Libvirt.check_if_vm_exists_and_delete(vm_name, self.connection)
+            model.Libvirt.check_if_vm_exists_and_delete(vm_name,
+                                                        self.connection)
+            xml_str, mac = model.Libvirt.build_vm_xml(
+                self.connection, self.vm_flavor, vm_name, index)
 
-            _, mac = Libvirt.build_vm_xml(self.connection, self.vm_flavor, cfg, vm_name, index)
             # 2: Cleanup already available VMs
-            for idx, (vkey, vfs) in enumerate(OrderedDict(vnf["network_ports"]).items()):
-                if vkey == "mgmt":
-                    continue
+            network_ports = collections.OrderedDict(
+                {k: v for k, v in vnf["network_ports"].items() if k != 'mgmt'})
+            for idx, vfs in enumerate(network_ports.values()):
                 self._enable_interfaces(index, idx, vfs, cfg)
 
             # copy xml to target...
+            model.Libvirt.write_file(cfg, xml_str)
             self.connection.put(cfg, cfg)
 
             # NOTE: launch through libvirt
             LOG.info("virsh create ...")
-            Libvirt.virsh_create_vm(self.connection, cfg)
+            model.Libvirt.virsh_create_vm(self.connection, cfg)
 
             self.vm_names.append(vm_name)
 
             # build vnf node details
-            nodes.append(self.vnf_node.generate_vnf_instance(self.vm_flavor,
-                                                             self.networks,
-                                                             self.host_mgmt.get('ip'),
-                                                             key, vnf, mac))
+            nodes.append(self.vnf_node.generate_vnf_instance(
+                self.vm_flavor, self.networks, self.host_mgmt.get('ip'),
+                key, vnf, mac))
 
         return nodes
 
@@ -248,7 +247,8 @@ class SriovContext(Context):
             "mac": vfmac,
             "pf_if": pfif
         }
-        vfs = StandaloneContextHelper.get_virtual_devices(self.connection, value)
+        vfs = model.StandaloneContextHelper.get_virtual_devices(
+            self.connection, value)
         for k, v in vfs.items():
             m = PciAddress(k.strip())
             m1 = PciAddress(value.strip())
index 03f4a12..005deb8 100644 (file)
@@ -14,8 +14,9 @@
 
 import copy
 import os
-import unittest
 import mock
+import unittest
+import uuid
 
 from xml.etree import ElementTree
 
@@ -45,19 +46,9 @@ XML_SAMPLE_INTERFACE = """<?xml version="1.0"?>
 class ModelLibvirtTestCase(unittest.TestCase):
 
     def setUp(self):
-        self.xml = ElementTree.ElementTree(
-            element=ElementTree.fromstring(XML_SAMPLE))
         self.pci_address_str = '0001:04:03.2'
         self.pci_address = utils.PciAddress(self.pci_address_str)
         self.mac = '00:00:00:00:00:01'
-        self._mock_write_xml = mock.patch.object(ElementTree.ElementTree,
-                                                 'write')
-        self.mock_write_xml = self._mock_write_xml.start()
-
-        self.addCleanup(self._cleanup)
-
-    def _cleanup(self):
-        self._mock_write_xml.stop()
 
     # TODO: Remove mocking of yardstick.ssh.SSH (here and elsewhere)
     # In this case, we are mocking a param to be passed into other methods
@@ -102,72 +93,65 @@ class ModelLibvirtTestCase(unittest.TestCase):
                          result.get('function'))
 
     def test_add_ovs_interfaces(self):
-        xml_input = mock.Mock()
-        with mock.patch.object(ElementTree, 'parse', return_value=self.xml) \
-                as mock_parse:
-            xml = copy.deepcopy(self.xml)
-            mock_parse.return_value = xml
-            model.Libvirt.add_ovs_interface(
-                '/usr/local', 0, self.pci_address_str, self.mac, xml_input)
-            mock_parse.assert_called_once_with(xml_input)
-            self.mock_write_xml.assert_called_once_with(xml_input)
-            interface = xml.find('devices').find('interface')
-            self.assertEqual('vhostuser', interface.get('type'))
-            mac = interface.find('mac')
-            self.assertEqual(self.mac, mac.get('address'))
-            source = interface.find('source')
-            self.assertEqual('unix', source.get('type'))
-            self.assertEqual('/usr/local/var/run/openvswitch/dpdkvhostuser0',
-                             source.get('path'))
-            self.assertEqual('client', source.get('mode'))
-            _model = interface.find('model')
-            self.assertEqual('virtio', _model.get('type'))
-            driver = interface.find('driver')
-            self.assertEqual('4', driver.get('queues'))
-            host = driver.find('host')
-            self.assertEqual('off', host.get('mrg_rxbuf'))
-            self.assertIsNotNone(interface.find('address'))
+        xml_input = copy.deepcopy(XML_SAMPLE)
+        xml_output = model.Libvirt.add_ovs_interface(
+            '/usr/local', 0, self.pci_address_str, self.mac, xml_input)
+
+        root = ElementTree.fromstring(xml_output)
+        et_out = ElementTree.ElementTree(element=root)
+        interface = et_out.find('devices').find('interface')
+        self.assertEqual('vhostuser', interface.get('type'))
+        mac = interface.find('mac')
+        self.assertEqual(self.mac, mac.get('address'))
+        source = interface.find('source')
+        self.assertEqual('unix', source.get('type'))
+        self.assertEqual('/usr/local/var/run/openvswitch/dpdkvhostuser0',
+                         source.get('path'))
+        self.assertEqual('client', source.get('mode'))
+        _model = interface.find('model')
+        self.assertEqual('virtio', _model.get('type'))
+        driver = interface.find('driver')
+        self.assertEqual('4', driver.get('queues'))
+        host = driver.find('host')
+        self.assertEqual('off', host.get('mrg_rxbuf'))
+        self.assertIsNotNone(interface.find('address'))
 
     def test_add_sriov_interfaces(self):
-        xml_input = mock.Mock()
-        with mock.patch.object(ElementTree, 'parse', return_value=self.xml) \
-                as mock_parse:
-            xml = copy.deepcopy(self.xml)
-            mock_parse.return_value = xml
-            vm_pci = '0001:05:04.2'
-            model.Libvirt.add_sriov_interfaces(
-                vm_pci, self.pci_address_str, self.mac, xml_input)
-            mock_parse.assert_called_once_with(xml_input)
-            self.mock_write_xml.assert_called_once_with(xml_input)
-            interface = xml.find('devices').find('interface')
-            self.assertEqual('yes', interface.get('managed'))
-            self.assertEqual('hostdev', interface.get('type'))
-            mac = interface.find('mac')
-            self.assertEqual(self.mac, mac.get('address'))
-            source = interface.find('source')
-            source_address = source.find('address')
-            self.assertIsNotNone(source.find('address'))
-
-            self.assertEqual('pci', source_address.get('type'))
-            self.assertEqual('0x' + self.pci_address_str.split(':')[0],
-                             source_address.get('domain'))
-            self.assertEqual('0x' + self.pci_address_str.split(':')[1],
-                             source_address.get('bus'))
-            self.assertEqual('0x' + self.pci_address_str.split(':')[2].split('.')[0],
-                             source_address.get('slot'))
-            self.assertEqual('0x' + self.pci_address_str.split(':')[2].split('.')[1],
-                             source_address.get('function'))
-
-            interface_address = interface.find('address')
-            self.assertEqual('pci', interface_address.get('type'))
-            self.assertEqual('0x' + vm_pci.split(':')[0],
-                             interface_address.get('domain'))
-            self.assertEqual('0x' + vm_pci.split(':')[1],
-                             interface_address.get('bus'))
-            self.assertEqual('0x' + vm_pci.split(':')[2].split('.')[0],
-                             interface_address.get('slot'))
-            self.assertEqual('0x' + vm_pci.split(':')[2].split('.')[1],
-                             interface_address.get('function'))
+        xml_input = copy.deepcopy(XML_SAMPLE)
+        vm_pci = '0001:05:04.2'
+        xml_output = model.Libvirt.add_sriov_interfaces(
+            vm_pci, self.pci_address_str, self.mac, xml_input)
+        root = ElementTree.fromstring(xml_output)
+        et_out = ElementTree.ElementTree(element=root)
+        interface = et_out.find('devices').find('interface')
+        self.assertEqual('yes', interface.get('managed'))
+        self.assertEqual('hostdev', interface.get('type'))
+        mac = interface.find('mac')
+        self.assertEqual(self.mac, mac.get('address'))
+        source = interface.find('source')
+        source_address = source.find('address')
+        self.assertIsNotNone(source.find('address'))
+
+        self.assertEqual('pci', source_address.get('type'))
+        self.assertEqual('0x' + self.pci_address_str.split(':')[0],
+                         source_address.get('domain'))
+        self.assertEqual('0x' + self.pci_address_str.split(':')[1],
+                         source_address.get('bus'))
+        self.assertEqual('0x' + self.pci_address_str.split(':')[2].split('.')[0],
+                         source_address.get('slot'))
+        self.assertEqual('0x' + self.pci_address_str.split(':')[2].split('.')[1],
+                         source_address.get('function'))
+
+        interface_address = interface.find('address')
+        self.assertEqual('pci', interface_address.get('type'))
+        self.assertEqual('0x' + vm_pci.split(':')[0],
+                         interface_address.get('domain'))
+        self.assertEqual('0x' + vm_pci.split(':')[1],
+                         interface_address.get('bus'))
+        self.assertEqual('0x' + vm_pci.split(':')[2].split('.')[0],
+                         interface_address.get('slot'))
+        self.assertEqual('0x' + vm_pci.split(':')[2].split('.')[1],
+                         interface_address.get('function'))
 
     def test_create_snapshot_qemu(self):
         result = "/var/lib/libvirt/images/0.qcow2"
@@ -179,24 +163,38 @@ class ModelLibvirtTestCase(unittest.TestCase):
         image = model.Libvirt.create_snapshot_qemu(ssh_mock, "0", "ubuntu.img")
         self.assertEqual(image, result)
 
-    @mock.patch.object(model.Libvirt, 'pin_vcpu_for_perf')
-    @mock.patch.object(model.Libvirt, 'create_snapshot_qemu')
+    @mock.patch.object(model.Libvirt, 'pin_vcpu_for_perf', return_value='4,5')
+    @mock.patch.object(model.Libvirt, 'create_snapshot_qemu',
+                       return_value='qemu_image')
     def test_build_vm_xml(self, mock_create_snapshot_qemu,
-                          *args):
-        # NOTE(ralonsoh): this test doesn't cover function execution. This test
-        # should also check mocked function calls.
-        cfg_file = 'test_config_file.cfg'
-        self.addCleanup(os.remove, cfg_file)
-        result = [4]
-        with mock.patch("yardstick.ssh.SSH") as ssh:
-            ssh_mock = mock.Mock(autospec=ssh.SSH)
-            ssh_mock.execute = \
-                mock.Mock(return_value=(0, "a", ""))
-            ssh.return_value = ssh_mock
-        mock_create_snapshot_qemu.return_value = "0.img"
-
-        status = model.Libvirt.build_vm_xml(ssh_mock, {}, cfg_file, 'vm_0', 0)
-        self.assertEqual(status[0], result[0])
+                          mock_pin_vcpu_for_perf):
+        extra_specs = {'hw:cpu_cores': '4',
+                       'hw:cpu_sockets': '3',
+                       'hw:cpu_threads': '2',
+                       'cputune': 'cool'}
+        flavor = {'ram': '1024',
+                  'extra_specs': extra_specs,
+                  'hw_socket': '1',
+                  'images': 'images'}
+        mac = model.StandaloneContextHelper.get_mac_address(0x00)
+        _uuid = uuid.uuid4()
+        connection = mock.Mock()
+        with mock.patch.object(model.StandaloneContextHelper,
+                               'get_mac_address', return_value=mac) as \
+                mock_get_mac_address, \
+                mock.patch.object(uuid, 'uuid4', return_value=_uuid):
+            xml_out, mac = model.Libvirt.build_vm_xml(
+                connection, flavor, 'vm_name', 100)
+
+        xml_ref = model.VM_TEMPLATE.format(vm_name='vm_name',
+            random_uuid=_uuid, mac_addr=mac, memory='1024', vcpu='8', cpu='4',
+            numa_cpus='0-7', socket='3', threads='2',
+            vm_image='qemu_image', cpuset='4,5', cputune='cool')
+        self.assertEqual(xml_ref, xml_out)
+        mock_get_mac_address.assert_called_once_with(0x00)
+        mock_create_snapshot_qemu.assert_called_once_with(
+            connection, 100, 'images')
+        mock_pin_vcpu_for_perf.assert_called_once_with(connection, '1')
 
     # TODO: Edit this test to test state instead of output
     # update_interrupts_hugepages_perf does not return anything
index 3ca0b9b..bc3bb73 100644 (file)
@@ -365,20 +365,14 @@ class OvsDpdkContextTestCase(unittest.TestCase):
             'fake_path', 0, self.NETWORKS['private_0']['vpci'],
             self.NETWORKS['private_0']['mac'], 'test')
 
+    @mock.patch.object(model.Libvirt, 'write_file')
     @mock.patch.object(model.Libvirt, 'build_vm_xml')
     @mock.patch.object(model.Libvirt, 'check_if_vm_exists_and_delete')
     @mock.patch.object(model.Libvirt, 'virsh_create_vm')
     def test_setup_ovs_dpdk_context(self, mock_create_vm, mock_check_if_exists,
-                                    mock_build_xml):
-        with mock.patch("yardstick.ssh.SSH") as ssh:
-            ssh_mock = mock.Mock(autospec=ssh.SSH)
-            ssh_mock.execute = \
-                mock.Mock(return_value=(0, "a", ""))
-            ssh_mock.put = \
-                mock.Mock(return_value=(0, "a", ""))
-            ssh.return_value = ssh_mock
+                                    mock_build_xml, mock_write_file):
         self.ovs_dpdk.vm_deploy = True
-        self.ovs_dpdk.connection = ssh_mock
+        self.ovs_dpdk.connection = mock.Mock()
         self.ovs_dpdk.vm_names = ['vm_0', 'vm_1']
         self.ovs_dpdk.drivers = []
         self.ovs_dpdk.servers = {
@@ -394,8 +388,9 @@ class OvsDpdkContextTestCase(unittest.TestCase):
         self.ovs_dpdk.host_mgmt = {}
         self.ovs_dpdk.flavor = {}
         self.ovs_dpdk.configure_nics_for_ovs_dpdk = mock.Mock(return_value="")
-        mock_build_xml.return_value = [6, "00:00:00:00:00:01"]
-        self.ovs_dpdk._enable_interfaces = mock.Mock(return_value="")
+        xml_str = mock.Mock()
+        mock_build_xml.return_value = (xml_str, '00:00:00:00:00:01')
+        self.ovs_dpdk._enable_interfaces = mock.Mock(return_value=xml_str)
         vnf_instance = mock.Mock()
         self.ovs_dpdk.vnf_node.generate_vnf_instance = mock.Mock(
             return_value=vnf_instance)
@@ -407,8 +402,8 @@ class OvsDpdkContextTestCase(unittest.TestCase):
         mock_check_if_exists.assert_called_once_with(
             'vm_0', self.ovs_dpdk.connection)
         mock_build_xml.assert_called_once_with(
-            self.ovs_dpdk.connection, self.ovs_dpdk.vm_flavor,
-            '/tmp/vm_ovs_0.xml', 'vm_0', 0)
+            self.ovs_dpdk.connection, self.ovs_dpdk.vm_flavor, 'vm_0', 0)
+        mock_write_file.assert_called_once_with('/tmp/vm_ovs_0.xml', xml_str)
 
     @mock.patch.object(io, 'BytesIO')
     def test__check_hugepages(self, mock_bytesio):
index f0953ef..e70ab0a 100644 (file)
@@ -18,6 +18,7 @@ import mock
 import unittest
 
 from yardstick import ssh
+from yardstick.benchmark.contexts.standalone import model
 from yardstick.benchmark.contexts.standalone import sriov
 
 
@@ -69,10 +70,11 @@ class SriovContextTestCase(unittest.TestCase):
         if self.sriov in self.sriov.list:
             self.sriov._delete_context()
 
-    @mock.patch('yardstick.benchmark.contexts.standalone.sriov.Libvirt')
-    @mock.patch('yardstick.benchmark.contexts.standalone.model.StandaloneContextHelper')
-    @mock.patch('yardstick.benchmark.contexts.standalone.model.Server')
-    def test___init__(self, mock_helper, mock_server, *args):
+    @mock.patch.object(model, 'StandaloneContextHelper')
+    @mock.patch.object(model, 'Libvirt')
+    @mock.patch.object(model, 'Server')
+    def test___init__(self, mock_helper, mock_libvirt, mock_server):
+        # pylint: disable=unused-argument
         # NOTE(ralonsoh): this test doesn't cover function execution.
         self.sriov.helper = mock_helper
         self.sriov.vnf_node = mock_server
@@ -97,9 +99,11 @@ class SriovContextTestCase(unittest.TestCase):
         self.sriov.wait_for_vnfs_to_start = mock.Mock(return_value={})
         self.assertIsNone(self.sriov.deploy())
 
-    @mock.patch('yardstick.benchmark.contexts.standalone.sriov.Libvirt')
     @mock.patch.object(ssh, 'SSH', return_value=(0, "a", ""))
-    def test_undeploy(self, mock_ssh, *args):
+    @mock.patch.object(model, 'Libvirt')
+    def test_undeploy(self, mock_libvirt, mock_ssh):
+        # pylint: disable=unused-argument
+        # NOTE(ralonsoh): the pylint exception should be removed.
         self.sriov.vm_deploy = False
         self.assertIsNone(self.sriov.undeploy())
 
@@ -237,11 +241,11 @@ class SriovContextTestCase(unittest.TestCase):
         self.sriov._get_vf_data = mock.Mock(return_value="")
         self.assertIsNone(self.sriov.configure_nics_for_sriov())
 
-    @mock.patch('yardstick.benchmark.contexts.standalone.sriov.Libvirt')
-    @mock.patch.object(ssh, 'SSH')
-    def test__enable_interfaces(self, mock_ssh, *args):
-        mock_ssh.return_value = 0, "a", ""
-
+    @mock.patch.object(ssh, 'SSH', return_value=(0, "a", ""))
+    @mock.patch.object(model, 'Libvirt')
+    def test__enable_interfaces(self, mock_libvirt, mock_ssh):
+        # pylint: disable=unused-argument
+        # NOTE(ralonsoh): the pylint exception should be removed.
         self.sriov.vm_deploy = True
         self.sriov.connection = mock_ssh
         self.sriov.vm_names = ['vm_0', 'vm_1']
@@ -251,20 +255,12 @@ class SriovContextTestCase(unittest.TestCase):
         self.assertIsNone(self.sriov._enable_interfaces(
             0, 0, ["private_0"], 'test'))
 
-    @mock.patch('yardstick.benchmark.contexts.standalone.model.Server')
-    @mock.patch('yardstick.benchmark.contexts.standalone.sriov.Libvirt')
-    def test_setup_sriov_context(self, mock_libvirt, *args):
-        with mock.patch("yardstick.ssh.SSH") as ssh:
-            ssh_mock = mock.Mock(autospec=ssh.SSH)
-            ssh_mock.execute = \
-                mock.Mock(return_value=(0, "a", ""))
-            ssh_mock.put = \
-                mock.Mock(return_value=(0, "a", ""))
-            ssh.return_value = ssh_mock
-        self.sriov.vm_deploy = True
-        self.sriov.connection = ssh_mock
-        self.sriov.vm_names = ['vm_0', 'vm_1']
-        self.sriov.drivers = []
+    @mock.patch.object(model.Libvirt, 'build_vm_xml')
+    @mock.patch.object(model.Libvirt, 'check_if_vm_exists_and_delete')
+    @mock.patch.object(model.Libvirt, 'write_file')
+    @mock.patch.object(model.Libvirt, 'virsh_create_vm')
+    def test_setup_sriov_context(self, mock_create_vm, mock_write_file,
+                                 mock_check, mock_build_vm_xml):
         self.sriov.servers = {
             'vnf_0': {
                 'network_ports': {
@@ -274,15 +270,31 @@ class SriovContextTestCase(unittest.TestCase):
                 }
             }
         }
-        self.sriov.networks = self.NETWORKS
-        self.sriov.host_mgmt = {}
-        self.sriov.flavor = {}
-        self.sriov.configure_nics_for_sriov = mock.Mock(return_value="")
-        mock_libvirt.build_vm_xml = mock.Mock(
-            return_value=[6, "00:00:00:00:00:01"])
-        self.sriov._enable_interfaces = mock.Mock(return_value="")
-        self.sriov.vnf_node.generate_vnf_instance = mock.Mock(return_value={})
-        self.assertIsNotNone(self.sriov.setup_sriov_context())
+        connection = mock.Mock()
+        self.sriov.connection = connection
+        self.sriov.host_mgmt = {'ip': '1.2.3.4'}
+        self.sriov.vm_flavor = 'flavor'
+        self.sriov.networks = 'networks'
+        self.sriov.configure_nics_for_sriov = mock.Mock()
+        cfg = '/tmp/vm_sriov_0.xml'
+        vm_name = 'vm_0'
+        xml_out = mock.Mock()
+        mock_build_vm_xml.return_value = (xml_out, '00:00:00:00:00:01')
+
+        with mock.patch.object(self.sriov, 'vnf_node') as mock_vnf_node, \
+                mock.patch.object(self.sriov, '_enable_interfaces'):
+            mock_vnf_node.generate_vnf_instance = mock.Mock(
+                return_value='node')
+            nodes_out = self.sriov.setup_sriov_context()
+        self.assertEqual(['node'], nodes_out)
+        mock_vnf_node.generate_vnf_instance.assert_called_once_with(
+            'flavor', 'networks', '1.2.3.4', 'vnf_0',
+            self.sriov.servers['vnf_0'], '00:00:00:00:00:01')
+        mock_build_vm_xml.assert_called_once_with(
+            connection, 'flavor', vm_name, 0)
+        mock_create_vm.assert_called_once_with(connection, cfg)
+        mock_check.assert_called_once_with(vm_name, connection)
+        mock_write_file.assert_called_once_with(cfg, xml_out)
 
     def test__get_vf_data(self):
         with mock.patch("yardstick.ssh.SSH") as ssh: