From 32bc2c314ab0ef2d58db5df22985be90bb372f59 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Thu, 5 Dec 2024 20:55:24 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20Feature:=20=E5=AD=98=E5=82=A8=20ma?= =?UTF-8?q?tcher=20=E5=8F=91=E9=80=81=20prompt=20=E7=9A=84=E7=BB=93?= =?UTF-8?q?=E6=9E=9C=20(#3155)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/consts.py | 4 ++ nonebot/internal/matcher/matcher.py | 40 +++++++++-- nonebot/internal/params.py | 43 ++++++++--- nonebot/params.py | 25 +++++++ tests/plugins/param/param_arg.py | 8 ++- tests/plugins/param/param_matcher.py | 17 ++++- tests/test_param.py | 53 +++++++++++++- website/docs/advanced/dependency.mdx | 103 +++++++++++++++++++++++++++ 8 files changed, 271 insertions(+), 22 deletions(-) diff --git a/nonebot/consts.py b/nonebot/consts.py index 701307d3f18a..0cf4056fe06f 100644 --- a/nonebot/consts.py +++ b/nonebot/consts.py @@ -22,6 +22,10 @@ """当前 `reject` 目标存储 key""" REJECT_CACHE_TARGET: Literal["_next_target"] = "_next_target" """下一个 `reject` 目标存储 key""" +PAUSE_PROMPT_RESULT_KEY: Literal["_pause_result"] = "_pause_result" +"""`pause` prompt 发送结果存储 key""" +REJECT_PROMPT_RESULT_KEY: Literal["_reject_{key}_result"] = "_reject_{key}_result" +"""`reject` prompt 发送结果存储 key""" # used by Rule PREFIX_KEY: Literal["_prefix"] = "_prefix" diff --git a/nonebot/internal/matcher/matcher.py b/nonebot/internal/matcher/matcher.py index 7f18effa5ce2..164b7b89edcc 100644 --- a/nonebot/internal/matcher/matcher.py +++ b/nonebot/internal/matcher/matcher.py @@ -27,8 +27,10 @@ from nonebot.consts import ( ARG_KEY, LAST_RECEIVE_KEY, + PAUSE_PROMPT_RESULT_KEY, RECEIVE_KEY, REJECT_CACHE_TARGET, + REJECT_PROMPT_RESULT_KEY, REJECT_TARGET, ) from nonebot.dependencies import Dependent, Param @@ -560,8 +562,8 @@ async def send( """ bot = current_bot.get() event = current_event.get() - state = current_matcher.get().state if isinstance(message, MessageTemplate): + state = current_matcher.get().state _message = message.format(**state) else: _message = message @@ -597,8 +599,15 @@ async def pause( kwargs: {ref}`nonebot.adapters.Bot.send` 的参数, 请参考对应 adapter 的 bot 对象 api """ + try: + matcher = current_matcher.get() + except Exception: + matcher = None + if prompt is not None: - await cls.send(prompt, **kwargs) + result = await cls.send(prompt, **kwargs) + if matcher is not None: + matcher.state[PAUSE_PROMPT_RESULT_KEY] = result raise PausedException @classmethod @@ -615,8 +624,19 @@ async def reject( kwargs: {ref}`nonebot.adapters.Bot.send` 的参数, 请参考对应 adapter 的 bot 对象 api """ + try: + matcher = current_matcher.get() + key = matcher.get_target() + except Exception: + matcher = None + key = None + + key = REJECT_PROMPT_RESULT_KEY.format(key=key) if key is not None else None + if prompt is not None: - await cls.send(prompt, **kwargs) + result = await cls.send(prompt, **kwargs) + if key is not None and matcher: + matcher.state[key] = result raise RejectedException @classmethod @@ -636,9 +656,12 @@ async def reject_arg( 请参考对应 adapter 的 bot 对象 api """ matcher = current_matcher.get() - matcher.set_target(ARG_KEY.format(key=key)) + arg_key = ARG_KEY.format(key=key) + matcher.set_target(arg_key) + if prompt is not None: - await cls.send(prompt, **kwargs) + result = await cls.send(prompt, **kwargs) + matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=arg_key)] = result raise RejectedException @classmethod @@ -658,9 +681,12 @@ async def reject_receive( 请参考对应 adapter 的 bot 对象 api """ matcher = current_matcher.get() - matcher.set_target(RECEIVE_KEY.format(id=id)) + receive_key = RECEIVE_KEY.format(id=id) + matcher.set_target(receive_key) + if prompt is not None: - await cls.send(prompt, **kwargs) + result = await cls.send(prompt, **kwargs) + matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=receive_key)] = result raise RejectedException @classmethod diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index 89d11990aa21..86f776d63ee2 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -18,6 +18,7 @@ from pydantic.fields import FieldInfo as PydanticFieldInfo from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info +from nonebot.consts import ARG_KEY, REJECT_PROMPT_RESULT_KEY from nonebot.dependencies import Dependent, Param from nonebot.dependencies.utils import check_field_type from nonebot.exception import SkippedException @@ -39,7 +40,7 @@ ) if TYPE_CHECKING: - from nonebot.adapters import Bot, Event + from nonebot.adapters import Bot, Event, Message from nonebot.matcher import Matcher @@ -522,10 +523,10 @@ async def _check( # pyright: ignore[reportIncompatibleMethodOverride] class ArgInner: def __init__( - self, key: Optional[str], type: Literal["message", "str", "plaintext"] + self, key: Optional[str], type: Literal["message", "str", "plaintext", "prompt"] ) -> None: self.key: Optional[str] = key - self.type: Literal["message", "str", "plaintext"] = type + self.type: Literal["message", "str", "plaintext", "prompt"] = type def __repr__(self) -> str: return f"ArgInner(key={self.key!r}, type={self.type!r})" @@ -546,6 +547,11 @@ def ArgPlainText(key: Optional[str] = None) -> str: return ArgInner(key, "plaintext") # type: ignore +def ArgPromptResult(key: Optional[str] = None) -> Any: + """`arg` prompt 发送结果""" + return ArgInner(key, "prompt") + + class ArgParam(Param): """Arg 注入参数 @@ -559,7 +565,7 @@ def __init__( self, *args, key: str, - type: Literal["message", "str", "plaintext"], + type: Literal["message", "str", "plaintext", "prompt"], **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) @@ -584,15 +590,32 @@ def _check_param( async def _solve( # pyright: ignore[reportIncompatibleMethodOverride] self, matcher: "Matcher", **kwargs: Any ) -> Any: - message = matcher.get_arg(self.key) - if message is None: - return message if self.type == "message": - return message + return self._solve_message(matcher) elif self.type == "str": - return str(message) + return self._solve_str(matcher) + elif self.type == "plaintext": + return self._solve_plaintext(matcher) + elif self.type == "prompt": + return self._solve_prompt(matcher) else: - return message.extract_plain_text() + raise ValueError(f"Unknown Arg type: {self.type}") + + def _solve_message(self, matcher: "Matcher") -> Optional["Message"]: + return matcher.get_arg(self.key) + + def _solve_str(self, matcher: "Matcher") -> Optional[str]: + message = matcher.get_arg(self.key) + return str(message) if message is not None else None + + def _solve_plaintext(self, matcher: "Matcher") -> Optional[str]: + message = matcher.get_arg(self.key) + return message.extract_plain_text() if message is not None else None + + def _solve_prompt(self, matcher: "Matcher") -> Optional[Any]: + return matcher.state.get( + REJECT_PROMPT_RESULT_KEY.format(key=ARG_KEY.format(key=self.key)) + ) class ExceptionParam(Param): diff --git a/nonebot/params.py b/nonebot/params.py index a4400e7d5132..c25010511178 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -19,9 +19,12 @@ ENDSWITH_KEY, FULLMATCH_KEY, KEYWORD_KEY, + PAUSE_PROMPT_RESULT_KEY, PREFIX_KEY, RAW_CMD_KEY, + RECEIVE_KEY, REGEX_MATCHED, + REJECT_PROMPT_RESULT_KEY, SHELL_ARGS, SHELL_ARGV, STARTSWITH_KEY, @@ -29,6 +32,7 @@ from nonebot.internal.params import Arg as Arg from nonebot.internal.params import ArgParam as ArgParam from nonebot.internal.params import ArgPlainText as ArgPlainText +from nonebot.internal.params import ArgPromptResult as ArgPromptResult from nonebot.internal.params import ArgStr as ArgStr from nonebot.internal.params import BotParam as BotParam from nonebot.internal.params import DefaultParam as DefaultParam @@ -252,6 +256,26 @@ def _last_received(matcher: "Matcher") -> Any: return Depends(_last_received, use_cache=False) +def ReceivePromptResult(id: Optional[str] = None) -> Any: + """`receive` prompt 发送结果""" + + def _receive_prompt_result(matcher: "Matcher") -> Any: + return matcher.state.get( + REJECT_PROMPT_RESULT_KEY.format(key=RECEIVE_KEY.format(id=id)) + ) + + return Depends(_receive_prompt_result, use_cache=False) + + +def PausePromptResult() -> Any: + """`pause` prompt 发送结果""" + + def _pause_prompt_result(matcher: "Matcher") -> Any: + return matcher.state.get(PAUSE_PROMPT_RESULT_KEY) + + return Depends(_pause_prompt_result, use_cache=False) + + __autodoc__ = { "Arg": True, "ArgStr": True, @@ -265,4 +289,5 @@ def _last_received(matcher: "Matcher") -> Any: "DefaultParam": True, "MatcherParam": True, "ExceptionParam": True, + "ArgPromptResult": True, } diff --git a/tests/plugins/param/param_arg.py b/tests/plugins/param/param_arg.py index 6bf64ded3742..c807228cf789 100644 --- a/tests/plugins/param/param_arg.py +++ b/tests/plugins/param/param_arg.py @@ -1,7 +1,7 @@ -from typing import Annotated +from typing import Annotated, Any from nonebot.adapters import Message -from nonebot.params import Arg, ArgPlainText, ArgStr +from nonebot.params import Arg, ArgPlainText, ArgPromptResult, ArgStr async def arg(key: Message = Arg()) -> Message: @@ -28,6 +28,10 @@ async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str: return key +async def annotated_arg_prompt_result(key: Annotated[Any, ArgPromptResult()]) -> Any: + return key + + # test dependency priority async def annotated_prior_arg(key: Annotated[str, ArgStr("foo")] = ArgPlainText()): return key diff --git a/tests/plugins/param/param_matcher.py b/tests/plugins/param/param_matcher.py index 6e8ec2fcab5e..dd90b1f68f54 100644 --- a/tests/plugins/param/param_matcher.py +++ b/tests/plugins/param/param_matcher.py @@ -1,8 +1,13 @@ -from typing import TypeVar, Union +from typing import Any, TypeVar, Union from nonebot.adapters import Event from nonebot.matcher import Matcher -from nonebot.params import LastReceived, Received +from nonebot.params import ( + LastReceived, + PausePromptResult, + Received, + ReceivePromptResult, +) async def matcher(m: Matcher) -> Matcher: @@ -59,3 +64,11 @@ async def receive(e: Event = Received("test")) -> Event: async def last_receive(e: Event = LastReceived()) -> Event: return e + + +async def receive_prompt_result(result: Any = ReceivePromptResult("test")) -> Any: + return result + + +async def pause_prompt_result(result: Any = PausePromptResult()) -> Any: + return result diff --git a/tests/test_param.py b/tests/test_param.py index c2001d56a254..bf561c9ecb87 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -1,3 +1,4 @@ +from contextlib import suppress import re from exceptiongroup import BaseExceptionGroup @@ -5,6 +6,7 @@ import pytest from nonebot.consts import ( + ARG_KEY, CMD_ARG_KEY, CMD_KEY, CMD_START_KEY, @@ -14,13 +16,14 @@ KEYWORD_KEY, PREFIX_KEY, RAW_CMD_KEY, + RECEIVE_KEY, REGEX_MATCHED, SHELL_ARGS, SHELL_ARGV, STARTSWITH_KEY, ) from nonebot.dependencies import Dependent -from nonebot.exception import TypeMisMatch +from nonebot.exception import PausedException, RejectedException, TypeMisMatch from nonebot.matcher import Matcher from nonebot.params import ( ArgParam, @@ -469,8 +472,10 @@ async def test_matcher(app: App): matcher, not_legacy_matcher, not_matcher, + pause_prompt_result, postpone_matcher, receive, + receive_prompt_result, sub_matcher, union_matcher, ) @@ -538,12 +543,42 @@ async def test_matcher(app: App): ctx.pass_params(matcher=fake_matcher) ctx.should_return(event_next) + fake_matcher.set_target(RECEIVE_KEY.format(id="test"), cache=False) + + async with app.test_api() as ctx: + bot = ctx.create_bot() + ctx.should_call_send(event, "test", result=True, bot=bot) + with fake_matcher.ensure_context(bot, event): + with suppress(RejectedException): + await fake_matcher.reject("test") + + async with app.test_dependent( + receive_prompt_result, allow_types=[MatcherParam, DependParam] + ) as ctx: + ctx.pass_params(matcher=fake_matcher) + ctx.should_return(True) + + async with app.test_api() as ctx: + bot = ctx.create_bot() + ctx.should_call_send(event, "test", result=False, bot=bot) + with fake_matcher.ensure_context(bot, event): + fake_matcher.set_target("test") + with suppress(PausedException): + await fake_matcher.pause("test") + + async with app.test_dependent( + pause_prompt_result, allow_types=[MatcherParam, DependParam] + ) as ctx: + ctx.pass_params(matcher=fake_matcher) + ctx.should_return(False) + @pytest.mark.anyio async def test_arg(app: App): from plugins.param.param_arg import ( annotated_arg, annotated_arg_plain_text, + annotated_arg_prompt_result, annotated_arg_str, annotated_multi_arg, annotated_prior_arg, @@ -553,6 +588,7 @@ async def test_arg(app: App): ) matcher = Matcher() + event = make_fake_event()() message = FakeMessage("text") matcher.set_arg("key", message) @@ -582,6 +618,21 @@ async def test_arg(app: App): ctx.pass_params(matcher=matcher) ctx.should_return(message.extract_plain_text()) + matcher.set_target(ARG_KEY.format(key="key"), cache=False) + + async with app.test_api() as ctx: + bot = ctx.create_bot() + ctx.should_call_send(event, "test", result="arg", bot=bot) + with matcher.ensure_context(bot, event): + with suppress(RejectedException): + await matcher.reject("test") + + async with app.test_dependent( + annotated_arg_prompt_result, allow_types=[ArgParam] + ) as ctx: + ctx.pass_params(matcher=matcher) + ctx.should_return("arg") + async with app.test_dependent(annotated_multi_arg, allow_types=[ArgParam]) as ctx: ctx.pass_params(matcher=matcher) ctx.should_return(message.extract_plain_text()) diff --git a/website/docs/advanced/dependency.mdx b/website/docs/advanced/dependency.mdx index 3efc35cd2263..e7946b39c89c 100644 --- a/website/docs/advanced/dependency.mdx +++ b/website/docs/advanced/dependency.mdx @@ -1224,6 +1224,37 @@ async def _(foo: Event = LastReceived()): ... +### ReceivePromptResult + +获取某次 `receive` 发送提示消息的结果。 + + + + +```python {6} +from typing import Any, Annotated + +from nonebot.params import ReceivePromptResult + +@matcher.receive("id", prompt="prompt") +async def _(result: Annotated[Any, ReceivePromptResult("id")]): ... +``` + + + + +```python {6} +from typing import Any + +from nonebot.params import ReceivePromptResult + +@matcher.receive("id", prompt="prompt") +async def _(result: Any = ReceivePromptResult("id")): ... +``` + + + + ### Arg 获取某次 `got` 接收的参数。如果 `Arg` 参数留空,则使用函数的参数名作为要获取的参数。 @@ -1318,3 +1349,75 @@ async def _(foo: str = ArgPlainText("key")): ... + +### ArgPromptResult + +获取某次 `got` 发送提示消息的结果。如果 `Arg` 参数留空,则使用函数的参数名作为要获取的参数。 + + + + +```python {6,7} +from typing import Any, Annotated + +from nonebot.params import ArgPromptResult + +@matcher.got("key", prompt="prompt") +async def _(result: Annotated[Any, ArgPromptResult()]): ... +async def _(result: Annotated[Any, ArgPromptResult("key")]): ... +``` + + + + +```python {6,7} +from typing import Any + +from nonebot.params import ArgPromptResult + +@matcher.got("key", prompt="prompt") +async def _(result: Any = ArgPromptResult()): ... +async def _(result: Any = ArgPromptResult("key")): ... +``` + + + + +### PausePromptResult + +获取最近一次 `pause` 发送提示消息的结果。 + + + + +```python {6} +from typing import Any, Annotated + +from nonebot.params import PausePromptResult + +@matcher.handle() +async def _(): + await matcher.pause(prompt="prompt") + +@matcher.handle() +async def _(result: Annotated[Any, PausePromptResult()]): ... +``` + + + + +```python {6} +from typing import Any + +from nonebot.params import PausePromptResult + +@matcher.handle() +async def _(): + await matcher.pause(prompt="prompt") + +@matcher.handle() +async def _(result: Any = PausePromptResult()): ... +``` + + +