support keys start with '$' or contain '.' in testAPI 91/15591/5
authorSerenaFeng <feng.xiaowei@zte.com.cn>
Wed, 15 Jun 2016 01:49:35 +0000 (09:49 +0800)
committerSerenaFeng <feng.xiaowei@zte.com.cn>
Wed, 15 Jun 2016 08:05:37 +0000 (16:05 +0800)
set check_keys=False in insert and update db
fix update and insert stub in fake_pymongo.py
add unittest for check_keys in test_fake_pymongo.py

JIRA: FUNCTEST-313

Change-Id: I4051ec4a1c70996c87167643f6ea19993f5b0811
Signed-off-by: SerenaFeng <feng.xiaowei@zte.com.cn>
utils/test/result_collection_api/opnfv_testapi/resources/handlers.py
utils/test/result_collection_api/opnfv_testapi/tests/unit/fake_pymongo.py
utils/test/result_collection_api/opnfv_testapi/tests/unit/test_fake_pymongo.py
utils/test/result_collection_api/opnfv_testapi/tests/unit/test_result.py
utils/test/result_collection_api/opnfv_testapi/tests/unit/test_testcase.py

index cc4a4c3..8737011 100644 (file)
@@ -98,7 +98,8 @@ class GenericApiHandler(RequestHandler):
 
         if self.table != 'results':
             data.creation_date = datetime.now()
-        _id = yield self._eval_db(self.table, 'insert', data.format())
+        _id = yield self._eval_db(self.table, 'insert', data.format(),
+                                  check_keys=False)
         if 'name' in self.json_args:
             resource = data.name
         else:
@@ -174,7 +175,8 @@ class GenericApiHandler(RequestHandler):
         edit_request.update(self._update_requests(data))
 
         """ Updating the DB """
-        yield self._eval_db(self.table, 'update', query, edit_request)
+        yield self._eval_db(self.table, 'update', query, edit_request,
+                            check_keys=False)
         edit_request['_id'] = str(data._id)
         self.finish_request(edit_request)
 
@@ -215,8 +217,8 @@ class GenericApiHandler(RequestHandler):
             query[key] = new
         return equal, query
 
-    def _eval_db(self, table, method, *args):
-        return eval('self.db.%s.%s(*args)' % (table, method))
+    def _eval_db(self, table, method, *args, **kwargs):
+        return eval('self.db.%s.%s(*args, **kwargs)' % (table, method))
 
     def _eval_db_find_one(self, query, table=None):
         if table is None:
index ef9c719..6ab98c7 100644 (file)
@@ -80,11 +80,15 @@ class MemDb(object):
             return_one = True
             docs = [docs]
 
+        if check_keys:
+            for doc in docs:
+                self._check_keys(doc)
+
         ids = []
         for doc in docs:
             if '_id' not in doc:
                 doc['_id'] = str(ObjectId())
-            if not check_keys or not self._find_one(doc['_id']):
+            if not self._find_one(doc['_id']):
                 ids.append(doc['_id'])
                 self.contents.append(doc_or_docs)
 
@@ -131,8 +135,12 @@ class MemDb(object):
     def find(self, *args):
         return MemCursor(self._find(*args))
 
-    def _update(self, spec, document):
+    def _update(self, spec, document, check_keys=True):
         updated = False
+
+        if check_keys:
+            self._check_keys(document)
+
         for index in range(len(self.contents)):
             content = self.contents[index]
             if self._in(content, spec):
@@ -142,8 +150,8 @@ class MemDb(object):
             self.contents[index] = content
         return updated
 
-    def update(self, spec, document):
-        return thread_execute(self._update, spec, document)
+    def update(self, spec, document, check_keys=True):
+        return thread_execute(self._update, spec, document, check_keys)
 
     def _remove(self, spec_or_id=None):
         if spec_or_id is None:
@@ -163,6 +171,17 @@ class MemDb(object):
     def clear(self):
         self._remove()
 
+    def _check_keys(self, doc):
+        for key in doc.keys():
+            print('key', key, 'value', doc.get(key))
+            if '.' in key:
+                raise NameError('key {} must not contain .'.format(key))
+            if key.startswith('$'):
+                raise NameError('key {} must not start with $'.format(key))
+            if isinstance(doc.get(key), dict):
+                self._check_keys(doc.get(key))
+
+
 pods = MemDb()
 projects = MemDb()
 testcases = MemDb()
index 9bc311c..27382f0 100644 (file)
@@ -53,25 +53,70 @@ class MyTest(AsyncHTTPTestCase):
         user = yield self.db.pods.find_one({'_id': '1'})
         self.assertEqual(user.get('name', None), 'new_test1')
 
