Ensure library and tests close all necessary resources. 93/38293/1
authorspisarski <s.pisarski@cablelabs.com>
Thu, 27 Jul 2017 16:27:14 +0000 (10:27 -0600)
committerspisarski <s.pisarski@cablelabs.com>
Thu, 27 Jul 2017 16:27:14 +0000 (10:27 -0600)
The SNAPS-OO library and tests had left open files, ssh, and scp
connections. These have all now been wrapped with try/finally
blocks.

JIRA: SNAPS-152

Change-Id: I43e09978b5c075bd78ff3279c0799556b8758878
Signed-off-by: spisarski <s.pisarski@cablelabs.com>
snaps/file_utils.py
snaps/openstack/create_instance.py
snaps/openstack/tests/conf/os_credentials_tests.py
snaps/openstack/tests/create_instance_tests.py
snaps/openstack/tests/create_keypairs_tests.py
snaps/openstack/utils/glance_utils.py
snaps/openstack/utils/nova_utils.py
snaps/playbook_runner.py
snaps/provisioning/tests/ansible_utils_tests.py
snaps/tests/file_utils_tests.py

index a7ed13c..ff2f1b3 100644 (file)
@@ -32,7 +32,8 @@ logger = logging.getLogger('file_utils')
 
 def file_exists(file_path):
     """
-    Returns True if the image file already exists and throws an exception if the path is a directory
+    Returns True if the image file already exists and throws an exception if
+    the path is a directory
     :return:
     """
     if os.path.exists(file_path):
@@ -55,7 +56,7 @@ def download(url, dest_path, name=None):
     dest = dest_path + '/' + name
     logger.debug('Downloading file from - ' + url)
     # Override proxy settings to use localhost to download file
-    f = None
+    download_file = None
 
     if not os.path.isdir(dest_path):
         try:
@@ -63,14 +64,14 @@ def download(url, dest_path, name=None):
         except:
             raise
     try:
-        with open(dest, 'wb') as f:
-            logger.debug('Saving file to - ' + os.path.abspath(f.name))
+        with open(dest, 'wb') as download_file:
+            logger.debug('Saving file to - ' + os.path.abspath(download_file.name))
             response = __get_url_response(url)
-            f.write(response.read())
-        return f
+            download_file.write(response.read())
+        return download_file
     finally:
-        if f:
-            f.close()
+        if download_file:
+            download_file.close()
 
 
 def get_content_length(url):
@@ -102,32 +103,45 @@ def read_yaml(config_file_path):
     :return: a dictionary
     """
     logger.debug('Attempting to load configuration file - ' + config_file_path)
-    with open(config_file_path) as config_file:
-        config = yaml.safe_load(config_file)
-        logger.info('Loaded configuration')
-    config_file.close()
-    logger.info('Closing configuration file')
-    return config
+    config_file = None
+    try:
+        with open(config_file_path) as config_file:
+            config = yaml.safe_load(config_file)
+            logger.info('Loaded configuration')
+        return config
+    finally:
+        if config_file:
+            logger.info('Closing configuration file')
+            config_file.close()
 
 
 def read_os_env_file(os_env_filename):
     """
     Reads the OS environment source file and returns a map of each key/value
-    Will ignore lines beginning with a '#' and will replace any single or double quotes contained within the value
+    Will ignore lines beginning with a '#' and will replace any single or
+    double quotes contained within the value
     :param os_env_filename: The name of the OS environment file to read
     :return: a dictionary
     """
     if os_env_filename:
-        logger.info('Attempting to read OS environment file - ' + os_env_filename)
+        logger.info('Attempting to read OS environment file - %s',
+                    os_env_filename)
         out = {}
-        for line in open(os_env_filename):
-            line = line.lstrip()
-            if not line.startswith('#') and line.startswith('export '):
-                line = line.lstrip('export ').strip()
-                tokens = line.split('=')
-                if len(tokens) > 1:
-                    # Remove leading and trailing ' & " characters from value
-                    out[tokens[0]] = tokens[1].lstrip('\'').lstrip('\"').rstrip('\'').rstrip('\"')
+        env_file = None
+        try:
+            env_file = open(os_env_filename)
+            for line in env_file:
+                line = line.lstrip()
+                if not line.startswith('#') and line.startswith('export '):
+                    line = line.lstrip('export ').strip()
+                    tokens = line.split('=')
+                    if len(tokens) > 1:
+                        # Remove leading and trailing ' & " characters from
+                        # value
+                        out[tokens[0]] = tokens[1].lstrip('\'').lstrip('\"').rstrip('\'').rstrip('\"')
+        finally:
+            if env_file:
+                env_file.close()
         return out
 
 
@@ -138,7 +152,12 @@ def read_file(filename):
     :return:
     """
     out = str()
