diff --git a/connect_extension_utils/db/models.py b/connect_extension_utils/db/models.py index 3c6a97c..07cc1a9 100644 --- a/connect_extension_utils/db/models.py +++ b/connect_extension_utils/db/models.py @@ -95,18 +95,15 @@ def add_next_with_verbose(self, instance, related_id_field): instance_class = instance.__class__ new_suffix = 0 related_id_value = getattr(instance, related_id_field) + base_qs = self.query(instance_class).filter( + instance_class.__dict__[related_id_field] == related_id_value, + ) if self.query( - self.query(instance_class) - .filter(instance_class.__dict__[related_id_field] == related_id_value) - .exists(), + base_qs.exists(), ).scalar(): - last_obj = ( - self.query(instance_class) - .order_by( - instance_class.id.desc(), - ) - .first() - ) + last_obj = base_qs.order_by( + instance_class.id.desc(), + ).first() _instance_id, suffix = last_obj.id.rsplit("-", 1) new_suffix = int(suffix) + 1 else: @@ -121,19 +118,15 @@ def add_all_with_next_verbose(self, instances, related_id_field): instance_class = first_item.__class__ new_suffix = 0 related_id_value = getattr(first_item, related_id_field) - + base_qs = self.query(instance_class).filter( + instance_class.__dict__[related_id_field] == related_id_value, + ) if self.query( - self.query(instance_class) - .filter(instance_class.__dict__[related_id_field] == related_id_value) - .exists(), + base_qs.exists(), ).scalar(): - last_obj = ( - self.query(instance_class) - .order_by( - instance_class.id.desc(), - ) - .first() - ) + last_obj = base_qs.order_by( + instance_class.id.desc(), + ).first() _instance_id, suffix = last_obj.id.rsplit("-", 1) new_suffix = int(suffix) + 1 else: diff --git a/connect_extension_utils/testing/factories.py b/connect_extension_utils/testing/factories.py index 134511d..bfe8153 100644 --- a/connect_extension_utils/testing/factories.py +++ b/connect_extension_utils/testing/factories.py @@ -38,32 +38,32 @@ class Meta: @classmethod def _save(cls, model_class, session, args, kwargs): - obj = model_class(*args, **kwargs) + save_method = None if cls._meta._is_transactional: + obj = model_class(*args, **kwargs) kwargs['id'] = cls.add_next_with_verbose( model_class, session, obj, cls._meta._related_id_field, ) - return super()._save(model_class, session, args, kwargs) + save_method = factory.alchemy.SQLAlchemyModelFactory.__dict__['_save'] + cls.save_method = save_method or super()._save + return cls.save_method(model_class, session, args, kwargs) @classmethod def add_next_with_verbose(cls, model_class, session, obj, related_id_field): new_suffix = 0 related_id_value = getattr(obj, related_id_field) + base_qs = session.query(model_class).filter( + model_class.__dict__[related_id_field] == related_id_value, + ) if session.query( - session.query(model_class) - .filter(model_class.__dict__[related_id_field] == related_id_value) - .exists(), + base_qs.exists(), ).scalar(): - last_obj = ( - session.query(model_class) - .order_by( - model_class.id.desc(), - ) - .first() - ) + last_obj = base_qs.order_by( + model_class.id.desc(), + ).first() _instance_id, suffix = last_obj.id.rsplit("-", 1) new_suffix = int(suffix) + 1 else: diff --git a/tests/db/test_models.py b/tests/db/test_models.py index 5c2aec8..7c66561 100644 --- a/tests/db/test_models.py +++ b/tests/db/test_models.py @@ -39,11 +39,10 @@ def test_add_verbose_bulk(dbsession): def test_add_with_next_verbose(dbsession): - obj = MyModel( - name='Foo', - created_by='Jony', - ) + obj = MyModel(name='Foo', created_by='Jony') + obj_2 = MyModel(name='Bar', created_by='Neri') dbsession.add_with_verbose(obj) + dbsession.add_with_verbose(obj_2) dbsession.commit() trx_obj = TransactionalModel( my_model_id=obj.id, @@ -55,10 +54,20 @@ def test_add_with_next_verbose(dbsession): ) dbsession.add_next_with_verbose(trx_obj_2, related_id_field='my_model_id') dbsession.commit() + trx_obj_3 = TransactionalModel( + my_model_id=obj_2.id, + ) + dbsession.add_next_with_verbose(trx_obj_3, related_id_field='my_model_id') + dbsession.commit() assert trx_obj.id.startswith(TransactionalModel.PREFIX) assert trx_obj.id.endswith('000') + assert obj.id.split('-', 1)[-1] in trx_obj.id assert trx_obj_2.id.startswith(TransactionalModel.PREFIX) assert trx_obj_2.id.endswith('001') + assert obj.id.split('-', 1)[-1] in trx_obj_2.id + assert trx_obj_3.id.startswith(TransactionalModel.PREFIX) + assert trx_obj_3.id.endswith('000') + assert obj_2.id.split('-', 1)[-1] in trx_obj_3.id def test_add_with_next_verbose_bulk(dbsession): @@ -66,13 +75,22 @@ def test_add_with_next_verbose_bulk(dbsession): name='Foo', created_by='Jony', ) - dbsession.add_with_verbose(m_obj) + m_obj2 = MyModel(name='Bar', created_by='Neri') + dbsession.add_all_with_verbose([m_obj, m_obj2]) dbsession.commit() + + trx_1 = TransactionalModel(my_model_id=m_obj.id) + dbsession.add_next_with_verbose( + trx_1, + related_id_field='my_model_id', + ) + dbsession.commit() + instances = [] for _ in range(3): instances.append( TransactionalModel( - my_model_id=m_obj.id, + my_model_id=m_obj2.id, ), ) dbsession.add_all_with_next_verbose(instances, related_id_field='my_model_id') @@ -80,12 +98,14 @@ def test_add_with_next_verbose_bulk(dbsession): for idx, obj in enumerate(instances): assert obj.id.startswith(TransactionalModel.PREFIX) assert obj.id.endswith(f'00{idx}') + assert m_obj2.id.split('-', 1)[-1] in obj.id - new_obj = TransactionalModel(my_model_id=m_obj.id) - dbsession.add_all_with_next_verbose([new_obj], related_id_field='my_model_id') + new_trx_obj = TransactionalModel(my_model_id=m_obj.id) + dbsession.add_all_with_next_verbose([new_trx_obj], related_id_field='my_model_id') dbsession.commit() - assert new_obj.id.startswith(TransactionalModel.PREFIX) - assert new_obj.id.endswith('003') + assert new_trx_obj.id.startswith(TransactionalModel.PREFIX) + assert new_trx_obj.id.endswith('001') + assert m_obj.id.split('-', 1)[-1] in new_trx_obj.id def test_add_with_verbose_bulk_fail_instances_not_same_class(dbsession): diff --git a/tests/testing/test_factories.py b/tests/testing/test_factories.py index 9e94b57..19e4a95 100644 --- a/tests/testing/test_factories.py +++ b/tests/testing/test_factories.py @@ -4,10 +4,12 @@ def test_model_factory(my_model_factory): assert obj.name.startswith("My Model") -def test_related_model_factory(my_model_factory, related_model_factory): +def test_related_model_factory(my_model_factory, related_model_factory, dbsession): rel_obj = related_model_factory() assert rel_obj.id.startswith(related_model_factory._meta.model.PREFIX) assert rel_obj.my_model_id.startswith(my_model_factory._meta.model.PREFIX) + assert dbsession.query(related_model_factory._meta.model).count() == 1 + assert dbsession.query(my_model_factory._meta.model).count() == 1 def test_transactional_model_factory( @@ -24,3 +26,5 @@ def test_transactional_model_factory( _, body = base.split("-", 1) assert trx_obj.my_model_id == f"{my_model_factory._meta.model.PREFIX}-{body}" assert id_suffix == f"00{suffix}" + assert dbsession.query(my_model_factory._meta.model).count() == 1 + assert dbsession.query(transactional_model_factory._meta.model).count() == 3