Skip to content

Commit

Permalink
Initial command syncing logic
Browse files Browse the repository at this point in the history
  • Loading branch information
tandemdude committed Mar 9, 2024
1 parent 5fa66e3 commit 64e5b44
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 33 deletions.
25 changes: 2 additions & 23 deletions lightbulb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from lightbulb.commands import groups
from lightbulb.internal import constants
from lightbulb.internal import di as di_
from lightbulb.internal import sync
from lightbulb.internal import utils

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -212,29 +213,7 @@ async def sync_application_commands(self) -> None:
Returns:
:obj:`None`
"""
# TODO - implement syncing logic - for now just do create
LOGGER.info("syncing commands with discord")
application = await self._ensure_application()

for guild_id, guild_commands in self._commands.items():
if guild_id == constants.GLOBAL_COMMAND_KEY:
LOGGER.debug("processing global commands")
# TODO - Do global command syncing
continue

LOGGER.debug("processing guild - %s", guild_id)

builders: t.List[hikari.api.CommandBuilder] = []
for cmds in guild_commands.values():
builders.extend(
c.as_command_builder(self.default_locale, self.localization_provider)
for c in [cmds.slash, cmds.user, cmds.message]
if c is not None
)

await self.rest.set_application_commands(application, builders, guild_id)

LOGGER.info("finished syncing commands with discord")
await sync.sync_application_commands(self)

@staticmethod
def _get_subcommand(
Expand Down
17 changes: 7 additions & 10 deletions lightbulb/commands/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

from lightbulb import utils
from lightbulb.commands import utils as cmd_utils
from lightbulb.internal.utils import non_undefined_or

if t.TYPE_CHECKING:
from lightbulb import commands
Expand All @@ -55,10 +56,6 @@
CtxMenuOptionReturnT = t.Union[hikari.User, hikari.Message]


def _non_undefined_or(item: hikari.UndefinedOr[T], default: D) -> t.Union[T, D]:
return item if item is not hikari.UNDEFINED else default


@dataclasses.dataclass(slots=True)
class OptionData(t.Generic[D]):
"""
Expand Down Expand Up @@ -139,12 +136,12 @@ def to_command_option(
description=description,
description_localizations=description_localizations, # type: ignore[reportArgumentType]
is_required=self.default is not hikari.UNDEFINED,
choices=_non_undefined_or(self.choices, None),
channel_types=_non_undefined_or(self.channel_types, None),
min_value=_non_undefined_or(self.min_value, None),
max_value=_non_undefined_or(self.max_value, None),
min_length=_non_undefined_or(self.min_length, None),
max_length=_non_undefined_or(self.max_length, None),
choices=non_undefined_or(self.choices, None),
channel_types=non_undefined_or(self.channel_types, None),
min_value=non_undefined_or(self.min_value, None),
max_value=non_undefined_or(self.max_value, None),
min_length=non_undefined_or(self.min_length, None),
max_length=non_undefined_or(self.max_length, None),
autocomplete=self.autocomplete,
)

Expand Down
188 changes: 188 additions & 0 deletions lightbulb/internal/sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# -*- coding: utf-8 -*-
# Copyright © tandemdude 2023-present
#
# This file is part of Lightbulb.
#
# Lightbulb is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Lightbulb is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Lightbulb. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations

__all__ = []

import collections
import dataclasses
import logging
import typing as t

import hikari

from lightbulb.internal import constants
from lightbulb.internal.utils import non_undefined_or

if t.TYPE_CHECKING:
from lightbulb import client as client_

LOGGER = logging.getLogger("lightbulb.internal.sync")


@dataclasses.dataclass(slots=True)
class CommandBuilderCollection:
slash: t.Optional[hikari.api.SlashCommandBuilder] = None
user: t.Optional[hikari.api.ContextMenuCommandBuilder] = None
message: t.Optional[hikari.api.ContextMenuCommandBuilder] = None

def put(self, bld: hikari.api.CommandBuilder) -> None:
if isinstance(bld, hikari.api.SlashCommandBuilder):
self.slash = bld
elif isinstance(bld, hikari.api.ContextMenuCommandBuilder):
if bld.type is hikari.CommandType.USER:
self.user = bld
else:
self.message = bld
else:
raise TypeError("unrecognised builder type")


def hikari_command_to_builder(
cmd: hikari.PartialCommand,
) -> t.Union[hikari.api.SlashCommandBuilder, hikari.api.ContextMenuCommandBuilder]:
bld: t.Union[hikari.api.SlashCommandBuilder, hikari.api.ContextMenuCommandBuilder]
if desc := getattr(cmd, "description", None):
bld = hikari.impl.SlashCommandBuilder(cmd.name, description=desc)
for option in getattr(cmd, "options", []) or []:
bld.add_option(option)
else:
bld = hikari.impl.ContextMenuCommandBuilder(type=cmd.type, name=cmd.name)

return (
bld.set_default_member_permissions(cmd.default_member_permissions)
.set_is_dm_enabled(cmd.is_dm_enabled)
.set_is_nsfw(cmd.is_nsfw)
.set_name_localizations(cmd.name_localizations)
.set_id(cmd.id)
)


