Skip to content

Commit

Permalink
Overhaul type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
tandemdude committed Mar 10, 2024
1 parent 360ae30 commit 37adc57
Show file tree
Hide file tree
Showing 13 changed files with 166 additions and 160 deletions.
68 changes: 41 additions & 27 deletions lightbulb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,20 @@

if t.TYPE_CHECKING:
from lightbulb.commands import options
from lightbulb.internal.types import MaybeAwaitable

T = t.TypeVar("T")
CommandOrGroupT = t.TypeVar("CommandOrGroupT", bound=t.Union[groups.Group, t.Type[commands.CommandBase]])
CommandOrGroup: t.TypeAlias = t.Union[groups.Group, type[commands.CommandBase]]
CommandOrGroupT = t.TypeVar("CommandOrGroupT", bound=CommandOrGroup)
CommandMapT = t.MutableMapping[hikari.Snowflakeish, t.MutableMapping[str, utils.CommandCollection]]
OptionT = t.TypeVar("OptionT", bound=hikari.CommandInteractionOption)

LOGGER = logging.getLogger("lightbulb.client")
DEFAULT_EXECUTION_STEP_ORDER = (
execution.ExecutionSteps.MAX_CONCURRENCY,
execution.ExecutionSteps.CHECKS,
execution.ExecutionSteps.COOLDOWNS,
)
LOGGER = logging.getLogger("lightbulb.client")


@t.runtime_checkable
Expand Down Expand Up @@ -82,6 +84,7 @@ class Client:
to support localizations.
delete_unknown_commands (:obj:`bool`): Whether to delete existing commands that the client does not have
an implementation for during command syncing.
deferred_registration_callback (:obj:`~typing.Optional` [ )
""" # noqa: E501

__slots__ = (
Expand All @@ -91,8 +94,10 @@ class Client:
"default_locale",
"localization_provider",
"delete_unknown_commands",
"deferred_registration_callback",
"_di",
"_localization",
"_deferred_commands",
"_commands",
"_application",
)
Expand All @@ -105,6 +110,10 @@ def __init__(
default_locale: hikari.Locale,
localization_provider: localization.LocalizationProviderT,
delete_unknown_commands: bool,
deferred_registration_callback: t.Callable[
[CommandOrGroup], MaybeAwaitable[t.Union[hikari.Snowflakeish, t.Sequence[hikari.Snowflakeish]]]
]
| None,
) -> None:
super().__init__()

Expand All @@ -114,9 +123,11 @@ def __init__(
self.default_locale = default_locale
self.localization_provider = localization_provider
self.delete_unknown_commands = delete_unknown_commands
self.deferred_registration_callback = deferred_registration_callback

self._di = di_.DependencyInjectionManager()

self._deferred_commands: list[CommandOrGroup] = []
self._commands: CommandMapT = collections.defaultdict(lambda: collections.defaultdict(utils.CommandCollection))
self._application: t.Optional[hikari.PartialApplication] = None

Expand All @@ -126,20 +137,20 @@ def di(self) -> di_.DependencyInjectionManager:

@t.overload
def register(
self, *, guilds: t.Optional[t.Sequence[hikari.Snowflakeish]] = None
self, *, guilds: t.Sequence[hikari.Snowflakeish] | None = None
) -> t.Callable[[CommandOrGroupT], CommandOrGroupT]: ...

@t.overload
def register(
self, command: CommandOrGroupT, *, guilds: t.Optional[t.Sequence[hikari.Snowflakeish]] = None
self, command: CommandOrGroupT, *, guilds: t.Sequence[hikari.Snowflakeish] | None = None
) -> CommandOrGroupT: ...

def register(
self,
command: t.Optional[CommandOrGroupT] = None,
command: CommandOrGroupT | None = None,
*,
guilds: t.Optional[t.Sequence[hikari.Snowflakeish]] = None,
) -> t.Union[CommandOrGroupT, t.Callable[[CommandOrGroupT], CommandOrGroupT]]:
guilds: t.Sequence[hikari.Snowflakeish] | None = None,
) -> CommandOrGroupT | t.Callable[[CommandOrGroupT], CommandOrGroupT]:
"""
Register a command or group with this client instance. Optionally, a sequence of guild ids can
be provided to make the commands created in specific guilds only - overriding the value for
Expand Down Expand Up @@ -204,6 +215,10 @@ def _inner(command_: CommandOrGroupT) -> CommandOrGroupT:

return _inner

def register_deferred(self, command: CommandOrGroupT) -> CommandOrGroupT:
self._deferred_commands.append(command)
return command

async def _ensure_application(self) -> hikari.PartialApplication:
if self._application is not None:
return self._application
Expand All @@ -223,7 +238,7 @@ async def sync_application_commands(self) -> None:
@staticmethod
def _get_subcommand(
options: t.Sequence[OptionT],
) -> t.Optional[OptionT]:
) -> OptionT | None:
subcommand = filter(
lambda o: o.type in (hikari.OptionType.SUB_COMMAND, hikari.OptionType.SUB_COMMAND_GROUP), options
)
Expand All @@ -232,24 +247,25 @@ def _get_subcommand(
@t.overload
def _resolve_options_and_command(
self, interaction: hikari.AutocompleteInteraction
) -> t.Optional[t.Tuple[t.Sequence[hikari.AutocompleteInteractionOption], t.Type[commands.CommandBase]]]: ...
) -> tuple[t.Sequence[hikari.AutocompleteInteractionOption], type[commands.CommandBase]] | None: ...

