From 27cca3008add749cde118587e5bddb0d13b5fbc6 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Wed, 4 Dec 2024 06:53:06 +0000 Subject: [PATCH] :bug: fix reject target error --- nonebot/internal/matcher/matcher.py | 18 +++++++++++---- tests/test_param.py | 35 ++++++++++++++++++++++------- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/nonebot/internal/matcher/matcher.py b/nonebot/internal/matcher/matcher.py index f539679a86a6..164b7b89edcc 100644 --- a/nonebot/internal/matcher/matcher.py +++ b/nonebot/internal/matcher/matcher.py @@ -656,8 +656,13 @@ async def reject_arg( 请参考对应 adapter 的 bot 对象 api """ matcher = current_matcher.get() - matcher.set_target(ARG_KEY.format(key=key)) - await cls.reject(prompt, **kwargs) + arg_key = ARG_KEY.format(key=key) + matcher.set_target(arg_key) + + if prompt is not None: + result = await cls.send(prompt, **kwargs) + matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=arg_key)] = result + raise RejectedException @classmethod async def reject_receive( @@ -676,8 +681,13 @@ async def reject_receive( 请参考对应 adapter 的 bot 对象 api """ matcher = current_matcher.get() - matcher.set_target(RECEIVE_KEY.format(id=id)) - await cls.reject(prompt, **kwargs) + receive_key = RECEIVE_KEY.format(id=id) + matcher.set_target(receive_key) + + if prompt is not None: + result = await cls.send(prompt, **kwargs) + matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=receive_key)] = result + raise RejectedException @classmethod def skip(cls) -> NoReturn: diff --git a/tests/test_param.py b/tests/test_param.py index 5583f8e7c66d..1e1bab62be31 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -1,3 +1,4 @@ +from contextlib import suppress import re from exceptiongroup import BaseExceptionGroup @@ -13,18 +14,16 @@ ENDSWITH_KEY, FULLMATCH_KEY, KEYWORD_KEY, - PAUSE_PROMPT_RESULT_KEY, PREFIX_KEY, RAW_CMD_KEY, RECEIVE_KEY, REGEX_MATCHED, - REJECT_PROMPT_RESULT_KEY, SHELL_ARGS, SHELL_ARGV, STARTSWITH_KEY, ) from nonebot.dependencies import Dependent -from nonebot.exception import TypeMisMatch +from nonebot.exception import PausedException, RejectedException, TypeMisMatch from nonebot.matcher import Matcher from nonebot.params import ( ArgParam, @@ -544,9 +543,14 @@ async def test_matcher(app: App): ctx.pass_params(matcher=fake_matcher) ctx.should_return(event_next) - fake_matcher.state[ - REJECT_PROMPT_RESULT_KEY.format(key=RECEIVE_KEY.format(id="test")) - ] = True + fake_matcher.set_target(RECEIVE_KEY.format(id="test"), cache=False) + + async with app.test_api() as ctx: + bot = ctx.create_bot() + ctx.should_call_send(event, "test", result=True, bot=bot) + with fake_matcher.ensure_context(bot, event): + with suppress(RejectedException): + await fake_matcher.reject("test") async with app.test_dependent( receive_prompt_result, allow_types=[MatcherParam, DependParam] @@ -554,7 +558,13 @@ async def test_matcher(app: App): ctx.pass_params(matcher=fake_matcher) ctx.should_return(True) - fake_matcher.state[PAUSE_PROMPT_RESULT_KEY] = True + async with app.test_api() as ctx: + bot = ctx.create_bot() + ctx.should_call_send(event, "test", result=True, bot=bot) + with fake_matcher.ensure_context(bot, event): + fake_matcher.set_target("test") + with suppress(PausedException): + await fake_matcher.pause("test") async with app.test_dependent( pause_prompt_result, allow_types=[MatcherParam, DependParam] @@ -578,9 +588,9 @@ async def test_arg(app: App): ) matcher = Matcher() + event = make_fake_event()() message = FakeMessage("text") matcher.set_arg("key", message) - matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=ARG_KEY.format(key="key"))] = True async with app.test_dependent(arg, allow_types=[ArgParam]) as ctx: ctx.pass_params(matcher=matcher) @@ -608,6 +618,15 @@ async def test_arg(app: App): ctx.pass_params(matcher=matcher) ctx.should_return(message.extract_plain_text()) + matcher.set_target(ARG_KEY.format(key="key"), cache=False) + + async with app.test_api() as ctx: + bot = ctx.create_bot() + ctx.should_call_send(event, "test", result=True, bot=bot) + with matcher.ensure_context(bot, event): + with suppress(RejectedException): + await matcher.reject("test") + async with app.test_dependent( annotated_arg_prompt_result, allow_types=[ArgParam] ) as ctx: