Skip to content

Commit

Permalink
Refactor sigparse
Browse files Browse the repository at this point in the history
  • Loading branch information
hypergonial committed Jan 19, 2024
1 parent b415a07 commit c82dc1a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 93 deletions.
115 changes: 44 additions & 71 deletions arc/internal/sigparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,39 +34,21 @@
from arc.context import Context
from arc.internal.types import ClientT, EventT

# pyright: reportUnnecessaryTypeIgnoreComment=false

__all__ = ("parse_command_signature",)

# Potential pyright bug? This wasn't reported in 1.1.345
TYPE_TO_OPTION_MAPPING: dict[type[t.Any], type[CommandOptionBase[t.Any, t.Any, t.Any]]] = { # pyright: ignore reportGeneralTypeIssues
#
# This doesn't include some special cases, for the complete resolution logic see: _get_option_type()
TYPE_TO_OPTION_MAPPING: dict[type[t.Any], type[CommandOptionBase[t.Any, t.Any, t.Any]]] = {
bool: BoolOption,
int: IntOption,
str: StrOption,
float: FloatOption,
hikari.Role: RoleOption,
hikari.User | hikari.Role: MentionableOption,
t.Union[hikari.User, hikari.Role]: MentionableOption,
hikari.Attachment: AttachmentOption,
hikari.User: UserOption,
hikari.GuildTextChannel: ChannelOption,
hikari.GuildVoiceChannel: ChannelOption,
hikari.GuildCategory: ChannelOption,
hikari.GuildNewsChannel: ChannelOption,
hikari.GuildPrivateThread: ChannelOption,
hikari.GuildPublicThread: ChannelOption,
hikari.GuildNewsThread: ChannelOption,
hikari.GuildForumChannel: ChannelOption,
hikari.GuildThreadChannel: ChannelOption,
hikari.DMChannel: ChannelOption,
hikari.GroupDMChannel: ChannelOption,
hikari.GuildStageChannel: ChannelOption,
hikari.PartialChannel: ChannelOption,
hikari.InteractionChannel: ChannelOption,
hikari.TextableChannel: ChannelOption,
hikari.GuildChannel: ChannelOption,
hikari.PrivateChannel: ChannelOption,
hikari.PermissibleGuildChannel: ChannelOption,
hikari.TextableGuildChannel: ChannelOption,
}

OPT_TO_PARAMS_MAPPING: dict[type[CommandOptionBase[t.Any, t.Any, t.Any]], type[t.Any]] = {
Expand Down Expand Up @@ -110,7 +92,7 @@ def _get_channel_type(channel: type[hikari.PartialChannel]) -> set[hikari.Channe


def _get_all_channel_types() -> dict[type[hikari.PartialChannel], set[hikari.ChannelType]]:
"""Get all channel types."""
"""Get all channels and their corresponding channel types."""
mapping: dict[type[hikari.PartialChannel], set[hikari.ChannelType]] = {}

for _, attribute in inspect.getmembers(
Expand All @@ -125,34 +107,45 @@ def _get_all_channel_types() -> dict[type[hikari.PartialChannel], set[hikari.Cha
CHANNEL_TYPES_MAPPING = _get_all_channel_types()


def _get_option_type(hint: t.Any) -> type[CommandOptionBase[t.Any, t.Any, t.Any]] | None:
"""Get the option type from a type hint."""
if _is_mentionable_union(hint):
return MentionableOption # pyright: ignore reportGeneralTypeIssues

elif _is_union(hint):
hints = [arg for arg in t.get_args(hint) if arg is not type(None)]
first = _get_option_type(hints[0])
# Check if it is a uniform union recursively
if all(_get_option_type(arg) is first for arg in hints):
return first

elif hint in CHANNEL_TYPES_MAPPING:
return ChannelOption # pyright: ignore reportGeneralTypeIssues

else:
return TYPE_TO_OPTION_MAPPING.get(hint)


def _is_param(meta: t.Any) -> bool:
"""Return True if the metadata is a command option parameter object."""
return isinstance(meta, OptionParams)


def _is_union(hint: t.Any) -> bool:
"""Return True if the type hint is a typing.Union. or Python 3.10's types.UnionType."""
return t.get_origin(hint) is t.Union or t.get_origin(hint) is types.UnionType


def _is_optional_union(hint: t.Any) -> bool:
"""Return True if the type hint is a typing.Union[T, None], also known as typing.Optional[T]."""
return t.get_origin(hint) is t.Union and len(t.get_args(hint)) == 2 and type(None) in t.get_args(hint)


def _extract_optional_type(hint: t.Any) -> type[t.Any]:
"""Convert typing.Optional[T] to T."""
return next(arg for arg in t.get_args(hint) if arg is not type(None))


def _get_supported_types() -> list[str]:
"""Get a list of supported types.
Used in error messages.
Returns
-------
list[str]
The list of supported types
"""
return [type_.__name__ if type(type_) is type else repr(type_) for type_ in TYPE_TO_OPTION_MAPPING]


def _is_mentionable_union(hint: t.Any) -> bool:
"""Check if a type hint is a union that represents a MentionableOption.
Expand Down Expand Up @@ -226,7 +219,7 @@ def _parse_channel_union_type_hint(hint: t.Any) -> list[hikari.ChannelType]:
args = t.get_args(hint)

if not all((issubclass(arg, hikari.PartialChannel)) or arg is type(None) for arg in args):
raise TypeError(f"Union expressions are only supported for channels, not '{hint!r}'")
raise TypeError(f"Union of channels is not uniform: '{hint!r}'")

return _channels_to_channel_types(arg for arg in args if arg is not type(None))

Expand Down Expand Up @@ -294,55 +287,33 @@ def parse_command_signature( # noqa: C901
if not _is_param(params):
continue

# If it's a union, verify all types are supported
# If it's a union, update is_optional
if _is_union(type_):
union = type_
union_args = t.get_args(union)

if not _is_mentionable_union(union) and not all(
arg is type(None) or arg in TYPE_TO_OPTION_MAPPING for arg in union_args
):
raise TypeError(
f"Unsupported option type: '{union!r}'\nSupported option types: {_get_supported_types()}"
)
type_ = next((arg for arg in union_args if arg in TYPE_TO_OPTION_MAPPING))
is_optional = is_optional or type(None) in union_args

# Verify if it's a supported type
elif type_ not in TYPE_TO_OPTION_MAPPING:
raise TypeError(f"Unsupported option type: '{type_!r}'\nSupported option types: {_get_supported_types()}")

# Get the corresponding option type
if union is not None and _is_mentionable_union(union):
opt_type = MentionableOption
else:
opt_type = TYPE_TO_OPTION_MAPPING[type_]
is_optional = is_optional or type(None) in t.get_args(union)

opt_type = _get_option_type(type_)

# If the opt_type is None, it failed to resolve
if opt_type is None:
raise TypeError(f"Unsupported option type: '{type_!r}'")

# Verify the params type matches the option type
if not isinstance(params, OPT_TO_PARAMS_MAPPING[opt_type]):
raise TypeError(
f"Expected params object to be of type {OPT_TO_PARAMS_MAPPING[opt_type].__name__}, got '{type(params).__name__}'"
f"Expected params object to be of type '{OPT_TO_PARAMS_MAPPING[opt_type].__name__}', got '{type(params).__name__}'"
)

# If it's a union of channel types, we need to parse all channel types
if union is not None and type_ in CHANNEL_TYPES_MAPPING:
if union is not None and any(arg in CHANNEL_TYPES_MAPPING for arg in t.get_args(union)):
channel_types = _parse_channel_union_type_hint(union)
options[arg_name] = ChannelOption._from_params(
name=params.name or arg_name, is_required=not is_optional, params=params, channel_types=channel_types
)
continue

# Parse mentionable unions
if union is not None and {arg for arg in t.get_args(union) if arg is not type(None)} == {
hikari.User,
hikari.Role,
}:
options[arg_name] = MentionableOption._from_params(
name=params.name or arg_name, is_required=not is_optional, params=params
)
continue

# If it's a single channel type, just pass the channel type
if type_ in CHANNEL_TYPES_MAPPING:
elif type_ in CHANNEL_TYPES_MAPPING:
options[arg_name] = ChannelOption._from_params(
name=params.name or arg_name,
is_required=not is_optional,
Expand All @@ -351,8 +322,10 @@ def parse_command_signature( # noqa: C901
)
continue

print(type_)

# Otherwise just build the option
options[arg_name] = TYPE_TO_OPTION_MAPPING[type_]._from_params(
options[arg_name] = opt_type._from_params(
name=params.name or arg_name, is_required=not is_optional, params=params
)

Expand Down
23 changes: 1 addition & 22 deletions tests/test_sigparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@
import pytest

import arc
from arc.internal.sigparse import (
BASE_CHANNEL_TYPE_MAP,
CHANNEL_TYPES_MAPPING,
OPT_TO_PARAMS_MAPPING,
TYPE_TO_OPTION_MAPPING,
parse_command_signature,
)
from arc.internal.sigparse import BASE_CHANNEL_TYPE_MAP, CHANNEL_TYPES_MAPPING, parse_command_signature


async def correct_command(
Expand Down Expand Up @@ -165,27 +159,12 @@ def test_ensure_parse_channel_types_has_every_channel_class() -> None:

assert result is not None, f"Missing channel type for {attribute} in CHANNEL_TYPES_MAPPING"

result = attribute in TYPE_TO_OPTION_MAPPING

assert result is True, f"Missing channel type for {attribute} in TYPE_TO_OPTION_MAPPING"


def test_ensure_base_channels_has_every_channel_type() -> None:
for channel_type in hikari.ChannelType:
assert channel_type in BASE_CHANNEL_TYPE_MAP.values()


def test_ensure_option_types_has_every_option() -> None:
for _, attribute in inspect.getmembers(
arc.command.option, lambda a: isinstance(a, type) and issubclass(a, arc.abc.option.CommandOptionBase)
):
assert (
attribute in TYPE_TO_OPTION_MAPPING.values()
), f"Missing option type for {attribute} in TYPE_TO_OPTION_MAPPING"

assert attribute in OPT_TO_PARAMS_MAPPING, f"Missing option type for {attribute} in OPT_TO_PARAMS_MAPPING"


# MIT License
#
# Copyright (c) 2023-present hypergonial
Expand Down

0 comments on commit c82dc1a

Please sign in to comment.