Skip to content

Commit

Permalink
Add support for Range[LargeInt, ...]
Browse files Browse the repository at this point in the history
  • Loading branch information
Enegg committed Jun 2, 2024
1 parent e701d61 commit 9e67086
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 50 deletions.
4 changes: 2 additions & 2 deletions disnake/app_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,9 @@ def __init__(
self.required: bool = required
self.options: List[Option] = options or []

if min_value and self.type is OptionType.integer:
if min_value is not None and self.type is OptionType.integer:
min_value = math.ceil(min_value)
if max_value and self.type is OptionType.integer:
if max_value is not None and self.type is OptionType.integer:
max_value = math.floor(max_value)

self.min_value: Optional[float] = min_value
Expand Down
31 changes: 31 additions & 0 deletions disnake/ext/commands/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,37 @@ def __init__(self, argument: str) -> None:
super().__init__(f"{argument} is not able to be converted to an integer")


class LargeIntOutOfRange(BadArgument):
"""Exception raised when an argument to a large integer option exceeds given range.
This inherits from :exc:`BadArgument`
.. versionadded:: 2.11
Attributes
----------
argument: :class:`str`
The argument that exceeded the defined range.
min_value: Optional[Union[:class:`int`, :class:`float`]]
The minimum allowed value.
max_value: Optional[Union[:class:`int`, :class:`float`]]
The maximum allowed value.
"""

def __init__(
self,
argument: str,
min_value: Union[int, float, None],
max_value: Union[int, float, None],
) -> None:
self.argument: str = argument
self.min_value: Union[int, float, None] = min_value
self.max_value: Union[int, float, None] = max_value
a = "..." if min_value is None else min_value
b = "..." if max_value is None else max_value
super().__init__(f"{argument} is not in range [{a}, {b}]")


class DisabledCommand(CommandError):
"""Exception raised when the command being invoked is disabled.
Expand Down
136 changes: 93 additions & 43 deletions disnake/ext/commands/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def issubclass_(obj: Any, tp: Union[TypeT, Tuple[TypeT, ...]]) -> TypeGuard[Type


def remove_optionals(annotation: Any) -> Any:
"""Remove unwanted optionals from an annotation"""
"""Remove unwanted optionals from an annotation."""
if get_origin(annotation) in (Union, UnionType):
args = tuple(i for i in annotation.__args__ if i not in (None, type(None)))
if len(args) == 1:
Expand Down Expand Up @@ -163,6 +163,29 @@ def _xt_to_xe(xe: Optional[float], xt: Optional[float], direction: float = 1) ->
return None


def _int_to_str_len(number: int) -> int:
"""Returns `len(str(number))`, i.e. character count of base 10 signed repr of `number`."""
# Desmos equivalent: floor(log(max(abs(x), 1))) + 1 + max(-sign(x), 0)
return (
int(math.log10(abs(number) or 1))
# 0 -> 0, 1 -> 0, 9 -> 0, 10 -> 1
+ 1
+ (number < 0)
)


def _range_to_str_len(min_value: int, max_value: int) -> Tuple[int, int]:
min_ = _int_to_str_len(min_value)
max_ = _int_to_str_len(max_value)
opposite_sign = (min_value < 0) ^ (max_value < 0)
# both bounds positive: len(str(min_value)) <= len(str(max_value))
# smaller bound negative: the range includes 0, which sets the minimum length to 1
# both bounds negative: len(str(min_value)) >= len(str(max_value))
if opposite_sign:
return 1, max(min_, max_)
return min(min_, max_), max(min_, max_)


class Injection(Generic[P, T_]):
"""Represents a slash command injection.
Expand Down Expand Up @@ -262,17 +285,24 @@ def decorator(func: CallableT) -> CallableT:
return decorator


NumberT = TypeVar("NumberT", bound=Union[int, float])


@dataclass(frozen=True)
class _BaseRange(ABC):
class _BaseRange(ABC, Generic[NumberT]):
"""Internal base type for supporting ``Range[...]`` and ``String[...]``."""

_allowed_types: ClassVar[Tuple[Type[Any], ...]]
_allowed_types: ClassVar[Tuple[type, ...]]

underlying_type: Type[Any]
min_value: Optional[Union[int, float]]
max_value: Optional[Union[int, float]]
underlying_type: type
min_value: Optional[NumberT]
max_value: Optional[NumberT]

def __class_getitem__(cls, params: Tuple[Any, ...]) -> Self:
if cls is _BaseRange:
# needed since made generic
return super().__class_getitem__(params) # pyright: ignore[reportAttributeAccessIssue]

# deconstruct type arguments
if not isinstance(params, tuple):
params = (params,)
Expand All @@ -290,13 +320,12 @@ def __class_getitem__(cls, params: Tuple[Any, ...]) -> Self:
f"Use `{name}[<type>, <min>, <max>]` instead.",
stacklevel=2,
)