@t.overload
def _resolve_options_and_command(
self, interaction: hikari.CommandInteraction
) -> t.Optional[t.Tuple[t.Sequence[hikari.CommandInteractionOption], t.Type[commands.CommandBase]]]: ...
) -> tuple[t.Sequence[hikari.CommandInteractionOption], type[commands.CommandBase]] | None: ...

def _resolve_options_and_command(
self, interaction: t.Union[hikari.AutocompleteInteraction, hikari.CommandInteraction]
) -> t.Optional[
t.Tuple[
t.Union[t.Sequence[hikari.AutocompleteInteractionOption], t.Sequence[hikari.CommandInteractionOption]],
t.Type[commands.CommandBase],
self, interaction: hikari.AutocompleteInteraction | hikari.CommandInteraction
) -> (
tuple[
t.Sequence[hikari.AutocompleteInteractionOption] | t.Sequence[hikari.CommandInteractionOption],
type[commands.CommandBase],
]
]:
| None
):
command_path = [interaction.command_name]

subcommand: t.Union[hikari.CommandInteractionOption, hikari.AutocompleteInteractionOption, None]
subcommand: hikari.CommandInteractionOption | hikari.AutocompleteInteractionOption | None
options = interaction.options or [] # TODO - check if this is hikari bug with interaction server
while (subcommand := self._get_subcommand(options or [])) is not None:
command_path.append(subcommand.name)
Expand Down Expand Up @@ -287,7 +303,7 @@ def build_autocomplete_context(
self,
interaction: hikari.AutocompleteInteraction,
options: t.Sequence[hikari.AutocompleteInteractionOption],
command_cls: t.Type[commands.CommandBase],
command_cls: type[commands.CommandBase],
) -> context_.AutocompleteContext:
return context_.AutocompleteContext(self, interaction, options, command_cls)