async def get_existing_and_registered_commands(
client: client_.Client, application: hikari.PartialApplication, guild: hikari.UndefinedOr[hikari.Snowflakeish]
) -> t.Tuple[t.Dict[str, CommandBuilderCollection], t.Dict[str, CommandBuilderCollection]]:
existing: t.Dict[str, CommandBuilderCollection] = collections.defaultdict(CommandBuilderCollection)
registered: t.Dict[str, CommandBuilderCollection] = collections.defaultdict(CommandBuilderCollection)

for existing_command in await client.rest.fetch_application_commands(application, guild=guild):
existing[existing_command.name].put(hikari_command_to_builder(existing_command))
for name, collection in client._commands.get(
constants.GLOBAL_COMMAND_KEY if guild is hikari.UNDEFINED else guild, {}
).items():
for item in [collection.slash, collection.user, collection.message]:
if item is None:
continue

registered[name].put(item.as_command_builder(client.default_locale, client.localization_provider))

return existing, registered


def serialize_builder(bld: hikari.api.CommandBuilder) -> t.Dict[str, t.Any]:
def serialize_option(opt: hikari.CommandOption) -> t.Dict[str, t.Any]:
return {
"type": opt.type,
"name": opt.name,
"description": opt.description,
"is_required": opt.is_required,
"choices": opt.choices or [],
"options": [serialize_option(o) for o in (opt.options or [])],
"channel_types": opt.channel_types or [],
"autocomplete": opt.autocomplete,
"min_value": opt.min_value,
"max_value": opt.max_value,
"name_localizations": opt.name_localizations,
"description_localizations": opt.description_localizations,
"min_length": opt.min_length,
"max_length": opt.max_length,
}

out: t.Dict[str, t.Any] = {
"name": bld.name,
"is_dm_enabled": non_undefined_or(bld.is_dm_enabled, True),
"is_nsfw": non_undefined_or(bld.is_nsfw, False),
"name_localizations": bld.name_localizations,
}

if isinstance(bld, hikari.api.SlashCommandBuilder):
out["description"] = bld.description
out["description_localizations"] = bld.description_localizations
out["options"] = [serialize_option(opt) for opt in bld.options]

return out


def get_commands_to_set(
existing: t.Dict[str, CommandBuilderCollection], registered: t.Dict[str, CommandBuilderCollection]
) -> t.Optional[t.Sequence[hikari.api.CommandBuilder]]:
created, deleted, updated, unchanged = 0, 0, 0, 0

commands_to_set: t.List[hikari.api.CommandBuilder] = []
for name in {*existing.keys(), *registered.keys()}:
existing_cmds, registered_cmds = existing[name], registered[name]
for existing_bld, registered_bld in zip(
[existing_cmds.slash, existing_cmds.user, existing_cmds.message],
[registered_cmds.slash, registered_cmds.user, registered_cmds.message],
):
if existing_bld is None and registered_bld is None:
continue

if existing_bld is None:
assert registered_bld is not None

commands_to_set.append(registered_bld)
created += 1
elif registered_bld is None:
# TODO - Check if user wants to remove unknown commands
commands_to_set.append(existing_bld)
else:
if serialize_builder(existing_bld) != serialize_builder(registered_bld):
commands_to_set.append(registered_bld)
updated += 1
else:
commands_to_set.append(existing_bld)
unchanged += 1

LOGGER.debug("created: %s, deleted: %s, updated: %s, unchanged: %s", created, deleted, updated, unchanged)
return commands_to_set if any([created, deleted, updated]) else None


async def sync_application_commands(client: client_.Client) -> None:
application = await client._ensure_application()

LOGGER.info("syncing global commands")
existing_global_commands, registered_global_commands = await get_existing_and_registered_commands(
client, application, hikari.UNDEFINED
)
global_commands_to_set = get_commands_to_set(existing_global_commands, registered_global_commands)
if global_commands_to_set is not None:
await client.rest.set_application_commands(application, global_commands_to_set)
LOGGER.info("finished syncing global commands")

for guild in client._commands:
if guild == constants.GLOBAL_COMMAND_KEY:
continue

LOGGER.info("syncing commands for guild '%s'", guild)
existing_guild_commands, registered_guild_commands = await get_existing_and_registered_commands(
client, application, guild
)
guild_commands_to_set = get_commands_to_set(existing_guild_commands, registered_guild_commands)
if guild_commands_to_set is not None:
await client.rest.set_application_commands(application, guild_commands_to_set)
LOGGER.info("finished syncing commands for guild '%s'", guild)
9 changes: 9 additions & 0 deletions lightbulb/internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,14 @@
import dataclasses
import typing as t

import hikari

from lightbulb.commands import commands
from lightbulb.commands import groups

T = t.TypeVar("T")
D = t.TypeVar("D")


@dataclasses.dataclass(slots=True)
class CommandCollection:
Expand All @@ -42,3 +47,7 @@ def put(
self.message = command
else:
raise TypeError("unsupported command passed")


def non_undefined_or(item: hikari.UndefinedOr[T], default: D) -> t.Union[T, D]:
return item if item is not hikari.UNDEFINED else default

0 comments on commit 64e5b44

Please sign in to comment.