diff --git a/syngenta_digital_dta/mongo/adapter.py b/syngenta_digital_dta/mongo/adapter.py index 50c4240..0da3494 100644 --- a/syngenta_digital_dta/mongo/adapter.py +++ b/syngenta_digital_dta/mongo/adapter.py @@ -55,13 +55,23 @@ def batch_create(self, **kwargs): return insert_result def batch_upsert(self, **kwargs): - items = self.__map_documents(**kwargs) + data = kwargs['data'] + batch_size = kwargs.get('batch_size', 25) + + if not isinstance(data, list): + raise Exception('Batched data must be contained within a list') + + batched_data = (data[pos:pos + batch_size] for pos in range(0, len(data), batch_size)) + results = [] + for items in batched_data: + bulk_operations = [ + operations.ReplaceOne(filter={'_id': item[self.__model_identifier]}, replacement=item, upsert=True) for item in items + ] + batch_results = self.__collection.bulk_write(bulk_operations, **kwargs.get('params', {})) + results.append(batch_results) + super().publish('batch_upsert', items, **kwargs) - bulk_operations = [ - operations.ReplaceOne(filter={'_id': item['_id']}, replacement=item, upsert=True) for item in items - ] - super().publish('batch_upsert', items, **kwargs) - return self.__collection.bulk_write(bulk_operations, **kwargs.get('params', {})) + return results def read(self, **kwargs): if kwargs.get('operation') == 'query': @@ -70,7 +80,8 @@ def read(self, **kwargs): def query(self, **kwargs): if kwargs['operation'] not in self.__allowed_queries: - raise Exception('query method is for read-only operations; please use another function for destructive operations') + raise Exception( + 'query method is for read-only operations; please use another function for destructive operations') query = getattr(self.__collection, kwargs['operation']) return query(kwargs['query']) diff --git a/tests/syngenta_digital_dta/mongo/test_mongo.py b/tests/syngenta_digital_dta/mongo/test_mongo.py index 9cf9aeb..6f7c944 100644 --- a/tests/syngenta_digital_dta/mongo/test_mongo.py +++ b/tests/syngenta_digital_dta/mongo/test_mongo.py @@ -57,7 +57,7 @@ def test_batch_upsert_succeed(self): for item in data: self.adapter.delete(query={'test_id': item['test_id']}) - affected_documents_count = batch_upsert_result.inserted_count + batch_upsert_result.modified_count + batch_upsert_result.upserted_count + affected_documents_count = batch_upsert_result[0].inserted_count + batch_upsert_result[0].modified_count + batch_upsert_result[0].upserted_count self.assertTrue(len(results) == len(data) and affected_documents_count == len(data))