Skip to content

Commit

Permalink
✨ Feature: 存储 matcher 发送 prompt 的结果 (#3155)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanyongyu authored Dec 5, 2024
1 parent ab8dea5 commit 32bc2c3
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 22 deletions.
4 changes: 4 additions & 0 deletions nonebot/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
"""当前 `reject` 目标存储 key"""
REJECT_CACHE_TARGET: Literal["_next_target"] = "_next_target"
"""下一个 `reject` 目标存储 key"""
PAUSE_PROMPT_RESULT_KEY: Literal["_pause_result"] = "_pause_result"
"""`pause` prompt 发送结果存储 key"""
REJECT_PROMPT_RESULT_KEY: Literal["_reject_{key}_result"] = "_reject_{key}_result"
"""`reject` prompt 发送结果存储 key"""

# used by Rule
PREFIX_KEY: Literal["_prefix"] = "_prefix"
Expand Down
40 changes: 33 additions & 7 deletions nonebot/internal/matcher/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
from nonebot.consts import (
ARG_KEY,
LAST_RECEIVE_KEY,
PAUSE_PROMPT_RESULT_KEY,
RECEIVE_KEY,
REJECT_CACHE_TARGET,
REJECT_PROMPT_RESULT_KEY,
REJECT_TARGET,
)
from nonebot.dependencies import Dependent, Param
Expand Down Expand Up @@ -560,8 +562,8 @@ async def send(
"""
bot = current_bot.get()
event = current_event.get()
state = current_matcher.get().state
if isinstance(message, MessageTemplate):
state = current_matcher.get().state
_message = message.format(**state)
else:
_message = message
Expand Down Expand Up @@ -597,8 +599,15 @@ async def pause(
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,
请参考对应 adapter 的 bot 对象 api
"""
try:
matcher = current_matcher.get()
except Exception:
matcher = None

if prompt is not None:
await cls.send(prompt, **kwargs)
result = await cls.send(prompt, **kwargs)
if matcher is not None:
matcher.state[PAUSE_PROMPT_RESULT_KEY] = result
raise PausedException

@classmethod
Expand All @@ -615,8 +624,19 @@ async def reject(
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数,
请参考对应 adapter 的 bot 对象 api
"""
try:
matcher = current_matcher.get()
key = matcher.get_target()
except Exception:
matcher = None
key = None

key = REJECT_PROMPT_RESULT_KEY.format(key=key) if key is not None else None

if prompt is not None:
await cls.send(prompt, **kwargs)
result = await cls.send(prompt, **kwargs)
if key is not None and matcher:
matcher.state[key] = result
raise RejectedException

@classmethod
Expand All @@ -636,9 +656,12 @@ async def reject_arg(
请参考对应 adapter 的 bot 对象 api
"""
matcher = current_matcher.get()
matcher.set_target(ARG_KEY.format(key=key))
arg_key = ARG_KEY.format(key=key)
matcher.set_target(arg_key)

if prompt is not None:
await cls.send(prompt, **kwargs)
result = await cls.send(prompt, **kwargs)
matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=arg_key)] = result
raise RejectedException

@classmethod
Expand All @@ -658,9 +681,12 @@ async def reject_receive(
请参考对应 adapter 的 bot 对象 api
"""
matcher = current_matcher.get()
matcher.set_target(RECEIVE_KEY.format(id=id))
receive_key = RECEIVE_KEY.format(id=id)
matcher.set_target(receive_key)

if prompt is not None:
await cls.send(prompt, **kwargs)
result = await cls.send(prompt, **kwargs)
matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=receive_key)] = result
raise RejectedException

@classmethod
Expand Down
43 changes: 33 additions & 10 deletions nonebot/internal/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pydantic.fields import FieldInfo as PydanticFieldInfo

from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
from nonebot.consts import ARG_KEY, REJECT_PROMPT_RESULT_KEY
from nonebot.dependencies import Dependent, Param
from nonebot.dependencies.utils import check_field_type
from nonebot.exception import SkippedException
Expand All @@ -39,7 +40,7 @@
)

if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
from nonebot.adapters import Bot, Event, Message
from nonebot.matcher import Matcher


Expand Down Expand Up @@ -522,10 +523,10 @@ async def _check( # pyright: ignore[reportIncompatibleMethodOverride]

class ArgInner:
def __init__(
self, key: Optional[str], type: Literal["message", "str", "plaintext"]
self, key: Optional[str], type: Literal["message", "str", "plaintext", "prompt"]
) -> None:
self.key: Optional[str] = key
self.type: Literal["message", "str", "plaintext"] = type
self.type: Literal["message", "str", "plaintext", "prompt"] = type

def __repr__(self) -> str:
return f"ArgInner(key={self.key!r}, type={self.type!r})"
Expand All @@ -546,6 +547,11 @@ def ArgPlainText(key: Optional[str] = None) -> str:
return ArgInner(key, "plaintext") # type: ignore


def ArgPromptResult(key: Optional[str] = None) -> Any:
"""`arg` prompt 发送结果"""
return ArgInner(key, "prompt")


class ArgParam(Param):
"""Arg 注入参数
Expand All @@ -559,7 +565,7 @@ def __init__(
self,
*args,
key: str,
type: Literal["message", "str", "plaintext"],
type: Literal["message", "str", "plaintext", "prompt"],
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
Expand All @@ -584,15 +590,32 @@ def _check_param(
async def _solve( # pyright: ignore[reportIncompatibleMethodOverride]
self, matcher: "Matcher", **kwargs: Any
) -> Any:
message = matcher.get_arg(self.key)
if message is None:
return message
if self.type == "message":
return message
return self._solve_message(matcher)
elif self.type == "str":
return str(message)
return self._solve_str(matcher)
elif self.type == "plaintext":
return self._solve_plaintext(matcher)
elif self.type == "prompt":
return self._solve_prompt(matcher)
else:
return message.extract_plain_text()
raise ValueError(f"Unknown Arg type: {self.type}")

def _solve_message(self, matcher: "Matcher") -> Optional["Message"]:
return matcher.get_arg(self.key)

def _solve_str(self, matcher: "Matcher") -> Optional[str]:
message = matcher.get_arg(self.key)
return str(message) if message is not None else None

def _solve_plaintext(self, matcher: "Matcher") -> Optional[str]:
message = matcher.get_arg(self.key)
return message.extract_plain_text() if message is not None else None

def _solve_prompt(self, matcher: "Matcher") -> Optional[Any]:
return matcher.state.get(
REJECT_PROMPT_RESULT_KEY.format(key=ARG_KEY.format(key=self.key))
)


class ExceptionParam(Param):
Expand Down
25 changes: 25 additions & 0 deletions nonebot/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@
ENDSWITH_KEY,
FULLMATCH_KEY,
KEYWORD_KEY,
PAUSE_PROMPT_RESULT_KEY,
PREFIX_KEY,
RAW_CMD_KEY,
RECEIVE_KEY,
REGEX_MATCHED,
REJECT_PROMPT_RESULT_KEY,
SHELL_ARGS,
SHELL_ARGV,
STARTSWITH_KEY,
)
from nonebot.internal.params import Arg as Arg
from nonebot.internal.params import ArgParam as ArgParam
from nonebot.internal.params import ArgPlainText as ArgPlainText
from nonebot.internal.params import ArgPromptResult as ArgPromptResult
from nonebot.internal.params import ArgStr as ArgStr
from nonebot.internal.params import BotParam as BotParam
from nonebot.internal.params import DefaultParam as DefaultParam
Expand Down Expand Up @@ -252,6 +256,26 @@ def _last_received(matcher: "Matcher") -> Any:
return Depends(_last_received, use_cache=False)


def ReceivePromptResult(id: Optional[str] = None) -> Any:
"""`receive` prompt 发送结果"""

def _receive_prompt_result(matcher: "Matcher") -> Any:
return matcher.state.get(
REJECT_PROMPT_RESULT_KEY.format(key=RECEIVE_KEY.format(id=id))
)

return Depends(_receive_prompt_result, use_cache=False)


def PausePromptResult() -> Any:
"""`pause` prompt 发送结果"""

def _pause_prompt_result(matcher: "Matcher") -> Any:
return matcher.state.get(PAUSE_PROMPT_RESULT_KEY)

return Depends(_pause_prompt_result, use_cache=False)


__autodoc__ = {
"Arg": True,
"ArgStr": True,
Expand All @@ -265,4 +289,5 @@ def _last_received(matcher: "Matcher") -> Any:
"DefaultParam": True,
"MatcherParam": True,
"ExceptionParam": True,
"ArgPromptResult": True,
}
8 changes: 6 additions & 2 deletions tests/plugins/param/param_arg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from typing import Annotated, Any

from nonebot.adapters import Message
from nonebot.params import Arg, ArgPlainText, ArgStr
from nonebot.params import Arg, ArgPlainText, ArgPromptResult, ArgStr


async def arg(key: Message = Arg()) -> Message:
Expand All @@ -28,6 +28,10 @@ async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str:
return key


async def annotated_arg_prompt_result(key: Annotated[Any, ArgPromptResult()]) -> Any:
return key


# test dependency priority
async def annotated_prior_arg(key: Annotated[str, ArgStr("foo")] = ArgPlainText()):
return key
Expand Down
17 changes: 15 additions & 2 deletions tests/plugins/param/param_matcher.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import TypeVar, Union
from typing import Any, TypeVar, Union

from nonebot.adapters import Event
from nonebot.matcher import Matcher
from nonebot.params import LastReceived, Received
from nonebot.params import (
LastReceived,
PausePromptResult,
Received,
ReceivePromptResult,
)


async def matcher(m: Matcher) -> Matcher:
Expand Down Expand Up @@ -59,3 +64,11 @@ async def receive(e: Event = Received("test")) -> Event:

async def last_receive(e: Event = LastReceived()) -> Event:
return e


async def receive_prompt_result(result: Any = ReceivePromptResult("test")) -> Any:
return result


async def pause_prompt_result(result: Any = PausePromptResult()) -> Any:
return result
Loading

0 comments on commit 32bc2c3

Please sign in to comment.