Skip to content

Commit

Permalink
✨ Feature: 跳过部分非必要的 task group 创建 (#3095)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanyongyu authored Oct 31, 2024
1 parent 7b13654 commit 15c5464
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
20 changes: 13 additions & 7 deletions nonebot/dependencies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,17 @@ def parse(
return cls(call, params, parameterless_params)

async def check(self, **params: Any) -> None:
async with anyio.create_task_group() as tg:
for param in self.parameterless:
tg.start_soon(partial(param._check, **params))

async with anyio.create_task_group() as tg:
for param in self.params:
tg.start_soon(partial(cast(Param, param.field_info)._check, **params))
if self.parameterless:
async with anyio.create_task_group() as tg:
for param in self.parameterless:
tg.start_soon(partial(param._check, **params))

if self.params:
async with anyio.create_task_group() as tg:
for param in self.params:
tg.start_soon(
partial(cast(Param, param.field_info)._check, **params)
)

async def _solve_field(self, field: ModelField, params: dict[str, Any]) -> Any:
param = cast(Param, field.field_info)
Expand All @@ -205,6 +209,8 @@ async def solve(self, **params: Any) -> dict[str, Any]:

# solve param values
result: dict[str, Any] = {}
if not self.params:
return result

async def _solve_field(field: ModelField, params: dict[str, Any]) -> None:
value = await self._solve_field(field, params)
Expand Down
6 changes: 6 additions & 0 deletions nonebot/internal/driver/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def _bot_connect(self, bot: "Bot") -> None:
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
self._bots[bot.self_id] = bot

if not self._bot_connection_hook:
return

def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
Expand Down Expand Up @@ -186,6 +189,9 @@ def _bot_disconnect(self, bot: "Bot") -> None:
if bot.self_id in self._bots:
del self._bots[bot.self_id]

if not self._bot_disconnection_hook:
return

def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
Expand Down

0 comments on commit 15c5464

Please sign in to comment.