From 227729c7e0f0ffc17bd596c2d6d20e68c7cb34c2 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Mon, 2 Dec 2024 13:19:15 +0000 Subject: [PATCH 1/3] :sparkles: store prompt result --- nonebot/consts.py | 4 ++ nonebot/internal/matcher/matcher.py | 34 ++++++--- nonebot/internal/params.py | 43 ++++++++--- nonebot/params.py | 25 +++++++ tests/plugins/param/param_arg.py | 8 ++- tests/plugins/param/param_matcher.py | 17 ++++- tests/test_param.py | 32 +++++++++ website/docs/advanced/dependency.mdx | 103 +++++++++++++++++++++++++++ 8 files changed, 243 insertions(+), 23 deletions(-) diff --git a/nonebot/consts.py b/nonebot/consts.py index 701307d3f18a..2b7a88fe9050 100644 --- a/nonebot/consts.py +++ b/nonebot/consts.py @@ -22,6 +22,10 @@ """当前 `reject` 目标存储 key""" REJECT_CACHE_TARGET: Literal["_next_target"] = "_next_target" """下一个 `reject` 目标存储 key""" +PAUSE_PROMPT_RESULT_KEY: Literal["_pause_{key}_result"] = "_pause_{key}_result" +"""`pause` prompt 发送结果存储 key""" +REJECT_PROMPT_RESULT_KEY: Literal["_reject_{key}_result"] = "_reject_{key}_result" +"""`reject` prompt 发送结果存储 key""" # used by Rule PREFIX_KEY: Literal["_prefix"] = "_prefix" diff --git a/nonebot/internal/matcher/matcher.py b/nonebot/internal/matcher/matcher.py index 7f18effa5ce2..f539679a86a6 100644 --- a/nonebot/internal/matcher/matcher.py +++ b/nonebot/internal/matcher/matcher.py @@ -27,8 +27,10 @@ from nonebot.consts import ( ARG_KEY, LAST_RECEIVE_KEY, + PAUSE_PROMPT_RESULT_KEY, RECEIVE_KEY, REJECT_CACHE_TARGET, + REJECT_PROMPT_RESULT_KEY, REJECT_TARGET, ) from nonebot.dependencies import Dependent, Param @@ -560,8 +562,8 @@ async def send( """ bot = current_bot.get() event = current_event.get() - state = current_matcher.get().state if isinstance(message, MessageTemplate): + state = current_matcher.get().state _message = message.format(**state) else: _message = message @@ -597,8 +599,15 @@ async def pause( kwargs: {ref}`nonebot.adapters.Bot.send` 的参数, 请参考对应 adapter 的 bot 对象 api """ + try: + matcher = current_matcher.get() + except Exception: + matcher = None + if prompt is not None: - await cls.send(prompt, **kwargs) + result = await cls.send(prompt, **kwargs) + if matcher is not None: + matcher.state[PAUSE_PROMPT_RESULT_KEY] = result raise PausedException @classmethod @@ -615,8 +624,19 @@ async def reject( kwargs: {ref}`nonebot.adapters.Bot.send` 的参数, 请参考对应 adapter 的 bot 对象 api """ + try: + matcher = current_matcher.get() + key = matcher.get_target() + except Exception: + matcher = None + key = None + + key = REJECT_PROMPT_RESULT_KEY.format(key=key) if key is not None else None + if prompt is not None: - await cls.send(prompt, **kwargs) + result = await cls.send(prompt, **kwargs) + if key is not None and matcher: + matcher.state[key] = result raise RejectedException @classmethod @@ -637,9 +657,7 @@ async def reject_arg( """ matcher = current_matcher.get() matcher.set_target(ARG_KEY.format(key=key)) - if prompt is not None: - await cls.send(prompt, **kwargs) - raise RejectedException + await cls.reject(prompt, **kwargs) @classmethod async def reject_receive( @@ -659,9 +677,7 @@ async def reject_receive( """ matcher = current_matcher.get() matcher.set_target(RECEIVE_KEY.format(id=id)) - if prompt is not None: - await cls.send(prompt, **kwargs) - raise RejectedException + await cls.reject(prompt, **kwargs) @classmethod def skip(cls) -> NoReturn: diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index 89d11990aa21..86f776d63ee2 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -18,6 +18,7 @@ from pydantic.fields import FieldInfo as PydanticFieldInfo from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info +from nonebot.consts import ARG_KEY, REJECT_PROMPT_RESULT_KEY from nonebot.dependencies import Dependent, Param from nonebot.dependencies.utils import check_field_type from nonebot.exception import SkippedException @@ -39,7 +40,7 @@ ) if TYPE_CHECKING: - from nonebot.adapters import Bot, Event + from nonebot.adapters import Bot, Event, Message from nonebot.matcher import Matcher @@ -522,10 +523,10 @@ async def _check( # pyright: ignore[reportIncompatibleMethodOverride] class ArgInner: def __init__( - self, key: Optional[str], type: Literal["message", "str", "plaintext"] + self, key: Optional[str], type: Literal["message", "str", "plaintext", "prompt"] ) -> None: self.key: Optional[str] = key - self.type: Literal["message", "str", "plaintext"] = type + self.type: Literal["message", "str", "plaintext", "prompt"] = type def __repr__(self) -> str: return f"ArgInner(key={self.key!r}, type={self.type!r})" @@ -546,6 +547,11 @@ def ArgPlainText(key: Optional[str] = None) -> str: return ArgInner(key, "plaintext") # type: ignore +def ArgPromptResult(key: Optional[str] = None) -> Any: + """`arg` prompt 发送结果""" + return ArgInner(key, "prompt") + + class ArgParam(Param): """Arg 注入参数 @@ -559,7 +565,7 @@ def __init__( self, *args, key: str, - type: Literal["message", "str", "plaintext"], + type: Literal["message", "str", "plaintext", "prompt"], **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) @@ -584,15 +590,32 @@ def _check_param( async def _solve( # pyright: ignore[reportIncompatibleMethodOverride] self, matcher: "Matcher", **kwargs: Any ) -> Any: - message = matcher.get_arg(self.key) - if message is None: - return message if self.type == "message": - return message + return self._solve_message(matcher) elif self.type == "str": - return str(message) + return self._solve_str(matcher) + elif self.type == "plaintext": + return self._solve_plaintext(matcher) + elif self.type == "prompt": + return self._solve_prompt(matcher) else: - return message.extract_plain_text() + raise ValueError(f"Unknown Arg type: {self.type}") + + def _solve_message(self, matcher: "Matcher") -> Optional["Message"]: + return matcher.get_arg(self.key) + + def _solve_str(self, matcher: "Matcher") -> Optional[str]: + message = matcher.get_arg(self.key) + return str(message) if message is not None else None + + def _solve_plaintext(self, matcher: "Matcher") -> Optional[str]: + message = matcher.get_arg(self.key) + return message.extract_plain_text() if message is not None else None + + def _solve_prompt(self, matcher: "Matcher") -> Optional[Any]: + return matcher.state.get( + REJECT_PROMPT_RESULT_KEY.format(key=ARG_KEY.format(key=self.key)) + ) class ExceptionParam(Param): diff --git a/nonebot/params.py b/nonebot/params.py index a4400e7d5132..c25010511178 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -19,9 +19,12 @@ ENDSWITH_KEY, FULLMATCH_KEY, KEYWORD_KEY, + PAUSE_PROMPT_RESULT_KEY, PREFIX_KEY, RAW_CMD_KEY, + RECEIVE_KEY, REGEX_MATCHED, + REJECT_PROMPT_RESULT_KEY, SHELL_ARGS, SHELL_ARGV, STARTSWITH_KEY, @@ -29,6 +32,7 @@ from nonebot.internal.params import Arg as Arg from nonebot.internal.params import ArgParam as ArgParam from nonebot.internal.params import ArgPlainText as ArgPlainText +from nonebot.internal.params import ArgPromptResult as ArgPromptResult from nonebot.internal.params import ArgStr as ArgStr from nonebot.internal.params import BotParam as BotParam from nonebot.internal.params import DefaultParam as DefaultParam @@ -252,6 +256,26 @@ def _last_received(matcher: "Matcher") -> Any: return Depends(_last_received, use_cache=False) +def ReceivePromptResult(id: Optional[str] = None) -> Any: + """`receive` prompt 发送结果""" + + def _receive_prompt_result(matcher: "Matcher") -> Any: + return matcher.state.get( + REJECT_PROMPT_RESULT_KEY.format(key=RECEIVE_KEY.format(id=id)) + ) + + return Depends(_receive_prompt_result, use_cache=False) + + +def PausePromptResult() -> Any: + """`pause` prompt 发送结果""" + + def _pause_prompt_result(matcher: "Matcher") -> Any: + return matcher.state.get(PAUSE_PROMPT_RESULT_KEY) + + return Depends(_pause_prompt_result, use_cache=False) + + __autodoc__ = { "Arg": True, "ArgStr": True, @@ -265,4 +289,5 @@ def _last_received(matcher: "Matcher") -> Any: "DefaultParam": True, "MatcherParam": True, "ExceptionParam": True, + "ArgPromptResult": True, } diff --git a/tests/plugins/param/param_arg.py b/tests/plugins/param/param_arg.py index 6bf64ded3742..c807228cf789 100644 --- a/tests/plugins/param/param_arg.py +++ b/tests/plugins/param/param_arg.py @@ -1,7 +1,7 @@ -from typing import Annotated +from typing import Annotated, Any from nonebot.adapters import Message -from nonebot.params import Arg, ArgPlainText, ArgStr +from nonebot.params import Arg, ArgPlainText, ArgPromptResult, ArgStr async def arg(key: Message = Arg()) -> Message: @@ -28,6 +28,10 @@ async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str: return key +async def annotated_arg_prompt_result(key: Annotated[Any, ArgPromptResult()]) -> Any: + return key + + # test dependency priority async def annotated_prior_arg(key: Annotated[str, ArgStr("foo")] = ArgPlainText()): return key diff --git a/tests/plugins/param/param_matcher.py b/tests/plugins/param/param_matcher.py index 6e8ec2fcab5e..dd90b1f68f54 100644 --- a/tests/plugins/param/param_matcher.py +++ b/tests/plugins/param/param_matcher.py @@ -1,8 +1,13 @@ -from typing import TypeVar, Union +from typing import Any, TypeVar, Union from nonebot.adapters import Event from nonebot.matcher import Matcher -from nonebot.params import LastReceived, Received +from nonebot.params import ( + LastReceived, + PausePromptResult, + Received, + ReceivePromptResult, +) async def matcher(m: Matcher) -> Matcher: @@ -59,3 +64,11 @@ async def receive(e: Event = Received("test")) -> Event: async def last_receive(e: Event = LastReceived()) -> Event: return e + + +async def receive_prompt_result(result: Any = ReceivePromptResult("test")) -> Any: + return result + + +async def pause_prompt_result(result: Any = PausePromptResult()) -> Any: + return result diff --git a/tests/test_param.py b/tests/test_param.py index c2001d56a254..5583f8e7c66d 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -5,6 +5,7 @@ import pytest from nonebot.consts import ( + ARG_KEY, CMD_ARG_KEY, CMD_KEY, CMD_START_KEY, @@ -12,9 +13,12 @@ ENDSWITH_KEY, FULLMATCH_KEY, KEYWORD_KEY, + PAUSE_PROMPT_RESULT_KEY, PREFIX_KEY, RAW_CMD_KEY, + RECEIVE_KEY, REGEX_MATCHED, + REJECT_PROMPT_RESULT_KEY, SHELL_ARGS, SHELL_ARGV, STARTSWITH_KEY, @@ -469,8 +473,10 @@ async def test_matcher(app: App): matcher, not_legacy_matcher, not_matcher, + pause_prompt_result, postpone_matcher, receive, + receive_prompt_result, sub_matcher, union_matcher, ) @@ -538,12 +544,31 @@ async def test_matcher(app: App): ctx.pass_params(matcher=fake_matcher) ctx.should_return(event_next) + fake_matcher.state[ + REJECT_PROMPT_RESULT_KEY.format(key=RECEIVE_KEY.format(id="test")) + ] = True + + async with app.test_dependent( + receive_prompt_result, allow_types=[MatcherParam, DependParam] + ) as ctx: + ctx.pass_params(matcher=fake_matcher) + ctx.should_return(True) + + fake_matcher.state[PAUSE_PROMPT_RESULT_KEY] = True + + async with app.test_dependent( + pause_prompt_result, allow_types=[MatcherParam, DependParam] + ) as ctx: + ctx.pass_params(matcher=fake_matcher) + ctx.should_return(True) + @pytest.mark.anyio async def test_arg(app: App): from plugins.param.param_arg import ( annotated_arg, annotated_arg_plain_text, + annotated_arg_prompt_result, annotated_arg_str, annotated_multi_arg, annotated_prior_arg, @@ -555,6 +580,7 @@ async def test_arg(app: App): matcher = Matcher() message = FakeMessage("text") matcher.set_arg("key", message) + matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=ARG_KEY.format(key="key"))] = True async with app.test_dependent(arg, allow_types=[ArgParam]) as ctx: ctx.pass_params(matcher=matcher) @@ -582,6 +608,12 @@ async def test_arg(app: App): ctx.pass_params(matcher=matcher) ctx.should_return(message.extract_plain_text()) + async with app.test_dependent( + annotated_arg_prompt_result, allow_types=[ArgParam] + ) as ctx: + ctx.pass_params(matcher=matcher) + ctx.should_return(True) + async with app.test_dependent(annotated_multi_arg, allow_types=[ArgParam]) as ctx: ctx.pass_params(matcher=matcher) ctx.should_return(message.extract_plain_text()) diff --git a/website/docs/advanced/dependency.mdx b/website/docs/advanced/dependency.mdx index 3efc35cd2263..e7946b39c89c 100644 --- a/website/docs/advanced/dependency.mdx +++ b/website/docs/advanced/dependency.mdx @@ -1224,6 +1224,37 @@ async def _(foo: Event = LastReceived()): ... +### ReceivePromptResult + +获取某次 `receive` 发送提示消息的结果。 + + + + +```python {6} +from typing import Any, Annotated + +from nonebot.params import ReceivePromptResult + +@matcher.receive("id", prompt="prompt") +async def _(result: Annotated[Any, ReceivePromptResult("id")]): ... +``` + + + + +```python {6} +from typing import Any + +from nonebot.params import ReceivePromptResult + +@matcher.receive("id", prompt="prompt") +async def _(result: Any = ReceivePromptResult("id")): ... +``` + + + + ### Arg 获取某次 `got` 接收的参数。如果 `Arg` 参数留空,则使用函数的参数名作为要获取的参数。 @@ -1318,3 +1349,75 @@ async def _(foo: str = ArgPlainText("key")): ... + +### ArgPromptResult + +获取某次 `got` 发送提示消息的结果。如果 `Arg` 参数留空,则使用函数的参数名作为要获取的参数。 + + + + +```python {6,7} +from typing import Any, Annotated + +from nonebot.params import ArgPromptResult + +@matcher.got("key", prompt="prompt") +async def _(result: Annotated[Any, ArgPromptResult()]): ... +async def _(result: Annotated[Any, ArgPromptResult("key")]): ... +``` + + + + +```python {6,7} +from typing import Any + +from nonebot.params import ArgPromptResult + +@matcher.got("key", prompt="prompt") +async def _(result: Any = ArgPromptResult()): ... +async def _(result: Any = ArgPromptResult("key")): ... +``` + + + + +### PausePromptResult + +获取最近一次 `pause` 发送提示消息的结果。 + + + + +```python {6} +from typing import Any, Annotated + +from nonebot.params import PausePromptResult + +@matcher.handle() +async def _(): + await matcher.pause(prompt="prompt") + +@matcher.handle() +async def _(result: Annotated[Any, PausePromptResult()]): ... +``` + + + + +```python {6} +from typing import Any + +from nonebot.params import PausePromptResult + +@matcher.handle() +async def _(): + await matcher.pause(prompt="prompt") + +@matcher.handle() +async def _(result: Any = PausePromptResult()): ... +``` + + + From 27cca3008add749cde118587e5bddb0d13b5fbc6 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Wed, 4 Dec 2024 06:53:06 +0000 Subject: [PATCH 2/3] :bug: fix reject target error --- nonebot/internal/matcher/matcher.py | 18 +++++++++++---- tests/test_param.py | 35 ++++++++++++++++++++++------- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/nonebot/internal/matcher/matcher.py b/nonebot/internal/matcher/matcher.py index f539679a86a6..164b7b89edcc 100644 --- a/nonebot/internal/matcher/matcher.py +++ b/nonebot/internal/matcher/matcher.py @@ -656,8 +656,13 @@ async def reject_arg( 请参考对应 adapter 的 bot 对象 api """ matcher = current_matcher.get() - matcher.set_target(ARG_KEY.format(key=key)) - await cls.reject(prompt, **kwargs) + arg_key = ARG_KEY.format(key=key) + matcher.set_target(arg_key) + + if prompt is not None: + result = await cls.send(prompt, **kwargs) + matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=arg_key)] = result + raise RejectedException @classmethod async def reject_receive( @@ -676,8 +681,13 @@ async def reject_receive( 请参考对应 adapter 的 bot 对象 api """ matcher = current_matcher.get() - matcher.set_target(RECEIVE_KEY.format(id=id)) - await cls.reject(prompt, **kwargs) + receive_key = RECEIVE_KEY.format(id=id) + matcher.set_target(receive_key) + + if prompt is not None: + result = await cls.send(prompt, **kwargs) + matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=receive_key)] = result + raise RejectedException @classmethod def skip(cls) -> NoReturn: diff --git a/tests/test_param.py b/tests/test_param.py index 5583f8e7c66d..1e1bab62be31 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -1,3 +1,4 @@ +from contextlib import suppress import re from exceptiongroup import BaseExceptionGroup @@ -13,18 +14,16 @@ ENDSWITH_KEY, FULLMATCH_KEY, KEYWORD_KEY, - PAUSE_PROMPT_RESULT_KEY, PREFIX_KEY, RAW_CMD_KEY, RECEIVE_KEY, REGEX_MATCHED, - REJECT_PROMPT_RESULT_KEY, SHELL_ARGS, SHELL_ARGV, STARTSWITH_KEY, ) from nonebot.dependencies import Dependent -from nonebot.exception import TypeMisMatch +from nonebot.exception import PausedException, RejectedException, TypeMisMatch from nonebot.matcher import Matcher from nonebot.params import ( ArgParam, @@ -544,9 +543,14 @@ async def test_matcher(app: App): ctx.pass_params(matcher=fake_matcher) ctx.should_return(event_next) - fake_matcher.state[ - REJECT_PROMPT_RESULT_KEY.format(key=RECEIVE_KEY.format(id="test")) - ] = True + fake_matcher.set_target(RECEIVE_KEY.format(id="test"), cache=False) + + async with app.test_api() as ctx: + bot = ctx.create_bot() + ctx.should_call_send(event, "test", result=True, bot=bot) + with fake_matcher.ensure_context(bot, event): + with suppress(RejectedException): + await fake_matcher.reject("test") async with app.test_dependent( receive_prompt_result, allow_types=[MatcherParam, DependParam] @@ -554,7 +558,13 @@ async def test_matcher(app: App): ctx.pass_params(matcher=fake_matcher) ctx.should_return(True) - fake_matcher.state[PAUSE_PROMPT_RESULT_KEY] = True + async with app.test_api() as ctx: + bot = ctx.create_bot() + ctx.should_call_send(event, "test", result=True, bot=bot) + with fake_matcher.ensure_context(bot, event): + fake_matcher.set_target("test") + with suppress(PausedException): + await fake_matcher.pause("test") async with app.test_dependent( pause_prompt_result, allow_types=[MatcherParam, DependParam] @@ -578,9 +588,9 @@ async def test_arg(app: App): ) matcher = Matcher() + event = make_fake_event()() message = FakeMessage("text") matcher.set_arg("key", message) - matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=ARG_KEY.format(key="key"))] = True async with app.test_dependent(arg, allow_types=[ArgParam]) as ctx: ctx.pass_params(matcher=matcher) @@ -608,6 +618,15 @@ async def test_arg(app: App): ctx.pass_params(matcher=matcher) ctx.should_return(message.extract_plain_text()) + matcher.set_target(ARG_KEY.format(key="key"), cache=False) + + async with app.test_api() as ctx: + bot = ctx.create_bot() + ctx.should_call_send(event, "test", result=True, bot=bot) + with matcher.ensure_context(bot, event): + with suppress(RejectedException): + await matcher.reject("test") + async with app.test_dependent( annotated_arg_prompt_result, allow_types=[ArgParam] ) as ctx: From 9f997bfd436da9acc64213ad7d021d1a4eaf2322 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Thu, 5 Dec 2024 20:44:01 +0800 Subject: [PATCH 3/3] :pencil2: improve --- nonebot/consts.py | 2 +- tests/test_param.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/nonebot/consts.py b/nonebot/consts.py index 2b7a88fe9050..0cf4056fe06f 100644 --- a/nonebot/consts.py +++ b/nonebot/consts.py @@ -22,7 +22,7 @@ """当前 `reject` 目标存储 key""" REJECT_CACHE_TARGET: Literal["_next_target"] = "_next_target" """下一个 `reject` 目标存储 key""" -PAUSE_PROMPT_RESULT_KEY: Literal["_pause_{key}_result"] = "_pause_{key}_result" +PAUSE_PROMPT_RESULT_KEY: Literal["_pause_result"] = "_pause_result" """`pause` prompt 发送结果存储 key""" REJECT_PROMPT_RESULT_KEY: Literal["_reject_{key}_result"] = "_reject_{key}_result" """`reject` prompt 发送结果存储 key""" diff --git a/tests/test_param.py b/tests/test_param.py index 1e1bab62be31..bf561c9ecb87 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -560,7 +560,7 @@ async def test_matcher(app: App): async with app.test_api() as ctx: bot = ctx.create_bot() - ctx.should_call_send(event, "test", result=True, bot=bot) + ctx.should_call_send(event, "test", result=False, bot=bot) with fake_matcher.ensure_context(bot, event): fake_matcher.set_target("test") with suppress(PausedException): @@ -570,7 +570,7 @@ async def test_matcher(app: App): pause_prompt_result, allow_types=[MatcherParam, DependParam] ) as ctx: ctx.pass_params(matcher=fake_matcher) - ctx.should_return(True) + ctx.should_return(False) @pytest.mark.anyio @@ -622,7 +622,7 @@ async def test_arg(app: App): async with app.test_api() as ctx: bot = ctx.create_bot() - ctx.should_call_send(event, "test", result=True, bot=bot) + ctx.should_call_send(event, "test", result="arg", bot=bot) with matcher.ensure_context(bot, event): with suppress(RejectedException): await matcher.reject("test") @@ -631,7 +631,7 @@ async def test_arg(app: App): annotated_arg_prompt_result, allow_types=[ArgParam] ) as ctx: ctx.pass_params(matcher=matcher) - ctx.should_return(True) + ctx.should_return("arg") async with app.test_dependent(annotated_multi_arg, allow_types=[ArgParam]) as ctx: ctx.pass_params(matcher=matcher)