Added a check for cloud-init completion before final ssh client check.
[snaps.git] / snaps / file_utils.py
index 34eb30c..284ae15 100644 (file)
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import ssl
+
 import os
-import urllib2
 import logging
 
+from cryptography.hazmat.primitives import serialization
+
+try:
+    import urllib.request as urllib
+except ImportError:
+    import urllib2 as urllib
+
 import yaml
 
 __author__ = 'spisarski'
@@ -29,50 +37,118 @@ logger = logging.getLogger('file_utils')
 
 def file_exists(file_path):
     """
-    Returns True if the image file already exists and throws an exception if the path is a directory
+    Returns True if the image file already exists and throws an exception if
+    the path is a directory
     :return:
     """
-    if os.path.exists(file_path):
-        if os.path.isdir(file_path):
+    expanded_path = os.path.expanduser(file_path)
+    if os.path.exists(expanded_path):
+        if os.path.isdir(expanded_path):
             return False
-        return os.path.isfile(file_path)
+        return os.path.isfile(expanded_path)
     return False
 
 
-def get_file(file_path):
-    """
-    Returns True if the image file has already been downloaded
-    :return: the image file object
-    :raise Exception when file cannot be found
-    """
-    if file_exists(file_path):
-        return open(file_path, 'r')
-    else:
-        raise Exception('File with path cannot be found - ' + file_path)
-
-
 def download(url, dest_path, name=None):
     """
     Download a file to a destination path given a URL
+    :param url: the endpoint to the file to download
+    :param dest_path: the directory to save the file
+    :param name: the file name (optional)
     :rtype : File object
     """
     if not name:
         name = url.rsplit('/')[-1]
     dest = dest_path + '/' + name
+    logger.debug('Downloading file from - ' + url)
+    # Override proxy settings to use localhost to download file
+    download_file = None
+
+    if not os.path.isdir(dest_path):
+        try:
+            os.mkdir(dest_path)
+        except:
+            raise
     try:
-        logger.debug('Downloading file from - ' + url)
-        # Override proxy settings to use localhost to download file
-        proxy_handler = urllib2.ProxyHandler({})
-        opener = urllib2.build_opener(proxy_handler)
-        urllib2.install_opener(opener)
-        response = urllib2.urlopen(url)
-    except (urllib2.HTTPError, urllib2.URLError):
-        raise Exception
-
-    with open(dest, 'wb') as f:
-        logger.debug('Saving file to - ' + dest)
-        f.write(response.read())
-    return f
+        with open(dest, 'wb') as download_file:
+            logger.debug('Saving file to - %s',
+                         os.path.abspath(download_file.name))
+            response = __get_url_response(url)
+            download_file.write(response.read())
+        return download_file
+    finally:
+        if download_file:
+            download_file.close()
+
+
+def save_keys_to_files(keys=None, pub_file_path=None, priv_file_path=None):
+    """
+    Saves the generated RSA generated keys to the filesystem
+    :param keys: the keys to save generated by cryptography
+    :param pub_file_path: the path to the public keys
+    :param priv_file_path: the path to the private keys
+    """
+    if keys:
+        if pub_file_path:
+            # To support '~'
+            pub_expand_file = os.path.expanduser(pub_file_path)
+            pub_dir = os.path.dirname(pub_expand_file)
+
+            if not os.path.isdir(pub_dir):
+                os.mkdir(pub_dir)
+
+            public_handle = None
+            try:
+                public_handle = open(pub_expand_file, 'wb')
+                public_bytes = keys.public_key().public_bytes(
+                    serialization.Encoding.OpenSSH,
+                    serialization.PublicFormat.OpenSSH)
+                public_handle.write(public_bytes)
+            finally:
+                if public_handle:
+                    public_handle.close()
+
+            os.chmod(pub_expand_file, 0o400)
+            logger.info("Saved public key to - " + pub_expand_file)
+        if priv_file_path:
+            # To support '~'
+            priv_expand_file = os.path.expanduser(priv_file_path)
+            priv_dir = os.path.dirname(priv_expand_file)
+            if not os.path.isdir(priv_dir):
+                os.mkdir(priv_dir)
+
+            private_handle = None
+            try:
+                private_handle = open(priv_expand_file, 'wb')
+                private_handle.write(
+                    keys.private_bytes(
+                        encoding=serialization.Encoding.PEM,
+                        format=serialization.PrivateFormat.TraditionalOpenSSL,
+                        encryption_algorithm=serialization.NoEncryption()))
+            finally:
+                if private_handle:
+                    private_handle.close()
+
+            os.chmod(priv_expand_file, 0o400)
+            logger.info("Saved private key to - " + priv_expand_file)
+
+
+def save_string_to_file(string, file_path, mode=None):
+    """
+    Stores
+    :param string: the string contents to store
+    :param file_path: the file path to create
+    :param mode: the file's mode
+    :return: the file object
+    """
+    save_file = open(file_path, 'w')
+    try:
+        save_file.write(string)
+        if mode:
+            os.chmod(file_path, mode)
+        return save_file
+    finally:
+        save_file.close()
 
 
 def get_content_length(url):
