diff --git a/arc/abc/__init__.py b/arc/abc/__init__.py index cb053f6..fb6d67a 100644 --- a/arc/abc/__init__.py +++ b/arc/abc/__init__.py @@ -2,6 +2,7 @@ from .command import CallableCommandBase, CallableCommandProto, CommandBase, CommandProto from .error_handler import HasErrorHandler from .option import CommandOptionBase, Option, OptionBase, OptionParams, OptionWithChoices, OptionWithChoicesParams +from .plugin import PluginBase __all__ = ( "HasErrorHandler", @@ -16,4 +17,5 @@ "OptionWithChoices", "OptionWithChoicesParams", "Client", + "PluginBase", ) diff --git a/arc/abc/client.py b/arc/abc/client.py index 42091f4..7dfcbbe 100644 --- a/arc/abc/client.py +++ b/arc/abc/client.py @@ -15,6 +15,7 @@ import alluka import hikari +from arc.abc.plugin import PluginBase from arc.command.message import MessageCommand from arc.command.slash import SlashCommand, SlashGroup, SlashSubCommand, SlashSubGroup from arc.command.user import UserCommand @@ -22,7 +23,6 @@ from arc.errors import ExtensionLoadError, ExtensionUnloadError from arc.internal.sync import _sync_commands from arc.internal.types import AppT, BuilderT, ResponseBuilderT -from arc.plugin import PluginBase if t.TYPE_CHECKING: import typing_extensions as te diff --git a/arc/abc/command.py b/arc/abc/command.py index bd0d635..aec2929 100644 --- a/arc/abc/command.py +++ b/arc/abc/command.py @@ -12,8 +12,8 @@ from arc.internal.types import BuilderT, ClientT, CommandCallbackT, ResponseBuilderT if t.TYPE_CHECKING: - from ..context import Context - from ..plugin import PluginBase + from arc.abc.plugin import PluginBase + from arc.context import Context class CommandProto(t.Protocol): diff --git a/arc/abc/plugin.py b/arc/abc/plugin.py new file mode 100644 index 0000000..513fb1b --- /dev/null +++ b/arc/abc/plugin.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +import abc +import functools +import inspect +import itertools +import typing as t + +import hikari + +from arc.abc.error_handler import HasErrorHandler +from arc.command import MessageCommand, SlashCommand, SlashGroup, UserCommand +from arc.context import AutodeferMode, Context +from arc.internal.types import BuilderT, ClientT, SlashCommandLike + +if t.TYPE_CHECKING: + from arc.abc.command import CommandBase + from arc.command import SlashSubCommand, SlashSubGroup + +__all__ = ("PluginBase",) + +P = t.ParamSpec("P") +T = t.TypeVar("T") + + +class PluginBase(HasErrorHandler[ClientT], t.Generic[ClientT]): + """An abstract base class for plugins. + + Parameters + ---------- + name : builtins.str + The name of this plugin. This must be unique across all plugins. + """ + + def __init__( + self, name: str, *, default_enabled_guilds: hikari.UndefinedOr[t.Sequence[hikari.Snowflake]] = hikari.UNDEFINED + ) -> None: + super().__init__() + self._client: ClientT | None = None + self._name = name + self._slash_commands: dict[str, SlashCommandLike[ClientT]] = {} + self._user_commands: dict[str, UserCommand[ClientT]] = {} + self._message_commands: dict[str, MessageCommand[ClientT]] = {} + self._default_enabled_guilds = default_enabled_guilds + + @property + @abc.abstractmethod + def is_rest(self) -> bool: + """Whether or not this plugin is a REST plugin.""" + + @property + def name(self) -> str: + """The name of this plugin.""" + return self._name + + @property + def client(self) -> ClientT: + """The client this plugin is included in.""" + if self._client is None: + raise RuntimeError( + f"Plugin '{self.name}' was not included in a client, '{type(self).__name__}.client' cannot be accessed until it is included in a client." + ) + return self._client + + @property + def default_enabled_guilds(self) -> hikari.UndefinedOr[t.Sequence[hikari.Snowflake]]: + """The default guilds to enable commands in.""" + return self._default_enabled_guilds + + def _client_include_hook(self, client: ClientT) -> None: + if client._plugins.get(self.name) is not None: + raise RuntimeError(f"Plugin '{self.name}' is already included in client.") + + self._client = client + self._client._plugins[self.name] = self + + for command in itertools.chain( + self._slash_commands.values(), self._user_commands.values(), self._message_commands.values() + ): + command._client_include_hook(client) + + def _client_remove_hook(self) -> None: + if self._client is None: + raise RuntimeError(f"Plugin '{self.name}' is not included in a client.") + + for command in itertools.chain( + self._slash_commands.values(), self._user_commands.values(), self._message_commands.values() + ): + self.client._remove_command(command) + + self._client._plugins.pop(self.name) + self._client = None + + def _add_command(self, command: CommandBase[ClientT, t.Any]) -> None: + if isinstance(command, (SlashCommand, SlashGroup)): + self._slash_commands[command.name] = command + elif isinstance(command, UserCommand): + self._user_commands[command.name] = command + elif isinstance(command, MessageCommand): + self._message_commands[command.name] = command + else: + raise TypeError(f"Unknown command type '{type(command).__name__}'.") + + def include(self, command: CommandBase[ClientT, BuilderT]) -> CommandBase[ClientT, BuilderT]: + """Include a command in this plugin. + + Parameters + ---------- + command : arc.CommandBase[ClientT, BuilderT] + The command to include in this plugin. + + Raises + ------ + RuntimeError + If the command is already included in this plugin. + """ + if command.plugin is not None: + raise ValueError(f"Command '{command.name}' is already registered with plugin '{command.plugin.name}'.") + + command._plugin_include_hook(self) + return command + + async def _handle_exception(self, ctx: Context[ClientT], exc: Exception) -> None: + try: + if self.error_handler is not None: + await self.error_handler(ctx, exc) + else: + raise exc + except Exception as exc: + await self.client._on_error(ctx, exc) + + def include_slash_group( + self, + name: str, + description: str = "No description provided.", + *, + guilds: hikari.UndefinedOr[t.Sequence[hikari.Snowflake]] = hikari.UNDEFINED, + autodefer: bool | AutodeferMode = True, + is_dm_enabled: bool = True, + default_permissions: hikari.UndefinedOr[hikari.Permissions] = hikari.UNDEFINED, + name_localizations: dict[hikari.Locale, str] | None = None, + description_localizations: dict[hikari.Locale, str] | None = None, + is_nsfw: bool = False, + ) -> SlashGroup[ClientT]: + """Add a new slash command group to this client. + + Parameters + ---------- + name : str + The name of the slash command group. + description : str + The description of the slash command group. + guilds : hikari.UndefinedOr[t.Sequence[hikari.Snowflake]], optional + The guilds to register the slash command group in, by default hikari.UNDEFINED + autodefer : bool | AutodeferMode, optional + If True, all commands in this group will automatically defer if it is taking longer than 2 seconds to respond. + This can be overridden on a per-subcommand basis. + is_dm_enabled : bool, optional + Whether the slash command group is enabled in DMs, by default True + default_permissions : hikari.UndefinedOr[hikari.Permissions], optional + The default permissions for the slash command group, by default hikari.UNDEFINED + name_localizations : dict[hikari.Locale, str], optional + The name of the slash command group in different locales, by default None + description_localizations : dict[hikari.Locale, str], optional + The description of the slash command group in different locales, by default None + is_nsfw : bool, optional + Whether the slash command group is only usable in NSFW channels, by default False + + Returns + ------- + SlashGroup[te.Self] + The slash command group that was created. + + Usage + ----- + ```py + group = client.include_slash_group("Group", "A group of commands.") + + @group.include + @arc.slash_subcommand(name="Command", description="A command.") + async def cmd(ctx: arc.GatewayContext) -> None: + await ctx.respond("Hello!") + ``` + """ + children: dict[str, SlashSubCommand[ClientT] | SlashSubGroup[ClientT]] = {} + + group = SlashGroup( + name=name, + description=description, + children=children, + guilds=guilds, + autodefer=AutodeferMode(autodefer), + is_dm_enabled=is_dm_enabled, + default_permissions=default_permissions, + name_localizations=name_localizations or {}, + description_localizations=description_localizations or {}, + is_nsfw=is_nsfw, + ) + group._plugin_include_hook(self) + return group + + def inject_dependencies(self, func: t.Callable[P, T]) -> t.Callable[P, T]: + """First order decorator to inject dependencies into the decorated function. + + !!! warning + This makes functions uncallable if the plugin is not added to a client. + + !!! note + Command callbacks are automatically injected with dependencies, + thus this decorator is not needed for them. + + Usage + ----- + ```py + class MyDependency: + def __init__(self, value: str): + self.value = value + + client = arc.GatewayClient(...) + client.set_type_dependency(MyDependency, MyDependency("Hello!")) + client.load_extension("foo") + + # In 'foo': + + plugin = arc.GatewayPlugin("My Plugin") + + @plugin.inject_dependencies + def my_func(dep: MyDependency = arc.inject()) -> None: + print(dep.value) # Prints "Hello!" + + @arc.loader + def load(client: arc.GatewayClient) -> None: + client.add_plugin(plugin) + ``` + + See Also + -------- + - [`Client.set_type_dependency`][arc.client.Client.set_type_dependency] + A method to set dependencies for the client. + """ + if inspect.iscoroutinefunction(func): + + @functools.wraps(func) + async def decorator_async(*args: P.args, **kwargs: P.kwargs) -> T: + if self._client is None: + raise RuntimeError( + f"Cannot inject dependencies into '{func.__name__}' before plugin '{self.name}' is included in a client." + ) + return await self._client.injector.call_with_async_di(func, *args, **kwargs) + + return decorator_async # pyright: ignore reportGeneralTypeIssues + else: + + @functools.wraps(func) + def decorator(*args: P.args, **kwargs: P.kwargs) -> T: + if self._client is None: + raise RuntimeError( + f"Cannot inject dependencies into '{func.__name__}' before plugin '{self.name}' is included in a client." + ) + return self._client.injector.call_with_di(func, *args, **kwargs) + + return decorator diff --git a/arc/client.py b/arc/client.py index 2390472..f76e411 100644 --- a/arc/client.py +++ b/arc/client.py @@ -27,7 +27,7 @@ class GatewayClient(Client[hikari.GatewayBotAware]): """The default implementation for an arc client with `hikari.GatewayBotAware` support. - If you want to use a `hikari.RESTBotAware`, use `RESTClient` instead. + If you want to use a `hikari.RESTBotAware`, use [`RESTClient`][arc.client.RESTClient] instead. Parameters ---------- @@ -119,7 +119,7 @@ def listen(self, *event_types: t.Type[EventT]) -> t.Callable[[EventCallbackT[Eve class RESTClient(Client[hikari.RESTBotAware]): """The default implementation for an arc client with `hikari.RESTBotAware` support. - If you want to use `hikari.GatewayBotAware`, use `GatewayClient` instead. + If you want to use `hikari.GatewayBotAware`, use [`GatewayClient`][arc.client.GatewayClient] instead. Parameters ---------- diff --git a/arc/command/slash.py b/arc/command/slash.py index a588f89..341ca85 100644 --- a/arc/command/slash.py +++ b/arc/command/slash.py @@ -19,7 +19,7 @@ from arc.abc.command import CallableCommandProto from arc.abc.option import CommandOptionBase - from arc.plugin import PluginBase + from arc.abc.plugin import PluginBase __all__ = ( "SlashCommandLike", diff --git a/arc/plugin.py b/arc/plugin.py index 26576fd..04a888e 100644 --- a/arc/plugin.py +++ b/arc/plugin.py @@ -1,269 +1,22 @@ from __future__ import annotations -import abc -import functools -import inspect -import itertools import typing as t -import hikari - -from arc.abc.error_handler import HasErrorHandler -from arc.command import MessageCommand, SlashCommand, SlashGroup, UserCommand -from arc.context import AutodeferMode, Context -from arc.internal.types import BuilderT, ClientT, EventCallbackT, GatewayClientT, RESTClientT, SlashCommandLike +from arc.abc.plugin import PluginBase +from arc.internal.types import EventCallbackT, GatewayClientT, RESTClientT if t.TYPE_CHECKING: - from arc.abc.command import CommandBase - from arc.command import SlashSubCommand, SlashSubGroup + import hikari -__all__ = ("PluginBase", "RESTPluginBase", "GatewayPluginBase") +__all__ = ("RESTPluginBase", "GatewayPluginBase") P = t.ParamSpec("P") T = t.TypeVar("T") -class PluginBase(HasErrorHandler[ClientT], t.Generic[ClientT]): - """An abstract base class for plugins. - - Parameters - ---------- - name : builtins.str - The name of this plugin. This must be unique across all plugins. - """ - - def __init__( - self, name: str, *, default_enabled_guilds: hikari.UndefinedOr[t.Sequence[hikari.Snowflake]] = hikari.UNDEFINED - ) -> None: - super().__init__() - self._client: ClientT | None = None - self._name = name - self._slash_commands: dict[str, SlashCommandLike[ClientT]] = {} - self._user_commands: dict[str, UserCommand[ClientT]] = {} - self._message_commands: dict[str, MessageCommand[ClientT]] = {} - self._default_enabled_guilds = default_enabled_guilds - - @property - @abc.abstractmethod - def is_rest(self) -> bool: - """Whether or not this plugin is a REST plugin.""" - - @property - def name(self) -> str: - """The name of this plugin.""" - return self._name - - @property - def client(self) -> ClientT: - """The client this plugin is included in.""" - if self._client is None: - raise RuntimeError( - f"Plugin '{self.name}' was not included in a client, '{type(self).__name__}.client' cannot be accessed until it is included in a client." - ) - return self._client - - @property - def default_enabled_guilds(self) -> hikari.UndefinedOr[t.Sequence[hikari.Snowflake]]: - """The default guilds to enable commands in.""" - return self._default_enabled_guilds - - def _client_include_hook(self, client: ClientT) -> None: - if client._plugins.get(self.name) is not None: - raise RuntimeError(f"Plugin '{self.name}' is already included in client.") - - self._client = client - self._client._plugins[self.name] = self - - for command in itertools.chain( - self._slash_commands.values(), self._user_commands.values(), self._message_commands.values() - ): - command._client_include_hook(client) - - def _client_remove_hook(self) -> None: - if self._client is None: - raise RuntimeError(f"Plugin '{self.name}' is not included in a client.") - - for command in itertools.chain( - self._slash_commands.values(), self._user_commands.values(), self._message_commands.values() - ): - self.client._remove_command(command) - - self._client._plugins.pop(self.name) - self._client = None - - def _add_command(self, command: CommandBase[ClientT, t.Any]) -> None: - if isinstance(command, (SlashCommand, SlashGroup)): - self._slash_commands[command.name] = command - elif isinstance(command, UserCommand): - self._user_commands[command.name] = command - elif isinstance(command, MessageCommand): - self._message_commands[command.name] = command - else: - raise TypeError(f"Unknown command type '{type(command).__name__}'.") - - def include(self, command: CommandBase[ClientT, BuilderT]) -> CommandBase[ClientT, BuilderT]: - """Include a command in this plugin. - - Parameters - ---------- - command : arc.CommandBase[ClientT, BuilderT] - The command to include in this plugin. - - Raises - ------ - RuntimeError - If the command is already included in this plugin. - """ - if command.plugin is not None: - raise ValueError(f"Command '{command.name}' is already registered with plugin '{command.plugin.name}'.") - - command._plugin_include_hook(self) - return command - - async def _handle_exception(self, ctx: Context[ClientT], exc: Exception) -> None: - try: - if self.error_handler is not None: - await self.error_handler(ctx, exc) - else: - raise exc - except Exception as exc: - await self.client._on_error(ctx, exc) - - def include_slash_group( - self, - name: str, - description: str = "No description provided.", - *, - guilds: hikari.UndefinedOr[t.Sequence[hikari.Snowflake]] = hikari.UNDEFINED, - autodefer: bool | AutodeferMode = True, - is_dm_enabled: bool = True, - default_permissions: hikari.UndefinedOr[hikari.Permissions] = hikari.UNDEFINED, - name_localizations: dict[hikari.Locale, str] | None = None, - description_localizations: dict[hikari.Locale, str] | None = None, - is_nsfw: bool = False, - ) -> SlashGroup[ClientT]: - """Add a new slash command group to this client. - - Parameters - ---------- - name : str - The name of the slash command group. - description : str - The description of the slash command group. - guilds : hikari.UndefinedOr[t.Sequence[hikari.Snowflake]], optional - The guilds to register the slash command group in, by default hikari.UNDEFINED - autodefer : bool | AutodeferMode, optional - If True, all commands in this group will automatically defer if it is taking longer than 2 seconds to respond. - This can be overridden on a per-subcommand basis. - is_dm_enabled : bool, optional - Whether the slash command group is enabled in DMs, by default True - default_permissions : hikari.UndefinedOr[hikari.Permissions], optional - The default permissions for the slash command group, by default hikari.UNDEFINED - name_localizations : dict[hikari.Locale, str], optional - The name of the slash command group in different locales, by default None - description_localizations : dict[hikari.Locale, str], optional - The description of the slash command group in different locales, by default None - is_nsfw : bool, optional - Whether the slash command group is only usable in NSFW channels, by default False - - Returns - ------- - SlashGroup[te.Self] - The slash command group that was created. - - Usage - ----- - ```py - group = client.include_slash_group("Group", "A group of commands.") - - @group.include - @arc.slash_subcommand(name="Command", description="A command.") - async def cmd(ctx: arc.GatewayContext) -> None: - await ctx.respond("Hello!") - ``` - """ - children: dict[str, SlashSubCommand[ClientT] | SlashSubGroup[ClientT]] = {} - - group = SlashGroup( - name=name, - description=description, - children=children, - guilds=guilds, - autodefer=AutodeferMode(autodefer), - is_dm_enabled=is_dm_enabled, - default_permissions=default_permissions, - name_localizations=name_localizations or {}, - description_localizations=description_localizations or {}, - is_nsfw=is_nsfw, - ) - group._plugin_include_hook(self) - return group - - def inject_dependencies(self, func: t.Callable[P, T]) -> t.Callable[P, T]: - """First order decorator to inject dependencies into the decorated function. - - !!! warning - This makes functions uncallable if the plugin is not added to a client. - - !!! note - Command callbacks are automatically injected with dependencies, - thus this decorator is not needed for them. - - Usage - ----- - ```py - class MyDependency: - def __init__(self, value: str): - self.value = value - - client = arc.GatewayClient(...) - client.set_type_dependency(MyDependency, MyDependency("Hello!")) - client.load_extension("foo") - - # In 'foo': - - plugin = arc.GatewayPlugin("My Plugin") - - @plugin.inject_dependencies - def my_func(dep: MyDependency = arc.inject()) -> None: - print(dep.value) # Prints "Hello!" - - @arc.loader - def load(client: arc.GatewayClient) -> None: - client.add_plugin(plugin) - ``` - - See Also - -------- - - [`Client.set_type_dependency`][arc.client.Client.set_type_dependency] - A method to set dependencies for the client. - """ - if inspect.iscoroutinefunction(func): - - @functools.wraps(func) - async def decorator_async(*args: P.args, **kwargs: P.kwargs) -> T: - if self._client is None: - raise RuntimeError( - f"Cannot inject dependencies into '{func.__name__}' before plugin '{self.name}' is included in a client." - ) - return await self._client.injector.call_with_async_di(func, *args, **kwargs) - - return decorator_async # pyright: ignore reportGeneralTypeIssues - else: - - @functools.wraps(func) - def decorator(*args: P.args, **kwargs: P.kwargs) -> T: - if self._client is None: - raise RuntimeError( - f"Cannot inject dependencies into '{func.__name__}' before plugin '{self.name}' is included in a client." - ) - return self._client.injector.call_with_di(func, *args, **kwargs) - - return decorator - - class RESTPluginBase(PluginBase[RESTClientT]): """The default implementation of a REST plugin. + To use this with the default [`RESTClient`][arc.client.RESTClient] implementation, see [`RESTPlugin`][arc.client.RESTPlugin]. Parameters ---------- @@ -278,6 +31,7 @@ def is_rest(self) -> bool: class GatewayPluginBase(PluginBase[GatewayClientT]): """The default implementation of a gateway plugin. + To use this with the default [`GatewayClient`][arc.client.GatewayClient] implementation, see [`GatewayPlugin`][arc.client.GatewayPlugin]. Parameters ---------- diff --git a/docs/api_reference/abc/plugin.md b/docs/api_reference/abc/plugin.md new file mode 100644 index 0000000..799f41c --- /dev/null +++ b/docs/api_reference/abc/plugin.md @@ -0,0 +1,8 @@ +--- +title: Plugin ABC +description: Abstract Base Classes API reference +--- + +# Plugin ABC + +::: arc.abc.plugin diff --git a/docs/guides/dependency_injection.md b/docs/guides/dependency_injection.md index 899c526..142ed1c 100644 --- a/docs/guides/dependency_injection.md +++ b/docs/guides/dependency_injection.md @@ -57,7 +57,7 @@ description: A guide on dependency injection & arc client.set_type_dependency(MyDatabase, database) ``` -In the above example, we asked `arc` that every time we ask for a dependency of type `MyDatabase`, it should return the specific instance we gave it as the second parameter to [`Client.set_type_dependency`][arc.client.Client.set_type_dependency] +In the above example, we asked `arc` that every time we ask for a dependency of type `MyDatabase`, it should return the specific instance we gave it as the second parameter to [`Client.set_type_dependency`][arc.abc.client.Client.set_type_dependency] ## Injecting dependencies @@ -93,7 +93,7 @@ And here we request that `arc` injects the dependency we declared earlier into o ### Injecting other functions -By default, only command callbacks are injected with dependencies, but you might want to inject other functions too. This can be done via the [`@Client.inject_dependencies`][arc.client.Client.inject_dependencies] decorator (or [`@Plugin.inject_dependencies`][arc.plugin.Plugin.inject_dependencies] if working in a plugin). +By default, only command callbacks are injected with dependencies, but you might want to inject other functions too. This can be done via the [`@Client.inject_dependencies`][arc.abc.client.Client.inject_dependencies] decorator (or [`@Plugin.inject_dependencies`][arc.abc.plugin.PluginBase.inject_dependencies] if working in a plugin). ```py @client.inject_dependencies @@ -105,7 +105,7 @@ def compare_counter(value: int, db: MyDatabase = arc.inject()) -> None: ``` !!! warning - Trying to use [`arc.inject()`][alluka.inject] outside a command or a function decorated with [`@Client.inject_dependencies`][arc.client.Client.inject_dependencies] will lead to unexpected results. + Trying to use [`arc.inject()`][alluka.inject] outside a command or a function decorated with [`@Client.inject_dependencies`][arc.abc.client.Client.inject_dependencies] will lead to unexpected results. ## The benefits of dependency injection diff --git a/mkdocs.yml b/mkdocs.yml index 1da88b7..bdc56bb 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -82,6 +82,7 @@ nav: - api_reference/abc/client.md - api_reference/abc/command.md - api_reference/abc/option.md + - api_reference/abc/plugin.md - api_reference/abc/error_handler.md - Changelog: changelog.md