+    def test_update_dot_error(self):
+        self._update_assert({'_id': '1', 'name': {'1. name': 'test1'}},
+                            'key 1. name must not contain .')
+
+    def test_update_dot_no_error(self):
+        self._update_assert({'_id': '1', 'name': {'1. name': 'test1'}},
+                            None,
+                            check_keys=False)
+
+    def test_update_dollar_error(self):
+        self._update_assert({'_id': '1', 'name': {'$name': 'test1'}},
+                            'key $name must not start with $')
+
+    def test_update_dollar_no_error(self):
+        self._update_assert({'_id': '1', 'name': {'$name': 'test1'}},
+                            None,
+                            check_keys=False)
+
     @gen_test
     def test_remove(self):
         yield self.db.pods.remove({'_id': '1'})
         user = yield self.db.pods.find_one({'_id': '1'})
         self.assertIsNone(user)
 
-    @gen_test
-    def test_insert_check_keys(self):
-        yield self.db.pods.insert({'_id': '1', 'name': 'test1'},
-                                  check_keys=False)
-        cursor = self.db.pods.find({'_id': '1'})
-        names = []
-        while (yield cursor.fetch_next):
-            ob = cursor.next_object()
-            names.append(ob.get('name'))
-        self.assertItemsEqual(names, ['test1', 'test1'])
+    def test_insert_dot_error(self):
+        self._insert_assert({'_id': '1', '2. name': 'test1'},
+                            'key 2. name must not contain .')
+
+    def test_insert_dot_no_error(self):
+        self._insert_assert({'_id': '1', '2. name': 'test1'},
+                            None,
+                            check_keys=False)
+
+    def test_insert_dollar_error(self):
+        self._insert_assert({'_id': '1', '$name': 'test1'},
+                            'key $name must not start with $')
+
+    def test_insert_dollar_no_error(self):
+        self._insert_assert({'_id': '1', '$name': 'test1'},
+                            None,
+                            check_keys=False)
 
     def _clear(self):
         self.db.pods.clear()
 
+    def _update_assert(self, docs, error=None, **kwargs):
+        self._db_assert('update', error, {'_id': '1'}, docs, **kwargs)
+
+    def _insert_assert(self, docs, error=None, **kwargs):
+        self._db_assert('insert', error, docs, **kwargs)
+
+    @gen_test
+    def _db_assert(self, method, error, *args, **kwargs):
+        name_error = None
+        try:
+            yield self._eval_pods_db(method, *args, **kwargs)
+        except NameError as err:
+            name_error = err.args[0]
+        finally:
+            self.assertEqual(name_error, error)
+
+    def _eval_pods_db(self, method, *args, **kwargs):
+        return eval('self.db.pods.%s(*args, **kwargs)' % method)
+
+
 if __name__ == '__main__':
     unittest.main()
index dbc4431..bba3b22 100644 (file)
@@ -7,6 +7,7 @@
 # http://www.apache.org/licenses/LICENSE-2.0
 ##############################################################################
 import unittest
+import copy
 
 from opnfv_testapi.common.constants import HTTP_OK, HTTP_BAD_REQUEST, \
     HTTP_NOT_FOUND
@@ -161,6 +162,13 @@ class TestResultCreate(TestResultBase):
         self.assertEqual(code, HTTP_OK)
         self.assert_href(body)
 
+    def test_key_with_doc(self):
+        req = copy.deepcopy(self.req_d)
+        req.details = {'1.name': 'dot_name'}
+        (code, body) = self.create(req)
+        self.assertEqual(code, HTTP_OK)
+        self.assert_href(body)
+
 
 class TestResultGet(TestResultBase):
     def test_getOne(self):
index a145c00..cb76784 100644 (file)
@@ -7,6 +7,7 @@
 # http://www.apache.org/licenses/LICENSE-2.0
 ##############################################################################
 import unittest
+import copy
 
 from test_base import TestBase
 from opnfv_testapi.resources.testcase_models import TestcaseCreateRequest, \
@@ -168,6 +169,13 @@ class TestCaseUpdate(TestCaseBase):
         self.assertEqual(_id, new_body._id)
         self.assert_update_body(self.req_d, new_body, self.update_e)
 
+    def test_with_dollar(self):
+        self.create_d()
+        update = copy.deepcopy(self.update_d)
+        update.description = {'2. change': 'dollar change'}
+        code, body = self.update(update, self.req_d.name)
+        self.assertEqual(code, HTTP_OK)
+
 
 class TestCaseDelete(TestCaseBase):
     def test_notFound(self):