Skip to content

Commit

Permalink
Move PluginBase to abc
Browse files Browse the repository at this point in the history
  • Loading branch information
hypergonial committed Dec 31, 2023
1 parent d9f3d7c commit 83bf83f
Show file tree
Hide file tree
Showing 10 changed files with 288 additions and 261 deletions.
2 changes: 2 additions & 0 deletions arc/abc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .command import CallableCommandBase, CallableCommandProto, CommandBase, CommandProto
from .error_handler import HasErrorHandler
from .option import CommandOptionBase, Option, OptionBase, OptionParams, OptionWithChoices, OptionWithChoicesParams
from .plugin import PluginBase

__all__ = (
"HasErrorHandler",
Expand All @@ -16,4 +17,5 @@
"OptionWithChoices",
"OptionWithChoicesParams",
"Client",
"PluginBase",
)
2 changes: 1 addition & 1 deletion arc/abc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import alluka
import hikari

from arc.abc.plugin import PluginBase
from arc.command.message import MessageCommand
from arc.command.slash import SlashCommand, SlashGroup, SlashSubCommand, SlashSubGroup
from arc.command.user import UserCommand
from arc.context import AutodeferMode, Context
from arc.errors import ExtensionLoadError, ExtensionUnloadError
from arc.internal.sync import _sync_commands
from arc.internal.types import AppT, BuilderT, ResponseBuilderT
from arc.plugin import PluginBase

if t.TYPE_CHECKING:
import typing_extensions as te
Expand Down
4 changes: 2 additions & 2 deletions arc/abc/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from arc.internal.types import BuilderT, ClientT, CommandCallbackT, ResponseBuilderT

if t.TYPE_CHECKING:
from ..context import Context
from ..plugin import PluginBase
from arc.abc.plugin import PluginBase
from arc.context import Context


class CommandProto(t.Protocol):
Expand Down
262 changes: 262 additions & 0 deletions arc/abc/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
from __future__ import annotations

import abc
import functools
import inspect
import itertools
import typing as t

import hikari

from arc.abc.error_handler import HasErrorHandler
from arc.command import MessageCommand, SlashCommand, SlashGroup, UserCommand
from arc.context import AutodeferMode, Context
from arc.internal.types import BuilderT, ClientT, SlashCommandLike

if t.TYPE_CHECKING:
from arc.abc.command import CommandBase
from arc.command import SlashSubCommand, SlashSubGroup

__all__ = ("PluginBase",)

P = t.ParamSpec("P")
T = t.TypeVar("T")


class PluginBase(HasErrorHandler[ClientT], t.Generic[ClientT]):
"""An abstract base class for plugins.
Parameters
----------
name : builtins.str
The name of this plugin. This must be unique across all plugins.
"""

def __init__(
self, name: str, *, default_enabled_guilds: hikari.UndefinedOr[t.Sequence[hikari.Snowflake]] = hikari.UNDEFINED
) -> None:
super().__init__()
self._client: ClientT | None = None
self._name = name
self._slash_commands: dict[str, SlashCommandLike[ClientT]] = {}
self._user_commands: dict[str, UserCommand[ClientT]] = {}
self._message_commands: dict[str, MessageCommand[ClientT]] = {}
self._default_enabled_guilds = default_enabled_guilds

@property
@abc.abstractmethod
def is_rest(self) -> bool:
"""Whether or not this plugin is a REST plugin."""

@property
def name(self) -> str:
"""The name of this plugin."""
return self._name

@property
def client(self) -> ClientT:
"""The client this plugin is included in."""
if self._client is None:
raise RuntimeError(
f"Plugin '{self.name}' was not included in a client, '{type(self).__name__}.client' cannot be accessed until it is included in a client."
)
return self._client

@property
def default_enabled_guilds(self) -> hikari.UndefinedOr[t.Sequence[hikari.Snowflake]]:
"""The default guilds to enable commands in."""
return self._default_enabled_guilds

def _client_include_hook(self, client: ClientT) -> None:
if client._plugins.get(self.name) is not None:
raise RuntimeError(f"Plugin '{self.name}' is already included in client.")

self._client = client
self._client._plugins[self.name] = self

for command in itertools.chain(
self._slash_commands.values(), self._user_commands.values(), self._message_commands.values()
):
command._client_include_hook(client)

def _client_remove_hook(self) -> None:
if self._client is None:
raise RuntimeError(f"Plugin '{self.name}' is not included in a client.")

for command in itertools.chain(
self._slash_commands.values(), self._user_commands.values(), self._message_commands.values()
):
self.client._remove_command(command)

self._client._plugins.pop(self.name)
self._client = None

def _add_command(self, command: CommandBase[ClientT, t.Any]) -> None:
if isinstance(command, (SlashCommand, SlashGroup)):
self._slash_commands[command.name] = command
elif isinstance(command, UserCommand):
self._user_commands[command.name] = command
elif isinstance(command, MessageCommand):
self._message_commands[command.name] = command
else:
raise TypeError(f"Unknown command type '{type(command).__name__}'.")

def include(self, command: CommandBase[ClientT, BuilderT]) -> CommandBase[ClientT, BuilderT]:
"""Include a command in this plugin.
Parameters
----------
command : arc.CommandBase[ClientT, BuilderT]
The command to include in this plugin.
Raises
------
RuntimeError
If the command is already included in this plugin.
"""
if command.plugin is not None:
raise ValueError(f"Command '{command.name}' is already registered with plugin '{command.plugin.name}'.")

command._plugin_include_hook(self)
return command