Expand Down Expand Up @@ -331,7 +347,7 @@ def build_command_context(
self,
interaction: hikari.CommandInteraction,
options: t.Sequence[hikari.CommandInteractionOption],
command_cls: t.Type[commands.CommandBase],
command_cls: type[commands.CommandBase],
) -> context_.Context:
"""
Build a context object from the given parameters.
Expand Down Expand Up @@ -454,7 +470,7 @@ def build_rest_autocomplete_context(
self,
interaction: hikari.AutocompleteInteraction,
options: t.Sequence[hikari.AutocompleteInteractionOption],
command_cls: t.Type[commands.CommandBase],
command_cls: type[commands.CommandBase],
response_callback: t.Callable[[hikari.api.InteractionResponseBuilder], None],
) -> context_.AutocompleteContext:
return context_.RestAutocompleteContext(self, interaction, options, command_cls, response_callback)
Expand Down Expand Up @@ -509,19 +525,17 @@ def build_rest_command_context(
self,
interaction: hikari.CommandInteraction,
options: t.Sequence[hikari.CommandInteractionOption],
command_cls: t.Type[commands.CommandBase],
command_cls: type[commands.CommandBase],
response_callback: t.Callable[[hikari.api.InteractionResponseBuilder], None],
) -> context_.Context:
return context_.RestContext(self, interaction, options, command_cls(), response_callback)

async def handle_rest_application_command_interaction(
self, interaction: hikari.CommandInteraction
) -> t.AsyncGenerator[
t.Union[
hikari.api.InteractionDeferredBuilder,
hikari.api.InteractionMessageBuilder,
hikari.api.InteractionModalBuilder,
],
hikari.api.InteractionDeferredBuilder
| hikari.api.InteractionMessageBuilder
| hikari.api.InteractionModalBuilder,
t.Any,
]:
out = self._resolve_options_and_command(interaction)
Expand Down Expand Up @@ -560,7 +574,7 @@ def set_response(response: hikari.api.InteractionResponseBuilder) -> None:


def client_from_app(
app: t.Union[GatewayClientAppT, RestClientAppT],
app: GatewayClientAppT | RestClientAppT,
default_enabled_guilds: t.Sequence[hikari.Snowflakeish] = (constants.GLOBAL_COMMAND_KEY,),
execution_step_order: t.Sequence[execution.ExecutionStep] = DEFAULT_EXECUTION_STEP_ORDER,
default_locale: hikari.Locale = hikari.Locale.EN_US,
Expand Down
12 changes: 6 additions & 6 deletions lightbulb/commands/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class CommandData:
invoke_method: str
"""The attribute name of the invoke method for the command."""

parent: t.Optional[t.Union[groups.Group, groups.SubGroup]] = dataclasses.field(init=False, default=None)
parent: groups.Group | groups.SubGroup | None = dataclasses.field(init=False, default=None)
"""The group that the command belongs to, or :obj:`None` if not applicable."""

def __post_init__(self) -> None:
Expand Down Expand Up @@ -201,13 +201,13 @@ class CommandMeta(type):
run before the command invocation function is executed. Defaults to an empty set.
"""

__command_types: t.ClassVar[t.Dict[type, hikari.CommandType]] = {}
__command_types: t.ClassVar[dict[type, hikari.CommandType]] = {}

@staticmethod
def _is_option(item: t.Any) -> bool:
return isinstance(item, options_.Option)

