From 6c608e7f54c2efb2ac31ad78a7ce92b6ce45bc15 Mon Sep 17 00:00:00 2001 From: CMHopeSunshine <277073121@qq.com> Date: Sat, 16 Sep 2023 15:04:35 +0800 Subject: [PATCH] :sparkles: `xx_or_create` add `fetch_related` arg --- cherry/models/models.py | 13 +++++++++++-- cherry/queryset/queryset.py | 4 ++-- tests/test_query.py | 5 +++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/cherry/models/models.py b/cherry/models/models.py index 3d3f0b4..774c118 100644 --- a/cherry/models/models.py +++ b/cherry/models/models.py @@ -10,6 +10,7 @@ Tuple, Type, TYPE_CHECKING, + Union, ) from typing_extensions import dataclass_transform, Self @@ -27,7 +28,6 @@ ) from cherry.fields.types import get_sqlalchemy_type_from_field from cherry.fields.utils import ( - args_and_kwargs_to_clause_list, classproperty, ) from cherry.meta.meta import init_meta_config, MetaConfig, mix_meta_config @@ -553,10 +553,15 @@ async def get_or_create( cls, *args: Any, defaults: Optional[DictStrAny] = None, + fetch_related: Union[bool, Tuple[Any, ...]] = False, **kwargs: Any, ) -> Tuple[Self, bool]: """select one model with filter condition, if not exist, create one""" queryset = cls.filter(*args, **kwargs) + if fetch_related is True or isinstance(fetch_related, tuple): + queryset = queryset.prefetch_related( + () if fetch_related is True else fetch_related, + ) try: return await queryset.get(), True except NoMatchDataError: @@ -574,12 +579,16 @@ async def update_or_create( cls, *args: Any, defaults: Optional[DictStrAny] = None, + fetch_related: Union[bool, Tuple[Any, ...]] = False, **kwargs: Any, ) -> Tuple[Self, bool]: """update one model with filter condition, if not exist, create one with filter and defaults values""" - clause_list = args_and_kwargs_to_clause_list(cls, args, kwargs) queryset = cls.filter(*args, **kwargs) + if fetch_related is True or isinstance(fetch_related, tuple): + queryset = queryset.prefetch_related( + () if fetch_related is True else fetch_related, + ) try: model = await queryset.get() return await model.update(**(defaults or {})), True diff --git a/cherry/queryset/queryset.py b/cherry/queryset/queryset.py index 0b66554..7fe39dc 100644 --- a/cherry/queryset/queryset.py +++ b/cherry/queryset/queryset.py @@ -345,7 +345,7 @@ async def _fetch_one_related(self, conn: AsyncConnection, now_data: Dict[str, An rfield.related_model.__meta__.table.select().where( getattr( rfield.related_model, - rfield.related_field_name, + rfield.related_field.foreign_key_self_name, ) == now_data[target_field.foreign_key], ), @@ -405,7 +405,7 @@ async def _fetch_many_related( rfield.related_model.__meta__.table.select().where( getattr( rfield.related_model, - rfield.related_field_name, + rfield.related_field.foreign_key_self_name, ).in_(related_values), ), ) diff --git a/tests/test_query.py b/tests/test_query.py index 50dd126..302dcea 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -73,3 +73,8 @@ async def test_query_with_one_to_many(): assert len(school1.students) == 0 await school2.fetch_related(School.students) assert len(school2.students) == 1 + + schools = await School.select_related().all() + assert len(schools) == 2 + assert len(schools[0].students) == 0 + assert len(schools[1].students) == 1