-    for line in open(filename):
-        out += line
-
-    return out
+    the_file = None
+    try:
+        the_file = open(filename)
+        for line in the_file:
+            out += line
+        return out
+    finally:
+        if the_file:
+            the_file.close()
index d5917a8..997b5a5 100644 (file)
@@ -618,6 +618,7 @@ class OpenStackVmInstance:
         if len(self.__floating_ips) > 0:
             ssh = self.ssh_client()
             if ssh:
+                ssh.close()
                 return True
         return False
 
index e7c34b9..4a2ce3d 100644 (file)
@@ -56,17 +56,17 @@ class ProxySettingsUnitTests(unittest.TestCase):
     def test_minimum(self):
         proxy_settings = ProxySettings(host='foo', port=1234)
         self.assertEqual('foo', proxy_settings.host)
-        self.assertEqual(1234, proxy_settings.port)
+        self.assertEqual('1234', proxy_settings.port)
         self.assertEqual('foo', proxy_settings.https_host)
-        self.assertEqual(1234, proxy_settings.https_port)
+        self.assertEqual('1234', proxy_settings.https_port)
         self.assertIsNone(proxy_settings.ssh_proxy_cmd)
 
     def test_minimum_kwargs(self):
         proxy_settings = ProxySettings(**{'host': 'foo', 'port': 1234})
         self.assertEqual('foo', proxy_settings.host)
-        self.assertEqual(1234, proxy_settings.port)
+        self.assertEqual('1234', proxy_settings.port)
         self.assertEqual('foo', proxy_settings.https_host)
-        self.assertEqual(1234, proxy_settings.https_port)
+        self.assertEqual('1234', proxy_settings.https_port)
         self.assertIsNone(proxy_settings.ssh_proxy_cmd)
 
     def test_all(self):
@@ -74,9 +74,9 @@ class ProxySettingsUnitTests(unittest.TestCase):
             host='foo', port=1234, https_host='bar', https_port=2345,
             ssh_proxy_cmd='proxy command')
         self.assertEqual('foo', proxy_settings.host)
-        self.assertEqual(1234, proxy_settings.port)
+        self.assertEqual('1234', proxy_settings.port)
         self.assertEqual('bar', proxy_settings.https_host)
-        self.assertEqual(2345, proxy_settings.https_port)
+        self.assertEqual('2345', proxy_settings.https_port)
         self.assertEqual('proxy command', proxy_settings.ssh_proxy_cmd)
 
     def test_all_kwargs(self):
@@ -84,9 +84,9 @@ class ProxySettingsUnitTests(unittest.TestCase):
             **{'host': 'foo', 'port': 1234, 'https_host': 'bar',
                'https_port': 2345, 'ssh_proxy_cmd': 'proxy command'})
         self.assertEqual('foo', proxy_settings.host)
-        self.assertEqual(1234, proxy_settings.port)
+        self.assertEqual('1234', proxy_settings.port)
         self.assertEqual('bar', proxy_settings.https_host)
-        self.assertEqual(2345, proxy_settings.https_port)
+        self.assertEqual('2345', proxy_settings.https_port)
         self.assertEqual('proxy command', proxy_settings.ssh_proxy_cmd)
 
 
@@ -245,7 +245,7 @@ class OSCredsUnitTests(unittest.TestCase):
         self.assertEqual('admin', os_creds.interface)
         self.assertFalse(os_creds.cacert)
         self.assertEqual('foo', os_creds.proxy_settings.host)
