Skip to content

Commit

Permalink
chore: refactor to remove some code duplication; create component con…
Browse files Browse the repository at this point in the history
…tainer ABC
  • Loading branch information
tandemdude committed Aug 11, 2024
1 parent d027521 commit e1d014e
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 149 deletions.
90 changes: 88 additions & 2 deletions lightbulb/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
# SOFTWARE.
from __future__ import annotations

__all__ = ["BaseComponent"]
__all__ = ["BaseComponent", "BuildableComponentContainer", "MessageResponseMixinWithEdit"]

import abc
import typing as t
from collections.abc import Sequence

import hikari
from hikari.api import special_endpoints
Expand All @@ -32,9 +33,10 @@
from lightbulb.internal import constants

if t.TYPE_CHECKING:
from collections.abc import Sequence
import typing_extensions as t_ex

RowT = t.TypeVar("RowT", special_endpoints.MessageActionRowBuilder, special_endpoints.ModalActionRowBuilder)
BaseComponentT = t.TypeVar("BaseComponentT", bound="BaseComponent[t.Any]")


class BaseComponent(abc.ABC, t.Generic[RowT]):
Expand Down Expand Up @@ -192,3 +194,87 @@ async def respond(
role_mentions=role_mentions,
)
).id


class BuildableComponentContainer(abc.ABC, Sequence[special_endpoints.ComponentBuilder], t.Generic[RowT]):
__slots__ = ("__current_row", "__rows")

__current_row: int
__rows: list[list[BaseComponent[RowT]]]

@property
def _current_row(self) -> int:
try:
return self.__current_row
except AttributeError:
self.__current_row = 0
return self.__current_row

@property
def _rows(self) -> list[list[BaseComponent[RowT]]]:
try:
return self.__rows
except AttributeError:
self.__rows = [[] for _ in range(self._max_rows)]
return self.__rows

@t.overload
def __getitem__(self, item: int) -> special_endpoints.ComponentBuilder: ...

@t.overload
def __getitem__(self, item: slice) -> Sequence[special_endpoints.ComponentBuilder]: ...

def __getitem__(
self, item: int | slice
) -> special_endpoints.ComponentBuilder | Sequence[special_endpoints.ComponentBuilder]:
return self._build().__getitem__(item)

def __len__(self) -> int:
return sum(1 for row in self._rows if row)

def _build(self) -> Sequence[special_endpoints.ComponentBuilder]:
built_rows: list[special_endpoints.ComponentBuilder] = []
for row in self._rows:
if not row:
continue

bld = self._make_action_row()
for component in row:
bld = component.add_to_row(bld)
built_rows.append(bld)
return built_rows

def clear_rows(self) -> t_ex.Self:
self._rows.clear()
return self

def clear_current_row(self) -> t_ex.Self:
self._rows[self._current_row].clear()
return self

def next_row(self) -> t_ex.Self:
if self._current_row + 1 >= self._max_rows:
raise RuntimeError("the maximum number of rows has been reached")
self.__current_row += 1
return self

def previous_row(self) -> t_ex.Self:
self.__current_row = max(0, self.__current_row - 1)
return self

def add(self, component: BaseComponentT) -> BaseComponentT:
if self._current_row_full():
self.next_row()

self._rows[self._current_row].append(component)
return component

@property
@abc.abstractmethod
def _max_rows(self) -> int: ...

@abc.abstractmethod
def _make_action_row(self) -> RowT: ...

@abc.abstractmethod
def _current_row_full(self) -> bool: ...
76 changes: 5 additions & 71 deletions lightbulb/components/menus.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@
from collections.abc import Awaitable
from collections.abc import Callable

import typing_extensions as t_ex

from lightbulb import client as client_

ValidSelectOptions: t.TypeAlias = t.Union[Sequence["TextSelectOption"], Sequence[str], Sequence[tuple[str, str]]]
Expand Down Expand Up @@ -375,88 +373,24 @@ async def respond(
)


class Menu(Sequence[special_endpoints.ComponentBuilder]):
class Menu(base.BuildableComponentContainer[special_endpoints.MessageActionRowBuilder]):
__slots__ = ("__current_row", "__rows")

_MAX_ROWS: t.Final[int] = 5
_MAX_BUTTONS_PER_ROW: t.Final[int] = 5

__current_row: int
__rows: list[list[base.BaseComponent[special_endpoints.MessageActionRowBuilder]]]

@t.overload
def __getitem__(self, item: int) -> special_endpoints.ComponentBuilder: ...

@t.overload
def __getitem__(self, item: slice) -> Sequence[special_endpoints.ComponentBuilder]: ...

def __getitem__(
self, item: int | slice
) -> special_endpoints.ComponentBuilder | Sequence[special_endpoints.ComponentBuilder]:
return self._build().__getitem__(item)

def __len__(self) -> int:
return sum(1 for row in self._rows if row)

def _build(self) -> Sequence[special_endpoints.ComponentBuilder]:
built_rows: list[special_endpoints.ComponentBuilder] = []
for row in self._rows:
if not row:
continue

bld = special_endpoints_impl.MessageActionRowBuilder()
for component in row:
bld = component.add_to_row(bld)
built_rows.append(bld)
return built_rows

@property
def _current_row(self) -> int:
try:
return self.__current_row
except AttributeError:
self.__current_row = 0
return self.__current_row
def _max_rows(self) -> int:
return 5