# infer type from min/max values
params = (cls._infer_type(params),) + params

if len(params) != 3:
raise TypeError(
f"`{name}` expects 3 type arguments ({name}[<type>, <min>, <max>]), got {len(params)}"
f"`{name}` expects 3 arguments ({name}[<type>, <min>, <max>]), got {len(params)}"
)

underlying_type, min_value, max_value = params
Expand All @@ -305,7 +334,7 @@ def __class_getitem__(cls, params: Tuple[Any, ...]) -> Self:
if not isinstance(underlying_type, type):
raise TypeError(f"First `{name}` argument must be a type, not `{underlying_type!r}`")

if not issubclass(underlying_type, cls._allowed_types):
if not issubclass_(underlying_type, cls._allowed_types):
allowed = "/".join(t.__name__ for t in cls._allowed_types)
raise TypeError(f"First `{name}` argument must be {allowed}, not `{underlying_type!r}`")

Expand All @@ -326,7 +355,7 @@ def __class_getitem__(cls, params: Tuple[Any, ...]) -> Self:
return cls(underlying_type=underlying_type, min_value=min_value, max_value=max_value)

@staticmethod
def _coerce_bound(value: Any, name: str) -> Optional[Union[int, float]]:
def _coerce_bound(value: Union[NumberT, None, EllipsisType], name: str) -> Optional[NumberT]:
if value is None or isinstance(value, EllipsisType):
return None
elif isinstance(value, (int, float)):
Expand All @@ -341,9 +370,9 @@ def __repr__(self) -> str:
b = "..." if self.max_value is None else self.max_value
return f"{type(self).__name__}[{self.underlying_type.__name__}, {a}, {b}]"

@classmethod
@staticmethod
@abstractmethod
def _infer_type(cls, params: Tuple[Any, ...]) -> Type[Any]:
def _infer_type(params: Tuple[Any, ...]) -> Type[Any]:
raise NotImplementedError

# hack to get `typing._type_check` to pass, e.g. when using `Range` as a generic parameter
Expand All @@ -353,8 +382,8 @@ def __call__(self) -> NoReturn:
# support new union syntax for `Range[int, 1, 2] | None`
if sys.version_info >= (3, 10):

def __or__(self, other):
return Union[self, other] # type: ignore
def __or__(self, other: type) -> UnionType:
return Union[self, other] # pyright: ignore


if TYPE_CHECKING:
Expand All @@ -363,7 +392,7 @@ def __or__(self, other):
else:

@dataclass(frozen=True, repr=False)
class Range(_BaseRange):
class Range(_BaseRange[Union[int, float]]):
"""Type representing a number with a limited range of allowed values.
See :ref:`param_ranges` for more information.
Expand All @@ -377,22 +406,30 @@ class Range(_BaseRange):

_allowed_types = (int, float)

def __post_init__(self):
def __post_init__(self) -> None:
for value in (self.min_value, self.max_value):
if value is None:
continue

if self.underlying_type is int and not isinstance(value, int):
if self.underlying_type is not float and not isinstance(value, int):
raise TypeError("Range[int, ...] bounds must be int, not float")

@classmethod
def _infer_type(cls, params: Tuple[Any, ...]) -> Type[Any]:
if self.underlying_type is int and abs(value) >= 2**53:
raise ValueError(
"Discord has upper input limit on integer input type of +/-2**53.\n"
" For larger values, use Range[commands.LargeInt, ...], which will use"
" string input type with length limited to the minimum and maximum string"
" representations of the range bounds."
)

@staticmethod
def _infer_type(params: Tuple[Any, ...]) -> Type[Any]:
if any(isinstance(p, float) for p in params):
return float
return int

@dataclass(frozen=True, repr=False)
class String(_BaseRange):
class String(_BaseRange[int]):
"""Type representing a string option with a limited length.
See :ref:`string_lengths` for more information.
Expand All @@ -406,7 +443,7 @@ class String(_BaseRange):

_allowed_types = (str,)

def __post_init__(self):
def __post_init__(self) -> None:
for value in (self.min_value, self.max_value):
if value is None:
continue
Expand All @@ -416,13 +453,13 @@ def __post_init__(self):
if value < 0:
raise ValueError("String bounds may not be negative")

@classmethod
def _infer_type(cls, params: Tuple[Any, ...]) -> Type[Any]:
@staticmethod
def _infer_type(params: Tuple[Any, ...]) -> Type[Any]:
return str


class LargeInt(int):
"""Type for large integers in slash commands."""
"""Type representing integers `=<-2**53`, `>=2**53` in slash commands."""


