Skip to content

Commit

Permalink
🐛 Fix: 修复结构化并发子依赖取消缓存问题 (#3084)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanyongyu authored Oct 29, 2024
1 parent be732cf commit e3cb4c7
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 4 deletions.
12 changes: 10 additions & 2 deletions nonebot/dependencies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from nonebot.typing import _DependentCallable
from nonebot.exception import SkippedException
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
from nonebot.utils import run_sync, is_coroutine_callable, flatten_exception_group
from nonebot.utils import (
run_sync,
run_coro_with_shield,
is_coroutine_callable,
flatten_exception_group,
)

from .utils import check_field_type, get_typed_signature

Expand Down Expand Up @@ -207,7 +212,10 @@ async def _solve_field(field: ModelField, params: dict[str, Any]) -> None:

async with anyio.create_task_group() as tg:
for field in self.params:
tg.start_soon(_solve_field, field, params)
# shield the task to prevent cancellation
# when one of the tasks raises an exception
# this will improve the dependency cache reusability
tg.start_soon(run_coro_with_shield, _solve_field(field, params))

return result

Expand Down
14 changes: 12 additions & 2 deletions nonebot/internal/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def __init__(self):
self._exception: Optional[BaseException] = None
self._waiter = anyio.Event()

def done(self) -> bool:
return self._state == CacheState.FINISHED

def result(self) -> Any:
"""获取子依赖结果"""

Expand Down Expand Up @@ -304,11 +307,18 @@ def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
dependency_cache[call] = cache = DependencyCache()
try:
result = await target
cache.set_result(result)
return result
except Exception as e:
cache.set_exception(e)
raise
except BaseException as e:
cache.set_exception(e)
# remove cache when base exception occurs
# e.g. CancelledError
dependency_cache.pop(call, None)
raise
else:
cache.set_result(result)
return result

@override
async def _check(self, **kwargs: Any) -> None:
Expand Down
24 changes: 24 additions & 0 deletions tests/plugins/param/param_depend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Annotated
from dataclasses import dataclass

import anyio
from pydantic import Field

from nonebot import on_message
Expand Down Expand Up @@ -105,3 +106,26 @@ async def validate_field(x: int = Depends(lambda: "1", validate=Field(gt=0))):

async def validate_field_fail(x: int = Depends(lambda: "0", validate=Field(gt=0))):
return x


async def _dep():
await anyio.sleep(1)
return 1


def _dep_mismatch():
return 1


async def cache_exception_func1(
dep: int = Depends(_dep),
mismatch: dict = Depends(_dep_mismatch),
):
raise RuntimeError("Never reach here")


async def cache_exception_func2(
dep: int = Depends(_dep),
match: int = Depends(_dep_mismatch),
):
return dep
22 changes: 22 additions & 0 deletions tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ async def test_depend(app: App):
annotated_depend,
sub_type_mismatch,
validate_field_fail,
cache_exception_func1,
cache_exception_func2,
annotated_class_depend,
annotated_multi_depend,
annotated_prior_depend,
Expand Down Expand Up @@ -130,6 +132,26 @@ async def test_depend(app: App):
if isinstance(exc_info.value, BaseExceptionGroup):
assert exc_info.group_contains(TypeMisMatch)

# test cache reuse when exception raised
dependency_cache = {}
with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info:
async with app.test_dependent(
cache_exception_func1, allow_types=[DependParam]
) as ctx:
ctx.pass_params(dependency_cache=dependency_cache)

if isinstance(exc_info.value, BaseExceptionGroup):
assert exc_info.group_contains(TypeMisMatch)

# dependency solve tasks should be shielded even if one of them raises an exception
assert len(dependency_cache) == 2

async with app.test_dependent(
cache_exception_func2, allow_types=[DependParam]
) as ctx:
ctx.pass_params(dependency_cache=dependency_cache)
ctx.should_return(1)


@pytest.mark.anyio
async def test_bot(app: App):
Expand Down

0 comments on commit e3cb4c7

Please sign in to comment.