Skip to content

Commit

Permalink
✨ Feature: 迁移至结构化并发框架 AnyIO (#3053)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanyongyu authored Oct 26, 2024
1 parent bd9befb commit ff21ceb
Show file tree
Hide file tree
Showing 39 changed files with 5,417 additions and 4,075 deletions.
2,029 changes: 1,110 additions & 919 deletions envs/pydantic-v1/poetry.lock

Large diffs are not rendered by default.

2,132 changes: 1,162 additions & 970 deletions envs/pydantic-v2/poetry.lock

Large diffs are not rendered by default.

1,394 changes: 832 additions & 562 deletions envs/test/poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions envs/test/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ packages = [{ include = "nonebot-test.py" }]

[tool.poetry.dependencies]
python = "^3.9"
nonebug = "^0.3.7"
trio = "^0.27.0"
nonebug = "^0.4.1"
wsproto = "^1.2.0"
pytest-cov = "^5.0.0"
pytest-xdist = "^3.0.2"
pytest-asyncio = "^0.23.2"
werkzeug = ">=2.3.6,<4.0.0"
coverage-conditional-plugin = "^0.9.0"

Expand Down
49 changes: 35 additions & 14 deletions nonebot/dependencies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@
"""

import abc
import asyncio
import inspect
from functools import partial
from dataclasses import field, dataclass
from collections.abc import Iterable, Awaitable
from typing import Any, Generic, TypeVar, Callable, Optional, cast

import anyio
from exceptiongroup import BaseExceptionGroup, catch

from nonebot.log import logger
from nonebot.typing import _DependentCallable
from nonebot.exception import SkippedException
from nonebot.utils import run_sync, is_coroutine_callable
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
from nonebot.utils import run_sync, is_coroutine_callable, flatten_exception_group

from .utils import check_field_type, get_typed_signature

Expand Down Expand Up @@ -84,7 +87,16 @@ def __repr__(self) -> str:
)

async def __call__(self, **kwargs: Any) -> R:
try:
exception: Optional[BaseExceptionGroup[SkippedException]] = None

def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
nonlocal exception
exception = exc_group
# raise one of the exceptions instead
excs = list(flatten_exception_group(exc_group))
logger.trace(f"{self} skipped due to {excs}")

with catch({SkippedException: _handle_skipped}):
# do pre-check
await self.check(**kwargs)

Expand All @@ -96,9 +108,8 @@ async def __call__(self, **kwargs: Any) -> R:
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
else:
return await run_sync(cast(Callable[..., R], self.call))(**values)
except SkippedException as e:
logger.trace(f"{self} skipped due to {e}")
raise

raise exception

@staticmethod
def parse_params(
Expand Down Expand Up @@ -166,10 +177,13 @@ def parse(
return cls(call, params, parameterless_params)

async def check(self, **params: Any) -> None:
await asyncio.gather(*(param._check(**params) for param in self.parameterless))
await asyncio.gather(
*(cast(Param, param.field_info)._check(**params) for param in self.params)
)
async with anyio.create_task_group() as tg:
for param in self.parameterless:
tg.start_soon(partial(param._check, **params))

async with anyio.create_task_group() as tg:
for param in self.params:
tg.start_soon(partial(cast(Param, param.field_info)._check, **params))

async def _solve_field(self, field: ModelField, params: dict[str, Any]) -> Any:
param = cast(Param, field.field_info)
Expand All @@ -185,10 +199,17 @@ async def solve(self, **params: Any) -> dict[str, Any]:
await param._solve(**params)

# solve param values
values = await asyncio.gather(
*(self._solve_field(field, params) for field in self.params)
)
return {field.name: value for field, value in zip(self.params, values)}
result: dict[str, Any] = {}

async def _solve_field(field: ModelField, params: dict[str, Any]) -> None:
value = await self._solve_field(field, params)
result[field.name] = value

async with anyio.create_task_group() as tg:
for field in self.params:
tg.start_soon(_solve_field, field, params)

return result


__autodoc__ = {"CustomConfig": False}
143 changes: 80 additions & 63 deletions nonebot/drivers/none.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@
"""

import signal
import asyncio
import threading
from typing import Optional
from typing_extensions import override

import anyio
from anyio.abc import TaskGroup
from exceptiongroup import BaseExceptionGroup, catch

from nonebot.log import logger
from nonebot.consts import WINDOWS
from nonebot.config import Env, Config
from nonebot.drivers import Driver as BaseDriver
from nonebot.utils import flatten_exception_group

HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
Expand All @@ -35,8 +39,8 @@ class Driver(BaseDriver):
def __init__(self, env: Env, config: Config):
super().__init__(env, config)

self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False
self.should_exit: anyio.Event = anyio.Event()
self.force_exit: anyio.Event = anyio.Event()

@property
@override
Expand All @@ -54,84 +58,97 @@ def logger(self):
def run(self, *args, **kwargs):
"""启动 none driver"""
super().run(*args, **kwargs)
loop = asyncio.get_event_loop()
loop.run_until_complete(self._serve())
anyio.run(self._serve)

async def _serve(self):
self._install_signal_handlers()
await self._startup()
if self.should_exit.is_set():
return
await self._main_loop()
await self._shutdown()
async with anyio.create_task_group() as driver_tg:
driver_tg.start_soon(self._handle_signals)
driver_tg.start_soon(self._listen_force_exit, driver_tg)
driver_tg.start_soon(self._handle_lifespan, driver_tg)

async def _startup(self):
async def _handle_signals(self):
try:
await self._lifespan.startup()
except Exception as e:
logger.opt(colors=True, exception=e).error(
with anyio.open_signal_receiver(*HANDLED_SIGNALS) as signal_receiver:
async for sig in signal_receiver:
self.exit(force=self.should_exit.is_set())
except NotImplementedError:
# Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, self._handle_legacy_signal)

# backport for Windows signal handling
def _handle_legacy_signal(self, sig, frame):
self.exit(force=self.should_exit.is_set())

async def _handle_lifespan(self, tg: TaskGroup):
try:
await self._startup()

if self.should_exit.is_set():
return

await self._listen_exit()

await self._shutdown()
finally:
tg.cancel_scope.cancel()

async def _startup(self):
def handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
self.should_exit.set()

for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>Error occurred while running startup hook."
"</bg #f8bbd0></r>"
)
logger.error(
"<r><bg #f8bbd0>Application startup failed. "
"Exiting.</bg #f8bbd0></r>"
)
self.should_exit.set()
return

logger.info("Application startup completed.")
with catch({Exception: handle_exception}):
await self._lifespan.startup()

if not self.should_exit.is_set():
logger.info("Application startup completed.")

async def _main_loop(self):
async def _listen_exit(self, tg: Optional[TaskGroup] = None):
await self.should_exit.wait()

if tg is not None:
tg.cancel_scope.cancel()

async def _shutdown(self):
logger.info("Shutting down")
logger.info("Waiting for application shutdown. (CTRL+C to force quit)")

logger.info("Waiting for application shutdown.")

try:
await self._lifespan.shutdown()
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>"
)