-        self.assertEqual(1234, os_creds.proxy_settings.port)
+        self.assertEqual('1234', os_creds.proxy_settings.port)
         self.assertIsNone(os_creds.proxy_settings.ssh_proxy_cmd)
         self.assertIsNone(os_creds.region_name)
 
@@ -269,7 +269,7 @@ class OSCredsUnitTests(unittest.TestCase):
         self.assertEqual('admin', os_creds.interface)
         self.assertFalse(os_creds.cacert)
         self.assertEqual('foo', os_creds.proxy_settings.host)
-        self.assertEqual(1234, os_creds.proxy_settings.port)
+        self.assertEqual('1234', os_creds.proxy_settings.port)
         self.assertIsNone(os_creds.proxy_settings.ssh_proxy_cmd)
         self.assertEqual('test_region', os_creds.region_name)
 
@@ -290,7 +290,7 @@ class OSCredsUnitTests(unittest.TestCase):
         self.assertEqual('admin', os_creds.interface)
         self.assertFalse(os_creds.cacert)
         self.assertEqual('foo', os_creds.proxy_settings.host)
-        self.assertEqual(1234, os_creds.proxy_settings.port)
+        self.assertEqual('1234', os_creds.proxy_settings.port)
         self.assertIsNone(os_creds.proxy_settings.ssh_proxy_cmd)
 
     def test_proxy_settings_dict_kwargs(self):
@@ -312,6 +312,6 @@ class OSCredsUnitTests(unittest.TestCase):
         self.assertEqual('admin', os_creds.interface)
         self.assertFalse(os_creds.cacert)
         self.assertEqual('foo', os_creds.proxy_settings.host)
-        self.assertEqual(1234, os_creds.proxy_settings.port)
+        self.assertEqual('1234', os_creds.proxy_settings.port)
         self.assertIsNone(os_creds.proxy_settings.ssh_proxy_cmd)
         self.assertEqual('test_region', os_creds.region_name)
index 75b0ed3..1922146 100644 (file)
@@ -1717,7 +1717,10 @@ def validate_ssh_client(instance_creator):
     if ssh_active:
         ssh_client = instance_creator.ssh_client()
         if ssh_client:
-            out = ssh_client.exec_command('pwd')[1]
+            try:
+                out = ssh_client.exec_command('pwd')[1]
+            finally:
+                ssh_client.close()
         else:
             return False
 
index 0b35095..7b75d05 100644 (file)
@@ -285,9 +285,15 @@ class CreateKeypairsTests(OSIntegrationTestCase):
                                             self.keypair_creator.get_keypair())
         self.assertEqual(self.keypair_creator.get_keypair(), keypair)
 
