From e3cb4c79075c0755cccd354b40bcc22be8bd386a Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Tue, 29 Oct 2024 14:22:41 +0800 Subject: [PATCH] =?UTF-8?q?:bug:=20Fix:=20=E4=BF=AE=E5=A4=8D=E7=BB=93?= =?UTF-8?q?=E6=9E=84=E5=8C=96=E5=B9=B6=E5=8F=91=E5=AD=90=E4=BE=9D=E8=B5=96?= =?UTF-8?q?=E5=8F=96=E6=B6=88=E7=BC=93=E5=AD=98=E9=97=AE=E9=A2=98=20(#3084?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/dependencies/__init__.py | 12 ++++++++++-- nonebot/internal/params.py | 14 ++++++++++++-- tests/plugins/param/param_depend.py | 24 ++++++++++++++++++++++++ tests/test_param.py | 22 ++++++++++++++++++++++ 4 files changed, 68 insertions(+), 4 deletions(-) diff --git a/nonebot/dependencies/__init__.py b/nonebot/dependencies/__init__.py index 1b56089b0730..da02f5313b91 100644 --- a/nonebot/dependencies/__init__.py +++ b/nonebot/dependencies/__init__.py @@ -21,7 +21,12 @@ from nonebot.typing import _DependentCallable from nonebot.exception import SkippedException from nonebot.compat import FieldInfo, ModelField, PydanticUndefined -from nonebot.utils import run_sync, is_coroutine_callable, flatten_exception_group +from nonebot.utils import ( + run_sync, + run_coro_with_shield, + is_coroutine_callable, + flatten_exception_group, +) from .utils import check_field_type, get_typed_signature @@ -207,7 +212,10 @@ async def _solve_field(field: ModelField, params: dict[str, Any]) -> None: async with anyio.create_task_group() as tg: for field in self.params: - tg.start_soon(_solve_field, field, params) + # shield the task to prevent cancellation + # when one of the tasks raises an exception + # this will improve the dependency cache reusability + tg.start_soon(run_coro_with_shield, _solve_field(field, params)) return result diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index 87f7367b0433..9dbe0b4099eb 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -115,6 +115,9 @@ def __init__(self): self._exception: Optional[BaseException] = None self._waiter = anyio.Event() + def done(self) -> bool: + return self._state == CacheState.FINISHED + def result(self) -> Any: """获取子依赖结果""" @@ -304,11 +307,18 @@ def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]): dependency_cache[call] = cache = DependencyCache() try: result = await target - cache.set_result(result) - return result + except Exception as e: + cache.set_exception(e) + raise except BaseException as e: cache.set_exception(e) + # remove cache when base exception occurs + # e.g. CancelledError + dependency_cache.pop(call, None) raise + else: + cache.set_result(result) + return result @override async def _check(self, **kwargs: Any) -> None: diff --git a/tests/plugins/param/param_depend.py b/tests/plugins/param/param_depend.py index 20c058925ef3..6f28677f5e41 100644 --- a/tests/plugins/param/param_depend.py +++ b/tests/plugins/param/param_depend.py @@ -1,6 +1,7 @@ from typing import Annotated from dataclasses import dataclass +import anyio from pydantic import Field from nonebot import on_message @@ -105,3 +106,26 @@ async def validate_field(x: int = Depends(lambda: "1", validate=Field(gt=0))): async def validate_field_fail(x: int = Depends(lambda: "0", validate=Field(gt=0))): return x + + +async def _dep(): + await anyio.sleep(1) + return 1 + + +def _dep_mismatch(): + return 1 + + +async def cache_exception_func1( + dep: int = Depends(_dep), + mismatch: dict = Depends(_dep_mismatch), +): + raise RuntimeError("Never reach here") + + +async def cache_exception_func2( + dep: int = Depends(_dep), + match: int = Depends(_dep_mismatch), +): + return dep diff --git a/tests/test_param.py b/tests/test_param.py index eb00996a2469..cdd9420b7baf 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -51,6 +51,8 @@ async def test_depend(app: App): annotated_depend, sub_type_mismatch, validate_field_fail, + cache_exception_func1, + cache_exception_func2, annotated_class_depend, annotated_multi_depend, annotated_prior_depend, @@ -130,6 +132,26 @@ async def test_depend(app: App): if isinstance(exc_info.value, BaseExceptionGroup): assert exc_info.group_contains(TypeMisMatch) + # test cache reuse when exception raised + dependency_cache = {} + with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info: + async with app.test_dependent( + cache_exception_func1, allow_types=[DependParam] + ) as ctx: + ctx.pass_params(dependency_cache=dependency_cache) + + if isinstance(exc_info.value, BaseExceptionGroup): + assert exc_info.group_contains(TypeMisMatch) + + # dependency solve tasks should be shielded even if one of them raises an exception + assert len(dependency_cache) == 2 + + async with app.test_dependent( + cache_exception_func2, allow_types=[DependParam] + ) as ctx: + ctx.pass_params(dependency_cache=dependency_cache) + ctx.should_return(1) + @pytest.mark.anyio async def test_bot(app: App):