@property
def _rows(self) -> list[list[base.BaseComponent[special_endpoints.MessageActionRowBuilder]]]:
try:
return self.__rows
except AttributeError:
self.__rows = [[] for _ in range(self._MAX_ROWS)]
return self.__rows
def _make_action_row(self) -> special_endpoints.MessageActionRowBuilder:
return special_endpoints_impl.MessageActionRowBuilder()

def _current_row_full(self) -> bool:
return bool(
len(self._rows[self._current_row]) >= self._MAX_BUTTONS_PER_ROW
or ((r := self._rows[self._current_row]) and isinstance(r[0], Select))
)

def clear_rows(self) -> t_ex.Self:
self._rows.clear()
return self

def clear_current_row(self) -> t_ex.Self:
self._rows[self._current_row].clear()
return self

def next_row(self) -> t_ex.Self:
if self._current_row + 1 >= self._MAX_ROWS:
raise RuntimeError("the maximum number of rows has been reached")
self.__current_row += 1
return self

def previous_row(self) -> t_ex.Self:
self.__current_row = max(0, self.__current_row - 1)
return self

def add(self, component: MessageComponentT) -> MessageComponentT:
if self._current_row_full():
self.next_row()

self._rows[self._current_row].append(component)
return component

def add_interactive_button(
self,
style: hikari.ButtonStyle,
Expand Down
87 changes: 11 additions & 76 deletions lightbulb/components/modals.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# -*- coding: utf-8 -*-
#
# api_ref_gen::add_autodoc_option::inherited-members
#
# Copyright (c) 2023-present tandemdude
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
Expand Down Expand Up @@ -27,7 +30,6 @@
import dataclasses
import typing as t
import uuid
from collections.abc import Sequence

import async_timeout
import hikari
Expand All @@ -37,8 +39,6 @@
from lightbulb.components import base

if t.TYPE_CHECKING:
import typing_extensions as t_ex

from lightbulb import client as client_

ModalComponentT = t.TypeVar("ModalComponentT", bound=base.BaseComponent[special_endpoints.ModalActionRowBuilder])
Expand Down Expand Up @@ -101,87 +101,19 @@ def value_for(self, input: TextInput) -> str | None:
return None


class Modal(abc.ABC, Sequence[special_endpoints.ComponentBuilder]):
_MAX_ROWS: t.Final[int] = 5

__current_row: int
__rows: list[list[base.BaseComponent[special_endpoints.ModalActionRowBuilder]]]

@abc.abstractmethod
async def on_submit(self, ctx: ModalContext) -> None: ...

@t.overload
def __getitem__(self, item: int) -> special_endpoints.ComponentBuilder: ...

@t.overload
def __getitem__(self, item: slice) -> Sequence[special_endpoints.ComponentBuilder]: ...

def __getitem__(
self, item: int | slice
) -> special_endpoints.ComponentBuilder | Sequence[special_endpoints.ComponentBuilder]:
return self._build().__getitem__(item)

def __len__(self) -> int:
return sum(1 for row in self._rows if row)

def _build(self) -> Sequence[special_endpoints.ComponentBuilder]:
built_rows: list[special_endpoints.ComponentBuilder] = []
for row in self._rows:
if not row:
continue

bld = special_endpoints_impl.ModalActionRowBuilder()
for component in row:
bld = component.add_to_row(bld)
built_rows.append(bld)
return built_rows

class Modal(base.BuildableComponentContainer[special_endpoints.ModalActionRowBuilder], abc.ABC):
@property
def _current_row(self) -> int:
try:
return self.__current_row
except AttributeError:
self.__current_row = 0
return self.__current_row
def _max_rows(self) -> int:
return 5

@property
def _rows(self) -> list[list[base.BaseComponent[special_endpoints.ModalActionRowBuilder]]]:
try:
return self.__rows
except AttributeError:
self.__rows = [[] for _ in range(self._MAX_ROWS)]
return self.__rows
def _make_action_row(self) -> special_endpoints.ModalActionRowBuilder:
return special_endpoints_impl.ModalActionRowBuilder()

def _current_row_full(self) -> bool:
# Currently, you are only allowed a single component within each row
# Maybe Discord will change this in the future
return bool(self._rows[self._current_row])

def clear_rows(self) -> t_ex.Self:
self._rows.clear()
return self

def clear_current_row(self) -> t_ex.Self:
self._rows[self._current_row].clear()
return self

def next_row(self) -> t_ex.Self:
if self._current_row + 1 >= self._MAX_ROWS:
raise RuntimeError("the maximum number of rows has been reached")
self.__current_row += 1
return self

def previous_row(self) -> t_ex.Self:
self.__current_row = max(0, self.__current_row - 1)
return self

def add(self, component: ModalComponentT) -> ModalComponentT:
if self._current_row_full():
self.next_row()

self._rows[self._current_row].append(component)
return component

def add_short_text_input(
self,
label: str,
Expand Down Expand Up @@ -248,3 +180,6 @@ async def attach(self, client: client_.Client, custom_id: str, *, timeout: float
finally:
# Unregister queue from client
client._modal_queues.remove(queue)

@abc.abstractmethod
async def on_submit(self, ctx: ModalContext) -> None: ...

0 comments on commit e1d014e

Please sign in to comment.