-        file_key = open(os.path.expanduser(self.pub_file_path)).read()
-        self.assertEqual(self.keypair_creator.get_keypair().public_key,
-                         file_key)
+        pub_file = None
+        try:
+            pub_file = open(os.path.expanduser(self.pub_file_path))
+            file_key = pub_file.read()
+            self.assertEqual(self.keypair_creator.get_keypair().public_key,
+                             file_key)
+        finally:
+            if pub_file:
+                pub_file.close()
 
     def test_create_keypair_save_both(self):
         """
@@ -305,7 +311,16 @@ class CreateKeypairsTests(OSIntegrationTestCase):
                                             self.keypair_creator.get_keypair())
         self.assertEqual(self.keypair_creator.get_keypair(), keypair)
 
-        file_key = open(os.path.expanduser(self.pub_file_path)).read()
+        pub_file = None
+        try:
+            pub_file = open(os.path.expanduser(self.pub_file_path))
+            file_key = pub_file.read()
+            self.assertEqual(self.keypair_creator.get_keypair().public_key,
+                             file_key)
+        finally:
+            if pub_file:
+                pub_file.close()
+
         self.assertEqual(self.keypair_creator.get_keypair().public_key,
                          file_key)
 
@@ -328,7 +343,16 @@ class CreateKeypairsTests(OSIntegrationTestCase):
                                             self.keypair_creator.get_keypair())
         self.assertEqual(self.keypair_creator.get_keypair(), keypair)
 
-        file_key = open(os.path.expanduser(self.pub_file_path)).read()
+        pub_file = None
+        try:
+            pub_file = open(os.path.expanduser(self.pub_file_path))
+            file_key = pub_file.read()
+            self.assertEqual(self.keypair_creator.get_keypair().public_key,
+                             file_key)
+        finally:
+            if pub_file:
+                pub_file.close()
+
         self.assertEqual(self.keypair_creator.get_keypair().public_key,
                          file_key)
 
index 49bfe95..ad9c5e5 100644 (file)
@@ -124,22 +124,30 @@ def __create_image_v1(glance, image_settings):
         'name': image_settings.name, 'disk_format': image_settings.format,
         'container_format': 'bare', 'is_public': image_settings.public}
 
-    if image_settings.extra_properties:
-        kwargs['properties'] = image_settings.extra_properties
-
-    if image_settings.url:
-        kwargs['location'] = image_settings.url
-    elif image_settings.image_file:
-        image_file = open(image_settings.image_file, 'rb')
-        kwargs['data'] = image_file
-    else:
-        logger.warn('Unable to create image with name - %s. No file or URL',
-                    image_settings.name)
-        return None
+    image_file = None
 
-    created_image = glance.images.create(**kwargs)
-    return Image(name=image_settings.name, image_id=created_image.id,
-                 size=created_image.size, properties=created_image.properties)
+    try:
+        if image_settings.extra_properties:
+            kwargs['properties'] = image_settings.extra_properties
+
+        if image_settings.url:
+            kwargs['location'] = image_settings.url
+        elif image_settings.image_file:
+            image_file = open(image_settings.image_file, 'rb')
+            kwargs['data'] = image_file
+        else:
+            logger.warn(
+                'Unable to create image with name - %s. No file or URL',
+                image_settings.name)
+            return None
+
+        created_image = glance.images.create(**kwargs)
+        return Image(name=image_settings.name, image_id=created_image.id,
+                     size=created_image.size,
+                     properties=created_image.properties)
+    finally:
+        if image_file:
+            image_file.close()
 
 
 def __create_image_v2(glance, image_settings):
index ab434f1..b148bc5 100644 (file)
@@ -232,12 +232,18 @@ def save_keys_to_files(keys=None, pub_file_path=None, priv_file_path=None):
 
             if not os.path.isdir(pub_dir):
                 os.mkdir(pub_dir)
-            public_handle = open(pub_expand_file, 'wb')
-            public_bytes = keys.public_key().public_bytes(
-                serialization.Encoding.OpenSSH,
-                serialization.PublicFormat.OpenSSH)
-            public_handle.write(public_bytes)
-            public_handle.close()
+
+            public_handle = None
+            try:
+                public_handle = open(pub_expand_file, 'wb')
+                public_bytes = keys.public_key().public_bytes(
+                    serialization.Encoding.OpenSSH,
+                    serialization.PublicFormat.OpenSSH)
+                public_handle.write(public_bytes)
+            finally:
+                if public_handle:
+                    public_handle.close()
+
             os.chmod(pub_expand_file, 0o400)
             logger.info("Saved public key to - " + pub_expand_file)
         if priv_file_path:
@@ -246,13 +252,19 @@ def save_keys_to_files(keys=None, pub_file_path=None, priv_file_path=None):
             priv_dir = os.path.dirname(priv_expand_file)
             if not os.path.isdir(priv_dir):
                 os.mkdir(priv_dir)
-            private_handle = open(priv_expand_file, 'wb')
-            private_handle.write(
-                keys.private_bytes(
-                    encoding=serialization.Encoding.PEM,
-                    format=serialization.PrivateFormat.TraditionalOpenSSL,
-                    encryption_algorithm=serialization.NoEncryption()))
-            private_handle.close()
+
+            private_handle = None
+            try:
+                private_handle = open(priv_expand_file, 'wb')
+                private_handle.write(
+                    keys.private_bytes(
+                        encoding=serialization.Encoding.PEM,
+                        format=serialization.PrivateFormat.TraditionalOpenSSL,
+                        encryption_algorithm=serialization.NoEncryption()))
+            finally:
+                if private_handle:
+                    private_handle.close()
+
             os.chmod(priv_expand_file, 0o400)
             logger.info("Saved private key to - " + priv_expand_file)
 
@@ -265,9 +277,14 @@ def upload_keypair_file(nova, name, file_path):
     :param file_path: the path to the public key file
     :return: the keypair object
     """