def __new__(cls, cls_name: str, bases: t.Tuple[type, ...], attrs: t.Dict[str, t.Any], **kwargs: t.Any) -> type:
def __new__(cls, cls_name: str, bases: tuple[type, ...], attrs: dict[str, t.Any], **kwargs: t.Any) -> type:
cmd_type: hikari.CommandType
# Bodge because I cannot figure out how to avoid initialising all the kwargs in our
# own convenience classes any other way
Expand Down Expand Up @@ -245,7 +245,7 @@ def __new__(cls, cls_name: str, bases: t.Tuple[type, ...], attrs: t.Dict[str, t.
raise TypeError("all hooks must be an instance of ExecutionHook")

options: t.Dict[str, options_.OptionData[t.Any]] = {}
invoke_method: t.Optional[str] = None
invoke_method: str | None = None
# Iterate through new class attributes to find options and invoke method
for name, item in attrs.items():
if cls._is_option(item):
Expand Down Expand Up @@ -282,7 +282,7 @@ class CommandBase:
__slots__ = ("_current_context", "_resolved_option_cache")

_command_data: t.ClassVar[CommandData]
_current_context: t.Optional[context_.Context]
_current_context: context_.Context | None
_resolved_option_cache: t.MutableMapping[str, t.Any]

def __new__(cls, *args: t.Any, **kwargs: t.Any) -> CommandBase:
Expand All @@ -304,7 +304,7 @@ def _set_context(self, context: context_.Context) -> None:
self._current_context = context
self._resolved_option_cache = {}

def _resolve_option(self, option: options_.Option[T, D]) -> t.Union[T, D]:
def _resolve_option(self, option: options_.Option[T, D]) -> T | D:
"""
Resolves the actual value for the given option from the command's current
execution context. If the value has been resolved before and is available in the cache then
Expand Down
17 changes: 8 additions & 9 deletions lightbulb/commands/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@
from lightbulb import exceptions
from lightbulb.internal import constants
from lightbulb.internal import di
from lightbulb.internal.types import MaybeAwaitable

if t.TYPE_CHECKING:
from lightbulb import context as context_

__all__ = ["ExecutionStep", "ExecutionSteps", "ExecutionHook", "ExecutionPipeline", "hook", "invoke"]

ExecutionHookFuncT: t.TypeAlias = t.Callable[
["ExecutionPipeline", "context_.Context"], t.Union[t.Awaitable[None], None]
]
ExecutionHookFuncT: t.TypeAlias = t.Callable[["ExecutionPipeline", "context_.Context"], MaybeAwaitable[None]]


@dataclasses.dataclass(frozen=True, slots=True, eq=True)
Expand Down Expand Up @@ -110,14 +109,14 @@ def __init__(self, context: context_.Context, order: t.Sequence[ExecutionStep])
self._context = context
self._remaining = list(order)

self._hooks: t.Dict[ExecutionStep, t.List[ExecutionHook]] = collections.defaultdict(list)
self._hooks: dict[ExecutionStep, list[ExecutionHook]] = collections.defaultdict(list)
for hook in context.command_data.hooks:
self._hooks[hook.step].append(hook)

self._current_step: t.Optional[ExecutionStep] = None
self._current_hook: t.Optional[ExecutionHook] = None
self._current_step: ExecutionStep | None = None
self._current_hook: ExecutionHook | None = None

self._failure: t.Optional[exceptions.HookFailedException] = None
self._failure: exceptions.HookFailedException | None = None

@property
def failed(self) -> bool:
Expand All @@ -128,7 +127,7 @@ def failed(self) -> bool:
"""
return self._failure is not None

def _next_step(self) -> t.Optional[ExecutionStep]:
def _next_step(self) -> ExecutionStep | None:
"""
Return the next execution step to run, or :obj:`None` if the remaining execution steps
have been exhausted.
Expand Down Expand Up @@ -178,7 +177,7 @@ async def _run(self) -> None:
except Exception as e:
raise exceptions.InvocationFailedException(e, self._context)

def fail(self, exc: t.Union[str, Exception]) -> None:
def fail(self, exc: str | Exception) -> None:
"""
Notify the pipeline of a failure in an execution hook.
Expand Down
12 changes: 6 additions & 6 deletions lightbulb/commands/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,25 @@
from lightbulb import localization
from lightbulb.commands import commands

CommandT = t.TypeVar("CommandT", bound=t.Type["commands.CommandBase"])
SubGroupCommandMappingT = t.Dict[str, t.Type["commands.CommandBase"]]
GroupCommandMappingT = t.Dict[str, t.Union["SubGroup", t.Type["commands.CommandBase"]]]
CommandT = t.TypeVar("CommandT", bound=type["commands.CommandBase"])
SubGroupCommandMappingT = dict[str, type["commands.CommandBase"]]
GroupCommandMappingT = dict[str, t.Union["SubGroup", type["commands.CommandBase"]]]


class GroupMixin(abc.ABC):
"""Base class for application command groups."""

__slots__ = ()

_commands: t.Union[SubGroupCommandMappingT, GroupCommandMappingT]
_commands: SubGroupCommandMappingT | GroupCommandMappingT

@t.overload
def register(self) -> t.Callable[[CommandT], CommandT]: ...

@t.overload
def register(self, command: CommandT) -> CommandT: ...

def register(self, command: t.Optional[CommandT] = None) -> t.Union[CommandT, t.Callable[[CommandT], CommandT]]:
def register(self, command: CommandT | None = None) -> CommandT | t.Callable[[CommandT], CommandT]:
"""
Register a command as a subcommand for this group. Can be used as a first or second order decorator,
or called with the command to register.
Expand Down Expand Up @@ -91,7 +91,7 @@ def _inner(_command: CommandT) -> CommandT:

return _inner

def resolve_subcommand(self, path: t.List[str]) -> t.Optional[t.Type[commands.CommandBase]]:
def resolve_subcommand(self, path: list[str]) -> type[commands.CommandBase] | None:
"""
Resolve the subcommand for the given path - fully qualified command name.
Expand Down
Loading

0 comments on commit 37adc57

Please sign in to comment.