async def _handle_exception(self, ctx: Context[ClientT], exc: Exception) -> None:
try:
if self.error_handler is not None:
await self.error_handler(ctx, exc)
else:
raise exc
except Exception as exc:
await self.client._on_error(ctx, exc)

def include_slash_group(
self,
name: str,
description: str = "No description provided.",
*,
guilds: hikari.UndefinedOr[t.Sequence[hikari.Snowflake]] = hikari.UNDEFINED,
autodefer: bool | AutodeferMode = True,
is_dm_enabled: bool = True,
default_permissions: hikari.UndefinedOr[hikari.Permissions] = hikari.UNDEFINED,
name_localizations: dict[hikari.Locale, str] | None = None,
description_localizations: dict[hikari.Locale, str] | None = None,
is_nsfw: bool = False,
) -> SlashGroup[ClientT]:
"""Add a new slash command group to this client.
Parameters
----------
name : str
The name of the slash command group.
description : str
The description of the slash command group.
guilds : hikari.UndefinedOr[t.Sequence[hikari.Snowflake]], optional
The guilds to register the slash command group in, by default hikari.UNDEFINED
autodefer : bool | AutodeferMode, optional
If True, all commands in this group will automatically defer if it is taking longer than 2 seconds to respond.
This can be overridden on a per-subcommand basis.
is_dm_enabled : bool, optional
Whether the slash command group is enabled in DMs, by default True
default_permissions : hikari.UndefinedOr[hikari.Permissions], optional
The default permissions for the slash command group, by default hikari.UNDEFINED
name_localizations : dict[hikari.Locale, str], optional
The name of the slash command group in different locales, by default None
description_localizations : dict[hikari.Locale, str], optional
The description of the slash command group in different locales, by default None
is_nsfw : bool, optional
Whether the slash command group is only usable in NSFW channels, by default False
Returns
-------
SlashGroup[te.Self]
The slash command group that was created.
Usage
-----
```py
group = client.include_slash_group("Group", "A group of commands.")
@group.include
@arc.slash_subcommand(name="Command", description="A command.")
async def cmd(ctx: arc.GatewayContext) -> None:
await ctx.respond("Hello!")
```
"""
children: dict[str, SlashSubCommand[ClientT] | SlashSubGroup[ClientT]] = {}

group = SlashGroup(
name=name,
description=description,
children=children,
guilds=guilds,
autodefer=AutodeferMode(autodefer),
is_dm_enabled=is_dm_enabled,
default_permissions=default_permissions,
name_localizations=name_localizations or {},
description_localizations=description_localizations or {},
is_nsfw=is_nsfw,
)
group._plugin_include_hook(self)
return group

def inject_dependencies(self, func: t.Callable[P, T]) -> t.Callable[P, T]:
"""First order decorator to inject dependencies into the decorated function.
!!! warning
This makes functions uncallable if the plugin is not added to a client.
!!! note
Command callbacks are automatically injected with dependencies,
thus this decorator is not needed for them.
Usage
-----
```py
class MyDependency:
def __init__(self, value: str):
self.value = value
client = arc.GatewayClient(...)
client.set_type_dependency(MyDependency, MyDependency("Hello!"))
client.load_extension("foo")
# In 'foo':
plugin = arc.GatewayPlugin("My Plugin")
@plugin.inject_dependencies
def my_func(dep: MyDependency = arc.inject()) -> None:
print(dep.value) # Prints "Hello!"
@arc.loader
def load(client: arc.GatewayClient) -> None:
client.add_plugin(plugin)
```
See Also
--------
- [`Client.set_type_dependency`][arc.client.Client.set_type_dependency]
A method to set dependencies for the client.
"""
if inspect.iscoroutinefunction(func):

@functools.wraps(func)
async def decorator_async(*args: P.args, **kwargs: P.kwargs) -> T:
if self._client is None:
raise RuntimeError(
f"Cannot inject dependencies into '{func.__name__}' before plugin '{self.name}' is included in a client."
)
return await self._client.injector.call_with_async_di(func, *args, **kwargs)

return decorator_async # pyright: ignore reportGeneralTypeIssues
else:

@functools.wraps(func)
def decorator(*args: P.args, **kwargs: P.kwargs) -> T:
if self._client is None:
raise RuntimeError(
f"Cannot inject dependencies into '{func.__name__}' before plugin '{self.name}' is included in a client."
)
return self._client.injector.call_with_di(func, *args, **kwargs)

return decorator
4 changes: 2 additions & 2 deletions arc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

class GatewayClient(Client[hikari.GatewayBotAware]):
"""The default implementation for an arc client with `hikari.GatewayBotAware` support.
If you want to use a `hikari.RESTBotAware`, use `RESTClient` instead.
If you want to use a `hikari.RESTBotAware`, use [`RESTClient`][arc.client.RESTClient] instead.
Parameters
----------
Expand Down Expand Up @@ -119,7 +119,7 @@ def listen(self, *event_types: t.Type[EventT]) -> t.Callable[[EventCallbackT[Eve

class RESTClient(Client[hikari.RESTBotAware]):
"""The default implementation for an arc client with `hikari.RESTBotAware` support.
If you want to use `hikari.GatewayBotAware`, use `GatewayClient` instead.
If you want to use `hikari.GatewayBotAware`, use [`GatewayClient`][arc.client.GatewayClient] instead.
Parameters
----------
Expand Down
2 changes: 1 addition & 1 deletion arc/command/slash.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from arc.abc.command import CallableCommandProto
from arc.abc.option import CommandOptionBase
from arc.plugin import PluginBase
from arc.abc.plugin import PluginBase

__all__ = (
"SlashCommandLike",
Expand Down
Loading

0 comments on commit 83bf83f

Please sign in to comment.