@@ -81,13 +157,23 @@ def get_content_length(url):
     :param url: the URL to inspect
     :return: the number of bytes
     """
-    proxy_handler = urllib2.ProxyHandler({})
-    opener = urllib2.build_opener(proxy_handler)
-    urllib2.install_opener(opener)
-    response = urllib2.urlopen(url)
+    response = __get_url_response(url)
     return response.headers['Content-Length']
 
 
+def __get_url_response(url):
+    """
+    Returns a response object for a given URL
+    :param url: the URL
+    :return: the response
+    """
+    proxy_handler = urllib.ProxyHandler({})
+    opener = urllib.build_opener(proxy_handler)
+    urllib.install_opener(opener)
+    context = ssl._create_unverified_context()
+    return urllib.urlopen(url, context=context)
+
+
 def read_yaml(config_file_path):
     """
     Reads the yaml file and returns a dictionary object representation
@@ -95,30 +181,75 @@ def read_yaml(config_file_path):
     :return: a dictionary
     """
     logger.debug('Attempting to load configuration file - ' + config_file_path)
-    with open(config_file_path) as config_file:
-        config = yaml.safe_load(config_file)
-        logger.info('Loaded configuration')
-    config_file.close()
-    logger.info('Closing configuration file')
-    return config
+    config_file = None
+    try:
+        with open(config_file_path, 'r') as config_file:
+            config = yaml.safe_load(config_file)
+            logger.info('Loaded configuration')
+        return config
+    finally:
+        if config_file:
+            logger.info('Closing configuration file')
+            config_file.close()
+
+
+def persist_dict_to_yaml(the_dict, file_name):
+    """
+    Creates a YAML file from a dict
+    :param the_dict: the dictionary to store
+    :param conf_dir: the directory used to store the config file
+    :return: the file object
+    """
+    logger.info('Persisting %s to [%s]', the_dict, file_name)
+    file_path = os.path.expanduser(file_name)
+    yaml_from_dict = yaml.dump(
+        the_dict, default_flow_style=False, default_style='')
+    return save_string_to_file(yaml_from_dict, file_path)
 
 
 def read_os_env_file(os_env_filename):
     """
     Reads the OS environment source file and returns a map of each key/value
-    Will ignore lines beginning with a '#' and will replace any single or double quotes contained within the value
+    Will ignore lines beginning with a '#' and will replace any single or
+    double quotes contained within the value
     :param os_env_filename: The name of the OS environment file to read
     :return: a dictionary
     """
     if os_env_filename:
-        logger.info('Attempting to read OS environment file - ' + os_env_filename)
+        logger.info('Attempting to read OS environment file - %s',
+                    os_env_filename)
         out = {}
-        for line in open(os_env_filename):
-            line = line.lstrip()
-            if not line.startswith('#') and line.startswith('export '):
-                line = line.lstrip('export ').strip()
-                tokens = line.split('=')
-                if len(tokens) > 1:
-                    # Remove leading and trailing ' & " characters from value
-                    out[tokens[0]] = tokens[1].lstrip('\'').lstrip('\"').rstrip('\'').rstrip('\"')
+        env_file = None
+        try:
+            env_file = open(os_env_filename)
+            for line in env_file:
+                line = line.lstrip()
+                if not line.startswith('#') and line.startswith('export '):
+                    line = line.lstrip('export ').strip()
+                    tokens = line.split('=')
+                    if len(tokens) > 1:
+                        # Remove leading and trailing ' & " characters from
+                        # value
+                        out[tokens[0]] = tokens[1].lstrip('\'').lstrip('\"').rstrip('\'').rstrip('\"')
+        finally:
+            if env_file:
+                env_file.close()
+        return out
+
+
+def read_file(filename):
+    """
+    Returns the contents of a file as a string
+    :param filename: the name of the file
+    :return:
+    """
+    out = str()
+    the_file = None
+    try:
+        the_file = open(filename)
+        for line in the_file:
+            out += line
         return out
+    finally:
+        if the_file:
+            the_file.close()