Merge "Create API to update openrc"
[functest.git] / functest / api / server.py
index e246333..1d47b0d 100644 (file)
@@ -12,6 +12,7 @@ Used to launch Functest RestApi
 
 """
 
+import inspect
 import logging
 import socket
 from urlparse import urljoin
@@ -21,12 +22,28 @@ from flask import Flask
 from flask_restful import Api
 
 from functest.api.base import ApiResource
-from functest.api.urls import URLPATTERNS
 from functest.api.common import api_utils
+from functest.api.database.db import BASE
+from functest.api.database.db import DB_SESSION
+from functest.api.database.db import ENGINE
+from functest.api.database.v1 import models
+from functest.api.urls import URLPATTERNS
 
 
 LOGGER = logging.getLogger(__name__)
 
+APP = Flask(__name__)
+API = Api(APP)
+
+
+@APP.teardown_request
+def shutdown_session(exception=None):  # pylint: disable=unused-argument
+    """
+    To be called at the end of each request whether it is successful
+    or an exception is raised
+    """
+    DB_SESSION.remove()
+
 
 def get_resource(resource_name):
     """ Obtain the required resource according to resource name """
@@ -41,7 +58,7 @@ def get_endpoint(url):
     return urljoin('http://{}:5000'.format(address), url)
 
 
-def api_add_resource(api):
+def api_add_resource():
     """
     The resource has multiple URLs and you can pass multiple URLs to the
     add_resource() method on the Api object. Each one will be routed to
@@ -49,19 +66,38 @@ def api_add_resource(api):
     """
     for url_pattern in URLPATTERNS:
         try:
-            api.add_resource(
+            API.add_resource(
                 get_resource(url_pattern.target), url_pattern.url,
                 endpoint=get_endpoint(url_pattern.url))
         except StopIteration:
             LOGGER.error('url resource not found: %s', url_pattern.url)
 
 
+def init_db():
+    """
+    Import all modules here that might define models so that
+    they will be registered properly on the metadata, and then
+    create a database
+    """
+    def func(subcls):
+        """ To check the subclasses of BASE"""
+        try:
+            if issubclass(subcls[1], BASE):
+                return True
+        except TypeError:
+            pass
+        return False
+    # pylint: disable=bad-builtin
+    subclses = filter(func, inspect.getmembers(models, inspect.isclass))
+    LOGGER.debug('Import models: %s', [subcls[1] for subcls in subclses])
+    BASE.metadata.create_all(bind=ENGINE)
+
+
 def main():
     """Entry point"""
     logging.config.fileConfig(pkg_resources.resource_filename(
         'functest', 'ci/logging.ini'))
     LOGGER.info('Starting Functest server')
-    app = Flask(__name__)
-    api = Api(app)
-    api_add_resource(api)
-    app.run(host='0.0.0.0')
+    api_add_resource()
+    init_db()
+    APP.run(host='0.0.0.0')