From e1d014ea4a89bdb0b7d1c7dfefe2ad81cc7d2461 Mon Sep 17 00:00:00 2001 From: tandemdude <43570299+tandemdude@users.noreply.github.com> Date: Sun, 11 Aug 2024 21:56:17 +0100 Subject: [PATCH] chore: refactor to remove some code duplication; create component container ABC --- lightbulb/components/base.py | 90 +++++++++++++++++++++++++++++++++- lightbulb/components/menus.py | 76 ++-------------------------- lightbulb/components/modals.py | 87 +++++--------------------------- 3 files changed, 104 insertions(+), 149 deletions(-) diff --git a/lightbulb/components/base.py b/lightbulb/components/base.py index 4c8e0e70..50ccadc3 100644 --- a/lightbulb/components/base.py +++ b/lightbulb/components/base.py @@ -20,10 +20,11 @@ # SOFTWARE. from __future__ import annotations -__all__ = ["BaseComponent"] +__all__ = ["BaseComponent", "BuildableComponentContainer", "MessageResponseMixinWithEdit"] import abc import typing as t +from collections.abc import Sequence import hikari from hikari.api import special_endpoints @@ -32,9 +33,10 @@ from lightbulb.internal import constants if t.TYPE_CHECKING: - from collections.abc import Sequence + import typing_extensions as t_ex RowT = t.TypeVar("RowT", special_endpoints.MessageActionRowBuilder, special_endpoints.ModalActionRowBuilder) +BaseComponentT = t.TypeVar("BaseComponentT", bound="BaseComponent[t.Any]") class BaseComponent(abc.ABC, t.Generic[RowT]): @@ -192,3 +194,87 @@ async def respond( role_mentions=role_mentions, ) ).id + + +class BuildableComponentContainer(abc.ABC, Sequence[special_endpoints.ComponentBuilder], t.Generic[RowT]): + __slots__ = ("__current_row", "__rows") + + __current_row: int + __rows: list[list[BaseComponent[RowT]]] + + @property + def _current_row(self) -> int: + try: + return self.__current_row + except AttributeError: + self.__current_row = 0 + return self.__current_row + + @property + def _rows(self) -> list[list[BaseComponent[RowT]]]: + try: + return self.__rows + except AttributeError: + self.__rows = [[] for _ in range(self._max_rows)] + return self.__rows + + @t.overload + def __getitem__(self, item: int) -> special_endpoints.ComponentBuilder: ... + + @t.overload + def __getitem__(self, item: slice) -> Sequence[special_endpoints.ComponentBuilder]: ... + + def __getitem__( + self, item: int | slice + ) -> special_endpoints.ComponentBuilder | Sequence[special_endpoints.ComponentBuilder]: + return self._build().__getitem__(item) + + def __len__(self) -> int: + return sum(1 for row in self._rows if row) + + def _build(self) -> Sequence[special_endpoints.ComponentBuilder]: + built_rows: list[special_endpoints.ComponentBuilder] = [] + for row in self._rows: + if not row: + continue + + bld = self._make_action_row() + for component in row: + bld = component.add_to_row(bld) + built_rows.append(bld) + return built_rows + + def clear_rows(self) -> t_ex.Self: + self._rows.clear() + return self + + def clear_current_row(self) -> t_ex.Self: + self._rows[self._current_row].clear() + return self + + def next_row(self) -> t_ex.Self: + if self._current_row + 1 >= self._max_rows: + raise RuntimeError("the maximum number of rows has been reached") + self.__current_row += 1 + return self + + def previous_row(self) -> t_ex.Self: + self.__current_row = max(0, self.__current_row - 1) + return self + + def add(self, component: BaseComponentT) -> BaseComponentT: + if self._current_row_full(): + self.next_row() + + self._rows[self._current_row].append(component) + return component + + @property + @abc.abstractmethod + def _max_rows(self) -> int: ... + + @abc.abstractmethod + def _make_action_row(self) -> RowT: ... + + @abc.abstractmethod + def _current_row_full(self) -> bool: ... diff --git a/lightbulb/components/menus.py b/lightbulb/components/menus.py index 30ae8e7f..c71b51df 100644 --- a/lightbulb/components/menus.py +++ b/lightbulb/components/menus.py @@ -54,8 +54,6 @@ from collections.abc import Awaitable from collections.abc import Callable - import typing_extensions as t_ex - from lightbulb import client as client_ ValidSelectOptions: t.TypeAlias = t.Union[Sequence["TextSelectOption"], Sequence[str], Sequence[tuple[str, str]]] @@ -375,56 +373,17 @@ async def respond( ) -class Menu(Sequence[special_endpoints.ComponentBuilder]): +class Menu(base.BuildableComponentContainer[special_endpoints.MessageActionRowBuilder]): __slots__ = ("__current_row", "__rows") - _MAX_ROWS: t.Final[int] = 5 _MAX_BUTTONS_PER_ROW: t.Final[int] = 5 - __current_row: int - __rows: list[list[base.BaseComponent[special_endpoints.MessageActionRowBuilder]]] - - @t.overload - def __getitem__(self, item: int) -> special_endpoints.ComponentBuilder: ... - - @t.overload - def __getitem__(self, item: slice) -> Sequence[special_endpoints.ComponentBuilder]: ... - - def __getitem__( - self, item: int | slice - ) -> special_endpoints.ComponentBuilder | Sequence[special_endpoints.ComponentBuilder]: - return self._build().__getitem__(item) - - def __len__(self) -> int: - return sum(1 for row in self._rows if row) - - def _build(self) -> Sequence[special_endpoints.ComponentBuilder]: - built_rows: list[special_endpoints.ComponentBuilder] = [] - for row in self._rows: - if not row: - continue - - bld = special_endpoints_impl.MessageActionRowBuilder() - for component in row: - bld = component.add_to_row(bld) - built_rows.append(bld) - return built_rows - @property - def _current_row(self) -> int: - try: - return self.__current_row - except AttributeError: - self.__current_row = 0 - return self.__current_row + def _max_rows(self) -> int: + return 5 - @property - def _rows(self) -> list[list[base.BaseComponent[special_endpoints.MessageActionRowBuilder]]]: - try: - return self.__rows - except AttributeError: - self.__rows = [[] for _ in range(self._MAX_ROWS)] - return self.__rows + def _make_action_row(self) -> special_endpoints.MessageActionRowBuilder: + return special_endpoints_impl.MessageActionRowBuilder() def _current_row_full(self) -> bool: return bool( @@ -432,31 +391,6 @@ def _current_row_full(self) -> bool: or ((r := self._rows[self._current_row]) and isinstance(r[0], Select)) ) - def clear_rows(self) -> t_ex.Self: - self._rows.clear() - return self - - def clear_current_row(self) -> t_ex.Self: - self._rows[self._current_row].clear() - return self - - def next_row(self) -> t_ex.Self: - if self._current_row + 1 >= self._MAX_ROWS: - raise RuntimeError("the maximum number of rows has been reached") - self.__current_row += 1 - return self - - def previous_row(self) -> t_ex.Self: - self.__current_row = max(0, self.__current_row - 1) - return self - - def add(self, component: MessageComponentT) -> MessageComponentT: - if self._current_row_full(): - self.next_row() - - self._rows[self._current_row].append(component) - return component - def add_interactive_button( self, style: hikari.ButtonStyle, diff --git a/lightbulb/components/modals.py b/lightbulb/components/modals.py index 3439d951..3f59c310 100644 --- a/lightbulb/components/modals.py +++ b/lightbulb/components/modals.py @@ -1,4 +1,7 @@ # -*- coding: utf-8 -*- +# +# api_ref_gen::add_autodoc_option::inherited-members +# # Copyright (c) 2023-present tandemdude # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -27,7 +30,6 @@ import dataclasses import typing as t import uuid -from collections.abc import Sequence import async_timeout import hikari @@ -37,8 +39,6 @@ from lightbulb.components import base if t.TYPE_CHECKING: - import typing_extensions as t_ex - from lightbulb import client as client_ ModalComponentT = t.TypeVar("ModalComponentT", bound=base.BaseComponent[special_endpoints.ModalActionRowBuilder]) @@ -101,87 +101,19 @@ def value_for(self, input: TextInput) -> str | None: return None -class Modal(abc.ABC, Sequence[special_endpoints.ComponentBuilder]): - _MAX_ROWS: t.Final[int] = 5 - - __current_row: int - __rows: list[list[base.BaseComponent[special_endpoints.ModalActionRowBuilder]]] - - @abc.abstractmethod - async def on_submit(self, ctx: ModalContext) -> None: ... - - @t.overload - def __getitem__(self, item: int) -> special_endpoints.ComponentBuilder: ... - - @t.overload - def __getitem__(self, item: slice) -> Sequence[special_endpoints.ComponentBuilder]: ... - - def __getitem__( - self, item: int | slice - ) -> special_endpoints.ComponentBuilder | Sequence[special_endpoints.ComponentBuilder]: - return self._build().__getitem__(item) - - def __len__(self) -> int: - return sum(1 for row in self._rows if row) - - def _build(self) -> Sequence[special_endpoints.ComponentBuilder]: - built_rows: list[special_endpoints.ComponentBuilder] = [] - for row in self._rows: - if not row: - continue - - bld = special_endpoints_impl.ModalActionRowBuilder() - for component in row: - bld = component.add_to_row(bld) - built_rows.append(bld) - return built_rows - +class Modal(base.BuildableComponentContainer[special_endpoints.ModalActionRowBuilder], abc.ABC): @property - def _current_row(self) -> int: - try: - return self.__current_row - except AttributeError: - self.__current_row = 0 - return self.__current_row + def _max_rows(self) -> int: + return 5 - @property - def _rows(self) -> list[list[base.BaseComponent[special_endpoints.ModalActionRowBuilder]]]: - try: - return self.__rows - except AttributeError: - self.__rows = [[] for _ in range(self._MAX_ROWS)] - return self.__rows + def _make_action_row(self) -> special_endpoints.ModalActionRowBuilder: + return special_endpoints_impl.ModalActionRowBuilder() def _current_row_full(self) -> bool: # Currently, you are only allowed a single component within each row # Maybe Discord will change this in the future return bool(self._rows[self._current_row]) - def clear_rows(self) -> t_ex.Self: - self._rows.clear() - return self - - def clear_current_row(self) -> t_ex.Self: - self._rows[self._current_row].clear() - return self - - def next_row(self) -> t_ex.Self: - if self._current_row + 1 >= self._MAX_ROWS: - raise RuntimeError("the maximum number of rows has been reached") - self.__current_row += 1 - return self - - def previous_row(self) -> t_ex.Self: - self.__current_row = max(0, self.__current_row - 1) - return self - - def add(self, component: ModalComponentT) -> ModalComponentT: - if self._current_row_full(): - self.next_row() - - self._rows[self._current_row].append(component) - return component - def add_short_text_input( self, label: str, @@ -248,3 +180,6 @@ async def attach(self, client: client_.Client, custom_id: str, *, timeout: float finally: # Unregister queue from client client._modal_queues.remove(queue) + + @abc.abstractmethod + async def on_submit(self, ctx: ModalContext) -> None: ...