Added a check for cloud-init completion before final ssh client check.
[snaps.git] / snaps / file_utils.py
index ff2f1b3..284ae15 100644 (file)
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import ssl
+
 import os
 import logging
+
+from cryptography.hazmat.primitives import serialization
+
 try:
     import urllib.request as urllib
 except ImportError:
@@ -36,10 +41,11 @@ def file_exists(file_path):
     the path is a directory
     :return:
     """
-    if os.path.exists(file_path):
-        if os.path.isdir(file_path):
+    expanded_path = os.path.expanduser(file_path)
+    if os.path.exists(expanded_path):
+        if os.path.isdir(expanded_path):
             return False
-        return os.path.isfile(file_path)
+        return os.path.isfile(expanded_path)
     return False
 
 
@@ -65,7 +71,8 @@ def download(url, dest_path, name=None):
             raise
     try:
         with open(dest, 'wb') as download_file:
-            logger.debug('Saving file to - ' + os.path.abspath(download_file.name))
+            logger.debug('Saving file to - %s',
+                         os.path.abspath(download_file.name))
             response = __get_url_response(url)
             download_file.write(response.read())
         return download_file
@@ -74,6 +81,76 @@ def download(url, dest_path, name=None):
             download_file.close()
 
 
+def save_keys_to_files(keys=None, pub_file_path=None, priv_file_path=None):
+    """
+    Saves the generated RSA generated keys to the filesystem
+    :param keys: the keys to save generated by cryptography
+    :param pub_file_path: the path to the public keys
+    :param priv_file_path: the path to the private keys
+    """
+    if keys:
+        if pub_file_path:
+            # To support '~'
+            pub_expand_file = os.path.expanduser(pub_file_path)
+            pub_dir = os.path.dirname(pub_expand_file)
+
+            if not os.path.isdir(pub_dir):
+                os.mkdir(pub_dir)
+
+            public_handle = None
+            try:
+                public_handle = open(pub_expand_file, 'wb')
+                public_bytes = keys.public_key().public_bytes(
+                    serialization.Encoding.OpenSSH,
+                    serialization.PublicFormat.OpenSSH)
+                public_handle.write(public_bytes)
+            finally:
+                if public_handle:
+                    public_handle.close()
+
+            os.chmod(pub_expand_file, 0o400)
+            logger.info("Saved public key to - " + pub_expand_file)
+        if priv_file_path:
+            # To support '~'
+            priv_expand_file = os.path.expanduser(priv_file_path)
+            priv_dir = os.path.dirname(priv_expand_file)
+            if not os.path.isdir(priv_dir):
+                os.mkdir(priv_dir)
+
+            private_handle = None
+            try:
+                private_handle = open(priv_expand_file, 'wb')
+                private_handle.write(
+                    keys.private_bytes(
+                        encoding=serialization.Encoding.PEM,
+                        format=serialization.PrivateFormat.TraditionalOpenSSL,
+                        encryption_algorithm=serialization.NoEncryption()))
+            finally:
+                if private_handle:
+                    private_handle.close()
+
+            os.chmod(priv_expand_file, 0o400)
+            logger.info("Saved private key to - " + priv_expand_file)
+
+
+def save_string_to_file(string, file_path, mode=None):
+    """
+    Stores
+    :param string: the string contents to store
+    :param file_path: the file path to create
+    :param mode: the file's mode
+    :return: the file object
+    """
+    save_file = open(file_path, 'w')
+    try:
+        save_file.write(string)
+        if mode:
+            os.chmod(file_path, mode)
+        return save_file
+    finally:
+        save_file.close()
+
+
 def get_content_length(url):
     """
     Returns the number of bytes to be downloaded from the given URL
@@ -93,7 +170,8 @@ def __get_url_response(url):
     proxy_handler = urllib.ProxyHandler({})
     opener = urllib.build_opener(proxy_handler)
     urllib.install_opener(opener)
-    return urllib.urlopen(url)
+    context = ssl._create_unverified_context()
+    return urllib.urlopen(url, context=context)
 
 
 def read_yaml(config_file_path):
@@ -105,7 +183,7 @@ def read_yaml(config_file_path):
     logger.debug('Attempting to load configuration file - ' + config_file_path)
     config_file = None
     try:
-        with open(config_file_path) as config_file:
+        with open(config_file_path, 'r') as config_file:
             config = yaml.safe_load(config_file)
             logger.info('Loaded configuration')
         return config
@@ -115,6 +193,20 @@ def read_yaml(config_file_path):
             config_file.close()
 
 
+def persist_dict_to_yaml(the_dict, file_name):
+    """
+    Creates a YAML file from a dict
+    :param the_dict: the dictionary to store
+    :param conf_dir: the directory used to store the config file
+    :return: the file object
+    """
+    logger.info('Persisting %s to [%s]', the_dict, file_name)
+    file_path = os.path.expanduser(file_name)
+    yaml_from_dict = yaml.dump(
+        the_dict, default_flow_style=False, default_style='')
+    return save_string_to_file(yaml_from_dict, file_path)
+
+
 def read_os_env_file(os_env_filename):
     """
     Reads the OS environment source file and returns a map of each key/value