for task in asyncio.all_tasks():
if task is not asyncio.current_task() and not task.done():
task.cancel()
await asyncio.sleep(0.1)

tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
if tasks and not self.force_exit:
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
while tasks and not self.force_exit:
await asyncio.sleep(0.1)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]

for task in tasks:
task.cancel()
error_occurred: bool = False

await asyncio.gather(*tasks, return_exceptions=True)
def handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
nonlocal error_occurred

logger.info("Application shutdown complete.")
loop = asyncio.get_event_loop()
loop.stop()
error_occurred = True

def _install_signal_handlers(self) -> None:
if threading.current_thread() is not threading.main_thread():
# Signals can only be listened to from the main thread.
return
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>Error occurred while running shutdown hook."
"</bg #f8bbd0></r>"
)
logger.error(
"<r><bg #f8bbd0>Application shutdown failed. "
"Exiting.</bg #f8bbd0></r>"
)

loop = asyncio.get_event_loop()
with catch({Exception: handle_exception}):
await self._lifespan.shutdown()

try:
for sig in HANDLED_SIGNALS:
loop.add_signal_handler(sig, self._handle_exit, sig, None)
except NotImplementedError:
# Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, self._handle_exit)
if not error_occurred:
logger.info("Application shutdown complete.")

def _handle_exit(self, sig, frame):
self.exit(force=self.should_exit.is_set())
async def _listen_force_exit(self, tg: TaskGroup):
await self.force_exit.wait()
tg.cancel_scope.cancel()

def exit(self, force: bool = False):
"""退出 none driver
Expand All @@ -142,4 +159,4 @@ def exit(self, force: bool = False):
if not self.should_exit.is_set():
self.should_exit.set()
if force:
self.force_exit = True
self.force_exit.set()
Loading

0 comments on commit ff21ceb

Please sign in to comment.