From f70073d12ee9a1e06769226b04f9516ea604d781 Mon Sep 17 00:00:00 2001 From: Abdullah Darwish Date: Thu, 1 Sep 2022 20:19:46 +0200 Subject: [PATCH] More mongodb utilities (#124) * delete_many(), count() and support params for insert_many() * batch_upsert() * super().publish('batch_upsert', items, **kwargs) * super().publish('delete_many', data, **kwargs) * fix lint issues * batch_delete * fixed tests after merging * removed delete_many Co-authored-by: Paul Cruse III --- syngenta_digital_dta/mongo/adapter.py | 34 +++++++++-- tests/syngenta_digital_dta/mongo/mock_data.py | 31 +++++----- .../syngenta_digital_dta/mongo/test_mongo.py | 57 ++++++++++++++++--- 3 files changed, 95 insertions(+), 27 deletions(-) diff --git a/syngenta_digital_dta/mongo/adapter.py b/syngenta_digital_dta/mongo/adapter.py index c989c11..50c4240 100644 --- a/syngenta_digital_dta/mongo/adapter.py +++ b/syngenta_digital_dta/mongo/adapter.py @@ -1,6 +1,6 @@ from functools import lru_cache -from pymongo import MongoClient +from pymongo import MongoClient, operations from syngenta_digital_dta.common.base_adapter import BaseAdapter from syngenta_digital_dta.common import dict_merger @@ -40,16 +40,29 @@ def create(self, **kwargs): super().publish('create', data, **kwargs) return data - def batch_create(self, **kwargs): + def __map_documents(self, **kwargs): items = [] for item in kwargs['data']: item = schema_mapper.map_to_schema(item, self.__model_schema_file, self.__model_schema) item['_id'] = item[self.__model_identifier] items.append(item) - self.__collection.insert_many(items) - super().publish('batch_create', items, **kwargs) return items + def batch_create(self, **kwargs): + items = self.__map_documents(**kwargs) + insert_result = self.__collection.insert_many(items, **kwargs.get('params', {})) + super().publish('batch_create', items, **kwargs) + return insert_result + + def batch_upsert(self, **kwargs): + items = self.__map_documents(**kwargs) + + bulk_operations = [ + operations.ReplaceOne(filter={'_id': item['_id']}, replacement=item, upsert=True) for item in items + ] + super().publish('batch_upsert', items, **kwargs) + return self.__collection.bulk_write(bulk_operations, **kwargs.get('params', {})) + def read(self, **kwargs): if kwargs.get('operation') == 'query': return self.find(**kwargs) @@ -68,6 +81,9 @@ def find(self, **kwargs): results = self.__collection.find(kwargs['query'], **kwargs.get('params', {})) return list(results) + def count(self, **kwargs): + return self.__collection.count_documents(kwargs.get('query', {}), **kwargs.get('params', {})) + def update(self, **kwargs): original_data = self.find_one(**kwargs) if not original_data: @@ -95,3 +111,13 @@ def delete(self, **kwargs): result = self.__collection.delete_one(kwargs['query']) super().publish('delete', data, **kwargs) return result + + def batch_delete(self, **kwargs): + items = self.__map_documents(**kwargs) + bulk_operations = [] + for item in items: + bulk_operations.append(operations.DeleteOne(filter={'_id': item['_id']})) + + results = self.__collection.bulk_write(bulk_operations, **kwargs.get('params', {})) + super().publish('batch_delete', items, **kwargs) + return results diff --git a/tests/syngenta_digital_dta/mongo/mock_data.py b/tests/syngenta_digital_dta/mongo/mock_data.py index 7c110d6..21d8adb 100644 --- a/tests/syngenta_digital_dta/mongo/mock_data.py +++ b/tests/syngenta_digital_dta/mongo/mock_data.py @@ -21,22 +21,23 @@ def get_standard(): def get_items(): - return [{ - 'test_id': str(uuid.uuid4()), - 'test_query_id': str(uuid.uuid4()), - 'object_key': { - 'string_key': 'nothing' + return [ + { + 'test_id': str(uuid.uuid4()), + 'test_query_id': str(uuid.uuid4()), + 'object_key': { + 'string_key': 'nothing' + }, + 'array_number': [1, 2, 3], + 'array_objects': [ + { + 'array_string_key': 'a', + 'array_number_key': 1 + } + ], + 'created': '2020-10-05', + 'modified': '2020-10-05' }, - 'array_number': [1, 2, 3], - 'array_objects': [ - { - 'array_string_key': 'a', - 'array_number_key': 1 - } - ], - 'created': '2020-10-05', - 'modified': '2020-10-05' - }, { 'test_id': str(uuid.uuid4()), 'test_query_id': str(uuid.uuid4()), diff --git a/tests/syngenta_digital_dta/mongo/test_mongo.py b/tests/syngenta_digital_dta/mongo/test_mongo.py index bf70501..9cf9aeb 100644 --- a/tests/syngenta_digital_dta/mongo/test_mongo.py +++ b/tests/syngenta_digital_dta/mongo/test_mongo.py @@ -37,13 +37,30 @@ def test_create_succeed(self): def test_batch_create_succeed(self): data = mock_data.get_items() - result = self.adapter.batch_create(data=data) - for item in result: - item.pop('_id') - self.assertListEqual(result, data) - for item in result: + insert_result = self.adapter.batch_create(data=data) + + for item in data: self.adapter.delete(query={'test_id': item['test_id']}) + self.assertEqual(len(insert_result.inserted_ids), len(data)) + + def test_batch_upsert_succeed(self): + data = mock_data.get_items() + + insert_result = self.adapter.batch_create(data=data) + for item in data: + item['test_query_id'] = 'update_query_id' + + batch_upsert_result = self.adapter.batch_upsert(data=data) + + results = self.adapter.find(query={'test_query_id': 'update_query_id'}) + for item in data: + self.adapter.delete(query={'test_id': item['test_id']}) + + affected_documents_count = batch_upsert_result.inserted_count + batch_upsert_result.modified_count + batch_upsert_result.upserted_count + + self.assertTrue(len(results) == len(data) and affected_documents_count == len(data)) + def test_create_fail_non_unique(self): data = mock_data.get_standard() data['test_id'] = 'fail-non-unique' @@ -98,11 +115,35 @@ def test_read_many_pagination(self): data['test_query_id'] = 'some-query' self.adapter.create(data=data) count += 1 - results = self.adapter.read(query={'test_query_id': 'some-query'}, operation='query', params={'skip': 5, 'limit': 5}) - for result in results: - self.adapter.delete(query={'test_id': result['test_id']}) # clean up + results = self.adapter.read(query={'test_query_id': 'some-query'}, operation='query', + params={'skip': 5, 'limit': 5}) + result = self.adapter._MongoAdapter__collection.delete_many({'test_query_id': 'some-query'}) # clean up self.assertEqual(len(results), 5) + def test_count(self): + data = mock_data.get_items() + count_before = self.adapter.count() + self.adapter.batch_create(data=data) + count_after = self.adapter.count() + + # cleanup + for item in data: + self.adapter.delete(query={'test_id': item['test_id']}) + + self.assertEqual(count_after - count_before, 2) + + def test_count_query(self): + data = mock_data.get_items() + self.adapter.batch_create(data=data) + test_id = data[0]['test_id'] + count = self.adapter.count(query={'test_id': test_id}) + + # cleanup + for item in data: + self.adapter.delete(query={'test_id': item['test_id']}) + + self.assertEqual(count, 1) + def test_update_success(self): data = mock_data.get_standard() self.adapter.create(data=data)