-    with open(os.path.expanduser(file_path), 'rb') as fpubkey:
-        logger.info('Saving keypair to - ' + file_path)
-        return upload_keypair(nova, name, fpubkey.read())
+    fpubkey = None
+    try:
+        with open(os.path.expanduser(file_path), 'rb') as fpubkey:
+            logger.info('Saving keypair to - ' + file_path)
+            return upload_keypair(nova, name, fpubkey.read())
+    finally:
+        if fpubkey:
+            fpubkey.close()
 
 
 def upload_keypair(nova, name, key):
index 3710309..4dba550 100644 (file)
@@ -1,4 +1,4 @@
-# Copyright (c) 2016 Cable Television Laboratories, Inc. ("CableLabs")
+# Copyright (c) 2017 Cable Television Laboratories, Inc. ("CableLabs")
 #                    and others.  All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -27,7 +27,8 @@ logger = logging.getLogger('playbook_runner')
 
 def main(parsed_args):
     """
-    Uses ansible_utils for applying Ansible Playbooks to machines with a private key
+    Uses ansible_utils for applying Ansible Playbooks to machines with a
+    private key
     """
     logging.basicConfig(level=logging.DEBUG)
     logger.info('Starting Playbook Runner')
@@ -35,24 +36,36 @@ def main(parsed_args):
     proxy_settings = None
     if parsed_args.http_proxy:
         tokens = re.split(':', parsed_args.http_proxy)
-        proxy_settings = ProxySettings(tokens[0], tokens[1], parsed_args.ssh_proxy_cmd)
+        proxy_settings = ProxySettings(host=tokens[0], port=tokens[1],
+                                       ssh_proxy_cmd=parsed_args.ssh_proxy_cmd)
 
     # Ensure can get an SSH client
-    ansible_utils.ssh_client(parsed_args.ip_addr, parsed_args.host_user, parsed_args.priv_key, proxy_settings)
+    ssh = ansible_utils.ssh_client(parsed_args.ip_addr, parsed_args.host_user,
+                                   parsed_args.priv_key, proxy_settings)
+    if ssh:
+        ssh.close()
 
-    retval = ansible_utils.apply_playbook(parsed_args.playbook, [parsed_args.ip_addr], parsed_args.host_user,
-                                          parsed_args.priv_key, variables={'name': 'Foo'}, proxy_setting=proxy_settings)
+    retval = ansible_utils.apply_playbook(
+        parsed_args.playbook, [parsed_args.ip_addr], parsed_args.host_user,
+        parsed_args.priv_key, variables={'name': 'Foo'},
+        proxy_setting=proxy_settings)
     exit(retval)
 
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('-a', '--ip-addr', dest='ip_addr', required=True, help='The Host IP Address')
-    parser.add_argument('-k', '--priv-key', dest='priv_key', required=True, help='The location of the private key file')
-    parser.add_argument('-u', '--host-user', dest='host_user', required=True, help='Host user account')
-    parser.add_argument('-b', '--playbook', dest='playbook', required=True, help='Playbook Location')
-    parser.add_argument('-p', '--http-proxy', dest='http_proxy', required=False, help='<host>:<port>')
-    parser.add_argument('-s', '--ssh-proxy-cmd', dest='ssh_proxy_cmd', required=False)
+    parser.add_argument('-a', '--ip-addr', dest='ip_addr', required=True,
+                        help='The Host IP Address')
+    parser.add_argument('-k', '--priv-key', dest='priv_key', required=True,
+                        help='The location of the private key file')
+    parser.add_argument('-u', '--host-user', dest='host_user', required=True,
+                        help='Host user account')
+    parser.add_argument('-b', '--playbook', dest='playbook', required=True,
+                        help='Playbook Location')
+    parser.add_argument('-p', '--http-proxy', dest='http_proxy',
+                        required=False, help='<host>:<port>')
+    parser.add_argument('-s', '--ssh-proxy-cmd', dest='ssh_proxy_cmd',
+                        required=False)
     args = parser.parse_args()
 
     main(args)