# option types that require additional handling in verify_type
Expand Down Expand Up @@ -478,25 +515,23 @@ class ParamInfo:
"""

TYPES: ClassVar[Dict[type, int]] = {
# fmt: off
str: OptionType.string.value,
int: OptionType.integer.value,
bool: OptionType.boolean.value,
float: OptionType.number.value,
disnake.abc.User: OptionType.user.value,
disnake.User: OptionType.user.value,
disnake.Member: OptionType.user.value,
Union[disnake.User, disnake.Member]: OptionType.user.value,
# channels handled separately
disnake.abc.GuildChannel: OptionType.channel.value,
disnake.Role: OptionType.role.value,
disnake.abc.Snowflake: OptionType.mentionable.value,
Union[disnake.Member, disnake.Role]: OptionType.mentionable.value,
Union[disnake.User, disnake.Role]: OptionType.mentionable.value,
Union[disnake.User, disnake.Member, disnake.Role]: OptionType.mentionable.value,
float: OptionType.number.value,
disnake.Attachment: OptionType.attachment.value,
# fmt: on
}
# channels handled separately
disnake.abc.GuildChannel: OptionType.channel.value,
} # fmt: skip
_registered_converters: ClassVar[Dict[type, Callable]] = {}

def __init__(
Expand All @@ -511,10 +546,10 @@ def __init__(
choices: Optional[Choices] = None,
type: Optional[type] = None,
channel_types: Optional[List[ChannelType]] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Union[int, float, None] = None,
le: Union[int, float, None] = None,
gt: Union[int, float, None] = None,
ge: Union[int, float, None] = None,
large: bool = False,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
Expand All @@ -535,10 +570,10 @@ def __init__(
self.choices = choices or []
self.type = type or str
self.channel_types = channel_types or []
self.max_value = _xt_to_xe(le, lt, -1)
self.min_value = _xt_to_xe(ge, gt, 1)
self.min_length = min_length
self.max_length = max_length
self.min_value: Union[int, float, None] = _xt_to_xe(ge, gt, 1)
self.max_value: Union[int, float, None] = _xt_to_xe(le, lt, -1)
self.min_length: Optional[int] = min_length
self.max_length: Optional[int] = max_length
self.large = large

def copy(self) -> Self:
Expand Down Expand Up @@ -619,7 +654,7 @@ def __repr__(self) -> str:
return f"{type(self).__name__}({args})"

async def get_default(self, inter: ApplicationCommandInteraction) -> Any:
"""Gets the default for an interaction"""
"""Gets the default for an interaction."""
default = self.default
if callable(self.default):
default = self.default(inter)
Expand Down Expand Up @@ -651,13 +686,19 @@ async def verify_type(self, inter: ApplicationCommandInteraction, argument: Any)
return argument

async def convert_argument(self, inter: ApplicationCommandInteraction, argument: Any) -> Any:
"""Convert a value if a converter is given"""
"""Convert a value if a converter is given."""
if self.large:
try:
argument = int(argument)
except ValueError:
raise errors.LargeIntConversionFailure(argument) from None

min_value = -math.inf if self.min_value is None else self.min_value
max_value = math.inf if self.max_value is None else self.max_value

if not min_value <= argument <= max_value:
raise errors.LargeIntOutOfRange(argument, self.min_value, self.max_value) from None

if self.converter is None:
# TODO: Custom validators
return await self.verify_type(inter, argument)
Expand Down Expand Up @@ -717,10 +758,12 @@ def parse_annotation(self, annotation: Any, converter_mode: bool = False) -> boo
self.min_value = annotation.min_value
self.max_value = annotation.max_value
annotation = annotation.underlying_type
if isinstance(annotation, String):

elif isinstance(annotation, String):
self.min_length = annotation.min_value
self.max_length = annotation.max_value
annotation = annotation.underlying_type

if issubclass_(annotation, LargeInt):
self.large = True
annotation = int
Expand All @@ -729,6 +772,13 @@ def parse_annotation(self, annotation: Any, converter_mode: bool = False) -> boo
self.type = str
if annotation is not int:
raise TypeError("Large integers must be annotated with int or LargeInt")

# if either bound is None or ..., we cannot restrict the length
if self.min_value is not None and self.max_value is not None:
self.min_length, self.max_length = _range_to_str_len(
self.min_value, self.max_value # pyright: ignore
)

elif annotation in self.TYPES:
self.type = annotation
elif (
Expand Down Expand Up @@ -827,8 +877,8 @@ def to_option(self) -> Option:
choices=self.choices or None,
channel_types=self.channel_types,
autocomplete=self.autocomplete is not None,
min_value=self.min_value,
max_value=self.max_value,
min_value=None if self.large else self.min_value,
max_value=None if self.large else self.max_value,
min_length=self.min_length,
max_length=self.max_length,
)
Expand Down
Loading

0 comments on commit 9e67086

Please sign in to comment.