index 203ba33..da056b2 100644 (file)
@@ -239,9 +239,14 @@ class AnsibleProvisioningTests(OSIntegrationTestCase):
 
         ssh_client = self.inst_creator.ssh_client()
         self.assertIsNotNone(ssh_client)
-        out = ssh_client.exec_command('pwd')[1].channel.in_buffer.read(1024)
-        self.assertIsNotNone(out)
-        self.assertGreater(len(out), 1)
+
+        try:
+            out = ssh_client.exec_command('pwd')[1].channel.in_buffer.read(
+                1024)
+            self.assertIsNotNone(out)
+            self.assertGreater(len(out), 1)
+        finally:
+            ssh_client.close()
 
         # Need to use the first floating IP as subsequent ones are currently
         # broken with Apex CO
@@ -257,14 +262,25 @@ class AnsibleProvisioningTests(OSIntegrationTestCase):
         ssh = ansible_utils.ssh_client(ip, user, priv_key,
                                        self.os_creds.proxy_settings)
         self.assertIsNotNone(ssh)
-        scp = SCPClient(ssh.get_transport())
-        scp.get('~/hello.txt', self.test_file_local_path)
+
+        try:
+            scp = SCPClient(ssh.get_transport())
+            scp.get('~/hello.txt', self.test_file_local_path)
+        finally:
+            scp.close()
+            ssh.close()
 
         self.assertTrue(os.path.isfile(self.test_file_local_path))
 
-        with open(self.test_file_local_path) as f:
-            file_contents = f.readline()
-            self.assertEqual('Hello World!', file_contents)
+        test_file = None
+
+        try:
+            with open(self.test_file_local_path) as test_file:
+                file_contents = test_file.readline()
+                self.assertEqual('Hello World!', file_contents)
+        finally:
+            if test_file:
+                test_file.close()
 
     def test_apply_template_playbook(self):
         """
@@ -310,11 +326,21 @@ class AnsibleProvisioningTests(OSIntegrationTestCase):
         ssh = ansible_utils.ssh_client(ip, user, priv_key,
                                        self.os_creds.proxy_settings)
         self.assertIsNotNone(ssh)
-        scp = SCPClient(ssh.get_transport())
-        scp.get('/tmp/hello.txt', self.test_file_local_path)
+
+        try:
+            scp = SCPClient(ssh.get_transport())
+            scp.get('/tmp/hello.txt', self.test_file_local_path)
+        finally:
+            scp.close()
+            ssh.close()
 
         self.assertTrue(os.path.isfile(self.test_file_local_path))
 
-        with open(self.test_file_local_path) as f:
-            file_contents = f.readline()
-            self.assertEqual('Hello Foo!', file_contents)
+        test_file = None
+        try:
+            with open(self.test_file_local_path) as test_file:
+                file_contents = test_file.readline()
+                self.assertEqual('Hello Foo!', file_contents)
+        finally:
+            if test_file:
+                test_file.close()
index f3a622a..ef8b4ae 100644 (file)
@@ -37,10 +37,14 @@ class FileUtilsTests(unittest.TestCase):
             os.makedirs(self.test_dir)
 
         self.tmpFile = self.test_dir + '/bar.txt'
+        self.tmp_file_opened = None
         if not os.path.exists(self.tmpFile):
-            open(self.tmpFile, 'wb')
+            self.tmp_file_opened = open(self.tmpFile, 'wb')
 
     def tearDown(self):
+        if self.tmp_file_opened:
+            self.tmp_file_opened.close()
+
         if os.path.exists(self.test_dir) and os.path.isdir(self.test_dir):
             shutil.rmtree(self.tmp_dir)