diff --git a/CHANGELOG.md b/CHANGELOG.md index 861f66ec0f..3c8d779160 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,6 +77,8 @@ These changes are available on the `master` branch, but have not yet been releas - Replaced audioop (deprecated module) implementation of `PCMVolumeTransformer.read` method with a pure Python equivalent. ([#2176](https://github.com/Pycord-Development/pycord/pull/2176)) +- Updated `Guild.filesize_limit` to 10 Mb instead of 25 Mb following Discord's API + changes. ([#2671](https://github.com/Pycord-Development/pycord/pull/2671)) ### Deprecated diff --git a/discord/asset.py b/discord/asset.py index 07c7ca8e7b..8c7fdf64d5 100644 --- a/discord/asset.py +++ b/discord/asset.py @@ -30,6 +30,7 @@ from typing import TYPE_CHECKING, Any, Literal import yarl +from typing_extensions import Final, override from . import utils from .errors import DiscordException, InvalidArgument @@ -39,6 +40,7 @@ if TYPE_CHECKING: ValidStaticFormatTypes = Literal["webp", "jpeg", "jpg", "png"] ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"] + from .state import ConnectionState VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"}) VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"} @@ -49,7 +51,7 @@ class AssetMixin: url: str - _state: Any | None + _state: ConnectionState | None async def read(self) -> bytes: """|coro| @@ -77,7 +79,9 @@ async def read(self) -> bytes: async def save( self, - fp: str | bytes | os.PathLike | io.BufferedIOBase, + fp: ( + str | bytes | os.PathLike | io.BufferedIOBase + ), # pyright: ignore [reportMissingTypeArgument] *, seek_begin: bool = True, ) -> int: @@ -117,7 +121,7 @@ async def save( fp.seek(0) return written else: - with open(fp, "wb") as f: + with open(fp, "wb") as f: # pyright: ignore [reportUnknownArgumentType] return f.write(data) @@ -154,16 +158,23 @@ class Asset(AssetMixin): "_key", ) - BASE = "https://cdn.discordapp.com" + BASE: Final = "https://cdn.discordapp.com" - def __init__(self, state, *, url: str, key: str, animated: bool = False): - self._state = state - self._url = url - self._animated = animated - self._key = key + def __init__( + self, + state: ConnectionState | None, + *, + url: str, + key: str, + animated: bool = False, + ): + self._state: ConnectionState | None = state + self._url: str = url + self._animated: bool = animated + self._key: str = key @classmethod - def _from_default_avatar(cls, state, index: int) -> Asset: + def _from_default_avatar(cls, state: ConnectionState, index: int) -> Asset: return cls( state, url=f"{cls.BASE}/embed/avatars/{index}.png", @@ -172,7 +183,7 @@ def _from_default_avatar(cls, state, index: int) -> Asset: ) @classmethod - def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: + def _from_avatar(cls, state: ConnectionState, user_id: int, avatar: str) -> Asset: animated = avatar.startswith("a_") format = "gif" if animated else "png" return cls( @@ -184,7 +195,10 @@ def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: @classmethod def _from_avatar_decoration( - cls, state, user_id: int, avatar_decoration: str + cls, + state: ConnectionState, + user_id: int, + avatar_decoration: str, # pyright: ignore [reportUnusedParameter] ) -> Asset: animated = avatar_decoration.startswith("a_") endpoint = ( @@ -201,7 +215,7 @@ def _from_avatar_decoration( @classmethod def _from_guild_avatar( - cls, state, guild_id: int, member_id: int, avatar: str + cls, state: ConnectionState, guild_id: int, member_id: int, avatar: str ) -> Asset: animated = avatar.startswith("a_") format = "gif" if animated else "png" @@ -214,7 +228,7 @@ def _from_guild_avatar( @classmethod def _from_guild_banner( - cls, state, guild_id: int, member_id: int, banner: str + cls, state: ConnectionState, guild_id: int, member_id: int, banner: str ) -> Asset: animated = banner.startswith("a_") format = "gif" if animated else "png" @@ -226,7 +240,9 @@ def _from_guild_banner( ) @classmethod - def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset: + def _from_icon( + cls, state: ConnectionState, object_id: int, icon_hash: str, path: str + ) -> Asset: return cls( state, url=f"{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024", @@ -235,7 +251,9 @@ def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset: ) @classmethod - def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset: + def _from_cover_image( + cls, state: ConnectionState, object_id: int, cover_image_hash: str + ) -> Asset: return cls( state, url=f"{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024", @@ -244,7 +262,9 @@ def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asse ) @classmethod - def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset: + def _from_guild_image( + cls, state: ConnectionState, guild_id: int, image: str, path: str + ) -> Asset: animated = False format = "png" if path == "banners": @@ -259,7 +279,9 @@ def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset ) @classmethod - def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset: + def _from_guild_icon( + cls, state: ConnectionState, guild_id: int, icon_hash: str + ) -> Asset: animated = icon_hash.startswith("a_") format = "gif" if animated else "png" return cls( @@ -270,7 +292,7 @@ def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset: ) @classmethod - def _from_sticker_banner(cls, state, banner: int) -> Asset: + def _from_sticker_banner(cls, state: ConnectionState, banner: int) -> Asset: return cls( state, url=f"{cls.BASE}/app-assets/710982414301790216/store/{banner}.png", @@ -279,7 +301,9 @@ def _from_sticker_banner(cls, state, banner: int) -> Asset: ) @classmethod - def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: + def _from_user_banner( + cls, state: ConnectionState, user_id: int, banner_hash: str + ) -> Asset: animated = banner_hash.startswith("a_") format = "gif" if animated else "png" return cls( @@ -291,7 +315,7 @@ def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: @classmethod def _from_scheduled_event_image( - cls, state, event_id: int, cover_hash: str + cls, state: ConnectionState, event_id: int, cover_hash: str ) -> Asset: return cls( state, @@ -300,24 +324,29 @@ def _from_scheduled_event_image( animated=False, ) + @override def __str__(self) -> str: return self._url def __len__(self) -> int: return len(self._url) + @override def __repr__(self): shorten = self._url.replace(self.BASE, "") return f"" - def __eq__(self, other): + @override + def __eq__(self, other: Any): # pyright: ignore [reportExplicitAny] return isinstance(other, Asset) and self._url == other._url + @override def __hash__(self): return hash(self._url) @property - def url(self) -> str: + @override + def url(self) -> str: # pyright: ignore [reportIncompatibleVariableOverride] """Returns the underlying URL of the asset.""" return self._url diff --git a/discord/client.py b/discord/client.py index 4db53e33e3..9fb4e3c183 100644 --- a/discord/client.py +++ b/discord/client.py @@ -35,7 +35,7 @@ import aiohttp -from . import utils +from . import models, utils from .activity import ActivityTypes, BaseActivity, create_activity from .appinfo import AppInfo, PartialAppInfo from .application_role_connection import ApplicationRoleConnectionMetadata @@ -1840,7 +1840,7 @@ async def fetch_user(self, user_id: int, /) -> User: :exc:`HTTPException` Fetching the user failed. """ - data = await self.http.get_user(user_id) + data: models.User = await self.http.get_user(user_id) return User(state=self._connection, data=data) async def fetch_channel( diff --git a/discord/gateway.py b/discord/gateway.py index 4af59f3864..4f95655def 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -35,8 +35,12 @@ import traceback import zlib from collections import deque, namedtuple +from typing import Any import aiohttp +from pydantic import BaseModel + +from discord import models from . import utils from .activity import BaseActivity @@ -548,11 +552,20 @@ async def received_message(self, msg, /): ) try: - func = self._discord_parsers[event] + func: Any = self._discord_parsers[event] except KeyError: _log.debug("Unknown event %s.", event) else: - func(data) + if hasattr(func, "_supports_model") and issubclass( + func._supports_model, models.gateway.GatewayEvent + ): + func( + func._supports_model( + **msg + ).d # pyright: ignore [reportUnknownMemberType, reportAttributeAccessIssue] + ) + else: + func(data) # remove the dispatched listeners removed = [] diff --git a/discord/guild.py b/discord/guild.py index b1e937d07b..68759ff41c 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -40,7 +40,7 @@ overload, ) -from . import abc, utils +from . import abc, models, utils from .asset import Asset from .automod import AutoModAction, AutoModRule, AutoModTriggerMetadata from .channel import * @@ -289,11 +289,11 @@ class Guild(Hashable): ) _PREMIUM_GUILD_LIMITS: ClassVar[dict[int | None, _GuildLimit]] = { - None: _GuildLimit(emoji=50, stickers=5, bitrate=96e3, filesize=26214400), - 0: _GuildLimit(emoji=50, stickers=5, bitrate=96e3, filesize=26214400), - 1: _GuildLimit(emoji=100, stickers=15, bitrate=128e3, filesize=26214400), - 2: _GuildLimit(emoji=150, stickers=30, bitrate=256e3, filesize=52428800), - 3: _GuildLimit(emoji=250, stickers=60, bitrate=384e3, filesize=104857600), + None: _GuildLimit(emoji=50, stickers=5, bitrate=96e3, filesize=10_485_760), + 0: _GuildLimit(emoji=50, stickers=5, bitrate=96e3, filesize=10_485_760), + 1: _GuildLimit(emoji=100, stickers=15, bitrate=128e3, filesize=10_485_760), + 2: _GuildLimit(emoji=150, stickers=30, bitrate=256e3, filesize=52_428_800), + 3: _GuildLimit(emoji=250, stickers=60, bitrate=384e3, filesize=104_857_600), } def __init__(self, *, data: GuildPayload, state: ConnectionState): @@ -2114,9 +2114,9 @@ async def fetch_ban(self, user: Snowflake) -> BanEntry: HTTPException An error occurred while fetching the information. """ - data: BanPayload = await self._state.http.get_ban(user.id, self.id) + data: models.Ban = await self._state.http.get_ban(user.id, self.id) return BanEntry( - user=User(state=self._state, data=data["user"]), reason=data["reason"] + user=User(state=self._state, data=data.user), reason=data.reason ) async def fetch_channel(self, channel_id: int, /) -> GuildChannel | Thread: diff --git a/discord/http.py b/discord/http.py index 464710daac..893e9f798e 100644 --- a/discord/http.py +++ b/discord/http.py @@ -33,8 +33,10 @@ from urllib.parse import quote as _uriquote import aiohttp +from pydantic import BaseModel, TypeAdapter +from typing_extensions import overload, reveal_type -from . import __version__, utils +from . import __version__, models, utils from .errors import ( DiscordServerError, Forbidden, @@ -87,9 +89,12 @@ T = TypeVar("T") BE = TypeVar("BE", bound=BaseException) MU = TypeVar("MU", bound="MaybeUnlock") - Response = Coroutine[Any, Any, T] + + Response = Coroutine[Any, Any, T] # pyright: ignore [reportExplicitAny] API_VERSION: int = 10 +TP = TypeVar("TP") +BM = TypeVar("BM", bound=BaseModel) async def json_or_text(response: aiohttp.ClientResponse) -> dict[str, Any] | str: @@ -157,7 +162,7 @@ def __exit__( # For some reason, the Discord voice websocket expects this header to be # completely lowercase while aiohttp respects spec and does it as case-insensitive -aiohttp.hdrs.WEBSOCKET = "websocket" # type: ignore +aiohttp.hdrs.WEBSOCKET = "websocket" # type: ignore # pyright: ignore [reportAttributeAccessIssue] class HTTPClient: @@ -215,14 +220,48 @@ async def ws_connect(self, url: str, *, compress: int = 0) -> Any: return await self.__session.ws_connect(url, **kwargs) + @overload + async def request( + self, + route: Route, + *, + files: Sequence[File] | None = None, + form: Iterable[dict[str, Any]] | None = None, + model: None, + **kwargs: Any, + ) -> Any: ... + + @overload + async def request( + self, + route: Route, + *, + files: None = ..., + form: None = ..., + model: type[BM], + **kwargs: Any, + ) -> BM: ... + + @overload + async def request( + self, + route: Route, + *, + files: None = ..., + form: None = ..., + model: TypeAdapter[TP], + **kwargs: Any, + ) -> TP: ... + async def request( self, route: Route, *, files: Sequence[File] | None = None, form: Iterable[dict[str, Any]] | None = None, + model: type[BM] | TypeAdapter[TP] | None = None, **kwargs: Any, - ) -> Any: + ) -> Any | BM | TP: bucket = route.bucket method = route.method url = route.url @@ -318,6 +357,13 @@ async def request( # the request was successful so just return the text/json if 300 > response.status >= 200: _log.debug("%s %s has received %s", method, url, data) + if model: + if isinstance(model, TypeAdapter): + return model.validate_python( + data + ) # pyright: ignore [reportUnknownVariableType] + return model.model_validate(data) + return data # we are being rate limited @@ -409,7 +455,7 @@ async def close(self) -> None: # login management - async def static_login(self, token: str) -> user.User: + async def static_login(self, token: str) -> models.User: # Necessary to get aiohttp to stop complaining about session creation self.__session = aiohttp.ClientSession( connector=self.connector, ws_response_class=DiscordClientWebSocketResponse @@ -418,7 +464,7 @@ async def static_login(self, token: str) -> user.User: self.token = token try: - data = await self.request(Route("GET", "/users/@me")) + data = await self.request(Route("GET", "/users/@me"), model=models.User) except HTTPException as exc: self.token = old_token if exc.status == 401: @@ -1598,11 +1644,11 @@ def create_from_template( def get_bans( self, - guild_id: Snowflake, + guild_id: models.Snowflake, limit: int | None = None, - before: Snowflake | None = None, - after: Snowflake | None = None, - ) -> Response[list[guild.Ban]]: + before: models.Snowflake | None = None, + after: models.Snowflake | None = None, + ) -> Response[list[models.Ban]]: params: dict[str, int | Snowflake] = {} if limit is not None: @@ -1613,17 +1659,22 @@ def get_bans( params["after"] = after return self.request( - Route("GET", "/guilds/{guild_id}/bans", guild_id=guild_id), params=params + Route("GET", "/guilds/{guild_id}/bans", guild_id=guild_id), + params=params, + model=TypeAdapter(list[models.Ban]), ) - def get_ban(self, user_id: Snowflake, guild_id: Snowflake) -> Response[guild.Ban]: + def get_ban( + self, user_id: models.Snowflake, guild_id: models.Snowflake + ) -> Response[models.Ban]: return self.request( Route( "GET", "/guilds/{guild_id}/bans/{user_id}", guild_id=guild_id, user_id=user_id, - ) + ), + model=models.Ban, ) def get_vanity_code(self, guild_id: Snowflake) -> Response[invite.VanityInvite]: @@ -3173,5 +3224,7 @@ async def get_bot_gateway( value = "{0}?encoding={1}&v={2}" return data["shards"], value.format(data["url"], encoding, API_VERSION) - def get_user(self, user_id: Snowflake) -> Response[user.User]: - return self.request(Route("GET", "/users/{user_id}", user_id=user_id)) + def get_user(self, user_id: Snowflake) -> Response[models.User]: + return self.request( + Route("GET", "/users/{user_id}", user_id=user_id), model=models.User + ) diff --git a/discord/iterators.py b/discord/iterators.py index eca3c72091..0f3cbea4cc 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -38,6 +38,9 @@ Union, ) +from typing_extensions import Final, override, reveal_type + +from . import models from .audit_logs import AuditLogEntry from .errors import NoMoreItems from .monetization import Entitlement @@ -57,9 +60,11 @@ if TYPE_CHECKING: from .abc import Snowflake from .guild import BanEntry, Guild + from .http import HTTPClient from .member import Member from .message import Message from .scheduled_events import ScheduledEvent + from .state import ConnectionState from .threads import Thread from .types.audit_log import AuditLog as AuditLogPayload from .types.guild import Guild as GuildPayload @@ -737,16 +742,26 @@ def create_member(self, data): class BanIterator(_AsyncIterator["BanEntry"]): - def __init__(self, guild, limit=None, before=None, after=None): - self.guild = guild - self.limit = limit - self.after = after - self.before = before - - self.state = self.guild._state - self.get_bans = self.state.http.get_bans - self.bans = asyncio.Queue() - + def __init__( + self, + guild: Guild, + limit: int | None = None, + before: models.Snowflake | None = None, + after: models.Snowflake | None = None, + ): + self.guild: Guild = guild + self.limit: int | None = limit + self.after: models.Snowflake | None = after + self.before: models.Snowflake | None = before + self.retrieve: int = 0 + + self.state: ConnectionState = ( + self.guild._state + ) # pyright: ignore [reportPrivateUsage] + self.get_bans: Final = self.state.http.get_bans + self.bans: asyncio.Queue[BanEntry] = asyncio.Queue() + + @override async def next(self) -> BanEntry: if self.bans.empty(): await self.fill_bans() @@ -757,20 +772,20 @@ async def next(self) -> BanEntry: raise NoMoreItems() def _get_retrieve(self): - l = self.limit - if l is None or l > 1000: - r = 1000 + if self.limit is None or self.limit > 1000: + self.retrieve = 1000 else: - r = l - self.retrieve = r - return r > 0 + self.retrieve = self.limit + return self.retrieve > 0 async def fill_bans(self): if not self._get_retrieve(): return - before = self.before.id if self.before else None - after = self.after.id if self.after else None - data = await self.get_bans(self.guild.id, self.retrieve, before, after) + before: models.Snowflake | None = self.before if self.before else None + after: models.Snowflake | None = self.after if self.after else None + data = await self.get_bans( + models.Snowflake(self.guild.id), self.retrieve, before, after + ) if not data: # no data, terminate return @@ -780,18 +795,16 @@ async def fill_bans(self): if len(data) < 1000: self.limit = 0 # terminate loop - self.after = Object(id=int(data[-1]["user"]["id"])) + self.after = data[-1].user.id for element in reversed(data): await self.bans.put(self.create_ban(element)) - def create_ban(self, data): + def create_ban(self, data: models.Ban) -> BanEntry: from .guild import BanEntry from .user import User - return BanEntry( - reason=data["reason"], user=User(state=self.state, data=data["user"]) - ) + return BanEntry(reason=data.reason, user=User(state=self.state, data=data.user)) class ArchivedThreadIterator(_AsyncIterator["Thread"]): diff --git a/discord/models/README.md b/discord/models/README.md new file mode 100644 index 0000000000..52209c5ecf --- /dev/null +++ b/discord/models/README.md @@ -0,0 +1,72 @@ +# Py-cord Models + +This directory contains the pydantic models and types used by py-cord. + +## Structure + +The models are structured in a way that they mirror the structure of the Discord API. +They are subdivided into the following submodules: + +> [!IMPORTANT] Each of the submodules is defined below in order. Submodules may only +> reference in their code classes from the same or lower level submodules. For example, +> `api` may reference classes from `api`, `base` and `types`, but `base` may not +> reference classes from `api`. This is to prevent circular imports and to keep the +> codebase clean and maintainable. + +### `types` + +Contains python types and dataclasses that are used in the following submodules. These +are used to represent the data in a more pythonic way, and are used to define the +pydantic models. + +### `base` + +Contains the base models defined in the Discord docs. These are the models you will +often find with a heading like "... Template", and hyperlinks linking to it referring to +it as an "object". + +For example, the +[User Template](https://discord.com/developers/docs/resources/user#user-object) is +defined in `base/user.py`. + +### `api` + +Contains the models that are used to represent the data received and set trough discord +API requests. They represent payloads that are sent and received from the Discord API. + +When representing a route, it is preferred to create a single python file for each base +route. If the file may become too large, it is preferred to split it into multiple +files, one for each sub-route. In that case, a submodule with the name of the base route +should be created to hold the sub-routes. + +For example, the +[Modify Guild Template](https://discord.com/developers/docs/resources/guild-template#modify-guild-template) +is defined in `api/guild_template.py`. + +### `gateway` + +Contains the models that are used to represent the data received and sent trough the +Discord Gateway. They represent payloads that are sent and received from the Discord +Gateway. + +For example, the [Ready Event](https://discord.com/developers/docs/topics/gateway#hello) +is defined in `gateway/ready.py`. + +## Naming + +The naming of the models is based on the Discord API documentation. The models are named +after the object they represent in the Discord API documentation. It is generally +preferred to create a new model for each object in the Discord API documentation, even +if the file may only contain a single class, so that the structure keeps a 1:1 mapping +with the Discord API documentation. + +## Exporting strategies + +The models are exported in the following way: + +- The models are exported in the `__init__.py` of their respective submodules. +- Models from the `base` submodule are re-exported in the `__init__.py` of the `modules` + module. +- The other submodules are re-exported in the `__init__.py` of the `models` module as a + single import. +- The `models` module is re-exported in the `discord` module. diff --git a/discord/models/__init__.py b/discord/models/__init__.py new file mode 100644 index 0000000000..0fd411c808 --- /dev/null +++ b/discord/models/__init__.py @@ -0,0 +1,29 @@ +from discord.models.base.role import Role + +from . import gateway, types +from .base import ( + AvatarDecorationData, + Ban, + Emoji, + Guild, + Sticker, + UnavailableGuild, + User, +) +from .types import Snowflake +from .types.utils import MISSING + +__all__ = ( + "Emoji", + "Guild", + "UnavailableGuild", + "Role", + "Sticker", + "types", + "User", + "MISSING", + "AvatarDecorationData", + "gateway", + "Ban", + "Snowflake", +) diff --git a/discord/models/api/__init__.py b/discord/models/api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/discord/models/api/ban.py b/discord/models/api/ban.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/discord/models/base/__init__.py b/discord/models/base/__init__.py new file mode 100644 index 0000000000..0ddb2883ce --- /dev/null +++ b/discord/models/base/__init__.py @@ -0,0 +1,17 @@ +from .ban import Ban +from .emoji import Emoji +from .guild import Guild, UnavailableGuild +from .role import Role +from .sticker import Sticker +from .user import AvatarDecorationData, User + +__all__ = ( + "Emoji", + "Guild", + "UnavailableGuild", + "Role", + "Sticker", + "User", + "Ban", + "AvatarDecorationData", +) diff --git a/discord/models/base/ban.py b/discord/models/base/ban.py new file mode 100644 index 0000000000..c831dbbf4e --- /dev/null +++ b/discord/models/base/ban.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + +from .user import User + + +class Ban(BaseModel): + reason: str + user: User diff --git a/discord/models/base/emoji.py b/discord/models/base/emoji.py new file mode 100644 index 0000000000..e3a0685546 --- /dev/null +++ b/discord/models/base/emoji.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from pydantic import BaseModel + +from ..types import MISSING, EmojiID, MissingSentinel, RoleID +from .user import User + + +class Emoji(BaseModel): + id: EmojiID | None + name: str + roles: list[RoleID] | MissingSentinel = MISSING + user: User | MissingSentinel = MISSING + require_colons: bool | MissingSentinel = MISSING + managed: bool | MissingSentinel = MISSING + animated: bool | MissingSentinel = MISSING + available: bool | MissingSentinel = MISSING diff --git a/discord/models/base/guild.py b/discord/models/base/guild.py new file mode 100644 index 0000000000..a9498654f7 --- /dev/null +++ b/discord/models/base/guild.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from enum import IntEnum +from typing import Any + +from pydantic import BaseModel, Field, GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema + +from discord.models.types import ( + MISSING, + ChannelID, + EmojiID, + GuildID, + Locale, + MissingSentinel, + Permissions, + SystemChannelFlags, + UserID, +) + +from .emoji import Emoji +from .role import Role +from .sticker import Sticker + + +class VerificationLevel(IntEnum): + NONE = 0 + LOW = 1 + MEDIUM = 2 + HIGH = 3 + VERY_HIGH = 4 + + +class DefaultNotificationLevel(IntEnum): + ALL_MESSAGES = 0 + ONLY_MENTIONS = 1 + + +class ExplicitContentFilterLevel(IntEnum): + DISABLED = 0 + MEMBERS_WITHOUT_ROLES = 1 + ALL_MEMBERS = 2 + + +class GuildFeatures(set[str]): + def __getattr__(self, item: str) -> bool: + return item.lower() in self or item in self + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, # pyright: ignore [reportExplicitAny] + _handler: GetCoreSchemaHandler, + ) -> CoreSchema: + def validate_and_create( + v: Any, + ) -> GuildFeatures: # pyright: ignore [reportExplicitAny] + if isinstance(v, cls): + return v + if isinstance(v, (list, set)): + return cls( + str(item) for item in v + ) # pyright: ignore [reportUnknownArgumentType, reportUnknownVariableType] + raise ValueError("Invalid input type for GuildFeatures") + + return core_schema.json_or_python_schema( + # For Python inputs: + python_schema=core_schema.union_schema( + [ + # Accept existing instances + core_schema.is_instance_schema(cls), + # Accept lists or sets + core_schema.no_info_plain_validator_function(validate_and_create), + ] + ), + # For JSON inputs, expecting an array of strings: + json_schema=core_schema.list_schema( + core_schema.str_schema(), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda x: list(x), + return_schema=core_schema.list_schema(core_schema.str_schema()), + ), + ), + # When serializing to JSON, convert to list + serialization=core_schema.plain_serializer_function_ser_schema( + lambda x: list(x), + return_schema=core_schema.list_schema(core_schema.str_schema()), + when_used="json", + ), + ) + + +class MFALevel(IntEnum): + NONE = 0 + ELEVATED = 1 + + +class PremiumTier(IntEnum): + NONE = 0 + TIER_1 = 1 + TIER_2 = 2 + TIER_3 = 3 + + +class NSFWLevel(IntEnum): + DEFAULT = 0 + EXPLICIT = 1 + SAFE = 2 + AGE_RESTRICTED = 3 + + +class BaseGuild(BaseModel): + id: GuildID + + +class UnavailableGuild(BaseGuild): + unavailable: bool = True + + +class Guild(BaseGuild): + name: str = Field(min_length=2, max_length=100) + icon: str | MissingSentinel | None = Field(alias="icon_hash", default=MISSING) + splash: str | None + discovery_splash: str | None + owner: bool | MissingSentinel = Field(default=MISSING) + owner_id: UserID + permissions: Permissions | MissingSentinel = Field(default=MISSING) + afk_channel_id: ChannelID | None + afk_timeout: int + widget_enabled: bool | MissingSentinel = Field(default=MISSING) + widget_channel_id: ChannelID | None | MissingSentinel = Field(default=MISSING) + verification_level: VerificationLevel + default_message_notifications: DefaultNotificationLevel + explicit_content_filter: ExplicitContentFilterLevel + roles: list[Role] + emojis: list[Emoji] + features: GuildFeatures + mfa_level: MFALevel + application_id: UserID | None + system_channel_id: ChannelID | None + system_channel_flags: SystemChannelFlags + rules_channel_id: ChannelID | None + max_presences: int | None | MissingSentinel = Field(default=MISSING) + max_members: int | MissingSentinel = Field(default=MISSING) + vanity_url_code: str | None + description: str | None + banner: str | None + premium_tier: PremiumTier + premium_subscription_count: int | None + preferred_locale: Locale + public_updates_channel_id: ChannelID | None + max_video_channel_users: int | MissingSentinel = Field(default=MISSING) + max_stage_video_channel_users: int | MissingSentinel = Field(default=MISSING) + approximate_member_count: int | MissingSentinel = Field(default=MISSING) + approximate_presence_count: int | MissingSentinel = Field(default=MISSING) + welcome_screen: WelcomeScreen | None + nsfw_level: NSFWLevel + stickers: list[Sticker] | MissingSentinel = Field(default=MISSING) + premium_progress_bar_enabled: bool + safety_alerts_channel_id: ChannelID | None + + +class WelcomeScreen(BaseModel): + description: str | None + welcome_channels: list[WelcomeScreenChannel] + + +class WelcomeScreenChannel(BaseModel): + channel_id: ChannelID + description: str + emoji_id: EmojiID | None + emoji_name: str | None diff --git a/discord/models/base/role.py b/discord/models/base/role.py new file mode 100644 index 0000000000..2b665e677f --- /dev/null +++ b/discord/models/base/role.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from typing import Any + +from pydantic import ( + BaseModel, + BeforeValidator, + Field, + SerializerFunctionWrapHandler, + model_serializer, +) +from typing_extensions import Annotated + +from ..types import ( + MISSING, + Color, + MissingSentinel, + Permissions, + RoleFlags, + RoleID, + Snowflake, + UserID, +) + + +def presence_bool(v: Any) -> bool: # pyright: ignore [reportExplicitAny] + return True if v is not None else False + + +PresentBool = Annotated[bool, BeforeValidator(presence_bool)] + + +class RoleTags(BaseModel): + bot_id: UserID | MissingSentinel = Field(default=MISSING) + integration_id: UserID | MissingSentinel = Field(default=MISSING) + premium_subscriber: PresentBool = False + subscription_listing_id: Snowflake | MissingSentinel = Field(default=MISSING) + available_for_purchase: PresentBool = False + guild_connections: PresentBool = False + + @model_serializer(mode="wrap") + def serialize_model( + self, nxt: SerializerFunctionWrapHandler + ) -> dict[Any, Any]: # pyright: ignore [reportExplicitAny] + return {k: v for k, v in nxt(self).items() if v is not False} + + +class Role(BaseModel): + id: RoleID + name: str + color: Color = Field(alias="colour") + hoist: bool + icon: str | None | MissingSentinel = Field(default=MISSING) + unicode_emoji: str | None | MissingSentinel = Field(default=MISSING) + position: int + permissions: Permissions + managed: bool + mentionable: bool + tags: RoleTags | MissingSentinel = Field(default=MISSING) + flags: RoleFlags diff --git a/discord/models/base/sticker.py b/discord/models/base/sticker.py new file mode 100644 index 0000000000..deec3d3fec --- /dev/null +++ b/discord/models/base/sticker.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from enum import IntEnum + +from pydantic import BaseModel + +from ..types import MISSING, GuildID, MissingSentinel, Snowflake, StickerID +from .user import User + + +class StickerType(IntEnum): + STANDARD = 1 + GUILD = 2 + + +class StickerFormatType(IntEnum): + PNG = 1 + APNG = 1 + LOTTIE = 3 + GIF = 4 + + +class Sticker(BaseModel): + id: StickerID | None + pack_id: Snowflake | MissingSentinel = MISSING + name: str + description: str | None + tags: str + type: StickerType + format_type: StickerFormatType + available: bool | MissingSentinel = MISSING + guild_id: GuildID | MissingSentinel = MISSING + user: User | MissingSentinel = MISSING + sort_value: int | MissingSentinel = MISSING + + +class StickerPack(BaseModel): + id: Snowflake + stickers: list[Sticker] + name: str + sku_id: Snowflake + cover_sticker_id: StickerID | MissingSentinel = MISSING + description: str + banner_asset_id: Snowflake | MissingSentinel = MISSING diff --git a/discord/models/base/user.py b/discord/models/base/user.py new file mode 100644 index 0000000000..b27416e571 --- /dev/null +++ b/discord/models/base/user.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from enum import IntEnum + +from pydantic import BaseModel + +from ..types import MISSING, Color, Locale, MissingSentinel, Snowflake, UserID +from ..types.flags import UserFlags + + +class PremiumType(IntEnum): + NONE = 0 + NITRO_CLASSIC = 1 + NITRO = 2 + NITRO_BASIC = 3 + + +class AvatarDecorationData(BaseModel): + asset: str + sku_id: Snowflake + + +class User(BaseModel): + id: UserID + username: str + discriminator: str + global_name: str | None + avatar: str | None + bot: bool | MissingSentinel = MISSING + system: bool | MissingSentinel = MISSING + mfa_enabled: bool | MissingSentinel = MISSING + banner: str | None | MissingSentinel = MISSING + accent_color: Color | None | MissingSentinel = MISSING + locale: Locale | MissingSentinel = MISSING + verified: bool | MissingSentinel = MISSING + email: str | MissingSentinel | None = MISSING + flags: UserFlags | MissingSentinel = MISSING + premium_type: PremiumType | MissingSentinel = MISSING + public_flags: UserFlags | MissingSentinel = MISSING + avatar_decoration_data: AvatarDecorationData | None | MissingSentinel = MISSING diff --git a/discord/models/gateway/__init__.py b/discord/models/gateway/__init__.py new file mode 100644 index 0000000000..50705ddd58 --- /dev/null +++ b/discord/models/gateway/__init__.py @@ -0,0 +1,8 @@ +from .base import GatewayEvent +from .ready import Ready, ReadyData + +__all__ = [ + "GatewayEvent", + "Ready", + "ReadyData", +] diff --git a/discord/models/gateway/base.py b/discord/models/gateway/base.py new file mode 100644 index 0000000000..33fa448420 --- /dev/null +++ b/discord/models/gateway/base.py @@ -0,0 +1,12 @@ +from typing import Generic, TypeVar + +from pydantic import BaseModel + +T = TypeVar("T", bound=BaseModel) + + +class GatewayEvent(BaseModel, Generic[T]): + op: int + d: T + s: int + t: str diff --git a/discord/models/gateway/ready.py b/discord/models/gateway/ready.py new file mode 100644 index 0000000000..9e5c025590 --- /dev/null +++ b/discord/models/gateway/ready.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from pydantic import BaseModel + +from ..base import Guild, UnavailableGuild, User +from ..types.utils import MISSING, MissingSentinel +from .base import GatewayEvent + + +class ReadyData(BaseModel): + v: int + user: User + private_channels: list[dict] # TODO: Channel + guilds: list[Guild | UnavailableGuild] + session_id: str + shard: list[int] | MissingSentinel = MISSING + application: dict # TODO: Application + + # @field_validator("guilds", mode='plain') + # def guilds_validator(cls, guilds: list[dict[str, Any]]) -> list[Guild | UnavailableGuild]: # pyright: ignore [reportExplicitAny] + # r: list[Guild | UnavailableGuild] = [] + # for guild in guilds: + # if guild.get("unavailable", False): + # r.append(UnavailableGuild(**guild)) + # else: + # r.append(Guild(**guild)) + # return r + + +class Ready(GatewayEvent[ReadyData]): + pass diff --git a/discord/models/types/__init__.py b/discord/models/types/__init__.py new file mode 100644 index 0000000000..d9e35408db --- /dev/null +++ b/discord/models/types/__init__.py @@ -0,0 +1,30 @@ +from .channel import ChannelID +from .color import Color, Colour +from .emoji import EmojiID +from .flags import Permissions, RoleFlags, SystemChannelFlags, UserFlags +from .guild import GuildID +from .locale import Locale +from .role import RoleID +from .snowflake import Snowflake +from .sticker import StickerID +from .user import UserID +from .utils import MISSING, MissingSentinel + +__all__ = [ + "Snowflake", + "ChannelID", + "GuildID", + "UserID", + "RoleID", + "SystemChannelFlags", + "Permissions", + "MISSING", + "MissingSentinel", + "Locale", + "Color", + "Colour", + "RoleFlags", + "EmojiID", + "StickerID", + "UserFlags", +] diff --git a/discord/models/types/channel.py b/discord/models/types/channel.py new file mode 100644 index 0000000000..f3366bbda4 --- /dev/null +++ b/discord/models/types/channel.py @@ -0,0 +1,5 @@ +from .snowflake import Snowflake + + +class ChannelID(Snowflake): + """Represents a Discord Channel ID.""" diff --git a/discord/models/types/color.py b/discord/models/types/color.py new file mode 100644 index 0000000000..8566216f6f --- /dev/null +++ b/discord/models/types/color.py @@ -0,0 +1,51 @@ +from typing import Any + +from pydantic import GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema +from typing_extensions import Self, final, override + + +@final +class Color(int): + """Represents a color.""" + + def __new__(cls, value: int) -> Self: + # allow for hex str #...... + if isinstance(value, str) and value.startswith("#") and len(value) == 7: + value = int(value[1:], 16) + return super().__new__(cls, value) + + def _get_byte(self, byte: int) -> int: + return (self >> (8 * byte)) & 0xFF + + @override + def __repr__(self) -> str: + return f"" + + @override + def __str__(self) -> str: + return f"#{self:0>6x}" + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, # pyright: ignore [reportExplicitAny] + _handler: GetCoreSchemaHandler, + ) -> CoreSchema: + """Define how Pydantic should handle validation and serialization of Snowflakes""" + return core_schema.json_or_python_schema( + json_schema=core_schema.int_schema(), + python_schema=core_schema.union_schema( + [ + core_schema.is_instance_schema(cls), + core_schema.int_schema(), + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + int, return_schema=core_schema.int_schema(), when_used="json" + ), + ) + + +Colour = Color +__all__ = ["Color", "Colour"] diff --git a/discord/models/types/emoji.py b/discord/models/types/emoji.py new file mode 100644 index 0000000000..c43a1d523b --- /dev/null +++ b/discord/models/types/emoji.py @@ -0,0 +1,5 @@ +from .snowflake import Snowflake + + +class EmojiID(Snowflake): + """Represents a Discord Guild ID.""" diff --git a/discord/models/types/flags/__init__.py b/discord/models/types/flags/__init__.py new file mode 100644 index 0000000000..d2f20ad6a8 --- /dev/null +++ b/discord/models/types/flags/__init__.py @@ -0,0 +1,15 @@ +from .base import BaseFlags, fill_with_flags, flag_value +from .permissions import Permissions +from .role import RoleFlags +from .system_channel import SystemChannelFlags +from .user import UserFlags + +__all__ = [ + "SystemChannelFlags", + "Permissions", + "BaseFlags", + "flag_value", + "fill_with_flags", + "RoleFlags", + "UserFlags", +] diff --git a/discord/models/types/flags/base.py b/discord/models/types/flags/base.py new file mode 100644 index 0000000000..d818701512 --- /dev/null +++ b/discord/models/types/flags/base.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +from collections.abc import Iterator +from typing import Any, Callable, ClassVar, TypeVar, overload + +from pydantic import GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema +from typing_extensions import Self, override + +FV = TypeVar("FV", bound="flag_value") +BF = TypeVar("BF", bound="BaseFlags") + + +class flag_value: + def __init__( + self, func: Callable[[Any], int] # pyright: ignore [reportExplicitAny] + ): + self.flag: int = func(None) + self.__doc__ = func.__doc__ + + @overload + def __get__(self: FV, instance: None, owner: type[BF]) -> FV: ... + + @overload + def __get__(self, instance: BF, owner: type[BF]) -> bool: ... + + def __get__( + self, instance: BF | None, owner: type[BF] + ) -> Any: # pyright: ignore [reportExplicitAny] + if instance is None: + return self + return instance._has_flag(self.flag) # pyright: ignore [reportPrivateUsage] + + def __set__(self, instance: BaseFlags, value: bool) -> None: + instance._set_flag(self.flag, value) # pyright: ignore [reportPrivateUsage] + + @override + def __repr__(self): + return f"" + + +class alias_flag_value(flag_value): + pass + + +def fill_with_flags(*, inverted: bool = False): + def decorator(cls: type[BF]): + cls.VALID_FLAGS = { + name: value.flag + for name, value in cls.__dict__.items() + if isinstance(value, flag_value) + } + + if inverted: + max_bits = max(cls.VALID_FLAGS.values()).bit_length() + cls.DEFAULT_VALUE = -1 + (2**max_bits) + else: + cls.DEFAULT_VALUE = 0 + + return cls + + return decorator + + +# n.b. flags must inherit from this and use the decorator above +class BaseFlags: + VALID_FLAGS: ClassVar[dict[str, int]] + DEFAULT_VALUE: ClassVar[int] + + value: int + + __slots__: tuple[str] = ("value",) + + def __init__(self, **kwargs: bool): + self.value = self.DEFAULT_VALUE + for key, value in kwargs.items(): + if key not in self.VALID_FLAGS: + raise TypeError(f"{key!r} is not a valid flag name.") + setattr(self, key, value) + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, # pyright: ignore [reportExplicitAny] + _handler: GetCoreSchemaHandler, + ) -> CoreSchema: + return core_schema.json_or_python_schema( + # Both JSON and Python accept plain integers as input + json_schema=core_schema.chain_schema( + [ + core_schema.int_schema(), + core_schema.no_info_plain_validator_function(cls._from_value), + ] + ), + python_schema=core_schema.union_schema( + [ + # Accept existing instances of our class + core_schema.is_instance_schema(cls), + # Also accept plain integers and convert them + core_schema.chain_schema( + [ + core_schema.int_schema(), + core_schema.no_info_plain_validator_function( + cls._from_value + ), + ] + ), + ] + ), + # Convert to plain int when serializing to JSON + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: instance.value, + return_schema=core_schema.int_schema(), + when_used="json", + ), + ) + + @classmethod + def _from_value(cls: type[Self], value: int) -> Self: + self = cls.__new__(cls) + self.value = value + return self + + @override + def __eq__(self, other: Any) -> bool: # pyright: ignore [reportExplicitAny] + if not isinstance(other, self.__class__): + raise TypeError( + f"'==' not supported between instances of {type(self)} and {type(other)}" + ) + return isinstance(other, self.__class__) and self.value == other.value + + @override + def __hash__(self) -> int: + return hash(self.value) + + @override + def __repr__(self) -> str: + return f"<{self.__class__.__name__} value={self.value}>" + + def __iter__(self) -> Iterator[tuple[str, bool]]: + for name, value in self.__class__.__dict__.items(): + if isinstance(value, alias_flag_value): + continue + + if isinstance(value, flag_value): + yield name, self._has_flag(value.flag) + + def __and__(self, other: Self | flag_value) -> Self: + if isinstance(other, self.__class__): + return self.__class__._from_value(self.value & other.value) + elif isinstance(other, flag_value): + return self.__class__._from_value(self.value & other.flag) + else: + raise TypeError( + f"'&' not supported between instances of {type(self)} and {type(other)}" + ) + + def __or__(self, other: Self | flag_value) -> Self: + if isinstance(other, self.__class__): + return self.__class__._from_value(self.value | other.value) + elif isinstance(other, flag_value): + return self.__class__._from_value(self.value | other.flag) + else: + raise TypeError( + f"'|' not supported between instances of {type(self)} and {type(other)}" + ) + + def __add__(self, other: Self | flag_value) -> Self: + try: + return self | other + except TypeError: + raise TypeError( + f"'+' not supported between instances of {type(self)} and {type(other)}" + ) + + def __sub__(self, other: Self | flag_value) -> Self: + if isinstance(other, self.__class__): + return self.__class__._from_value(self.value & ~other.value) + elif isinstance(other, flag_value): + return self.__class__._from_value(self.value & ~other.flag) + else: + raise TypeError( + f"'-' not supported between instances of {type(self)} and {type(other)}" + ) + + def __invert__(self): + return self.__class__._from_value(~self.value) + + __rand__: Callable[[Self, Self | flag_value], Self] = __and__ + __ror__: Callable[[Self, Self | flag_value], Self] = __or__ + __radd__: Callable[[Self, Self | flag_value], Self] = __add__ + __rsub__: Callable[[Self, Self | flag_value], Self] = __sub__ + + def _has_flag(self, o: int) -> bool: + return (self.value & o) == o + + def _set_flag(self, o: int, toggle: bool) -> None: + if toggle: + self.value |= o + else: + self.value &= ~o diff --git a/discord/models/types/flags/permissions.py b/discord/models/types/flags/permissions.py new file mode 100644 index 0000000000..0cee1b45c9 --- /dev/null +++ b/discord/models/types/flags/permissions.py @@ -0,0 +1,657 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from typing_extensions import Self, final + +from .base import BaseFlags, alias_flag_value, flag_value + + +class permission_alias(alias_flag_value): + alias: str # pyright: ignore [reportUninitializedInstanceVariable] + + +def make_permission_alias( + alias: str, +) -> Callable[ + [Callable[[Any], int]], permission_alias +]: # pyright: ignore [reportExplicitAny] + def decorator( + func: Callable[[Any], int] + ) -> permission_alias: # pyright: ignore [reportExplicitAny] + ret = permission_alias(func) + ret.alias = alias + return ret + + return decorator + + +@final +class Permissions(BaseFlags): + """Wraps up the Discord permission value. + + The properties provided are two way. You can set and retrieve individual + bits using the properties as if they were regular bools. This allows + you to edit permissions. + + .. versionchanged:: 1.3 + You can now use keyword arguments to initialize :class:`Permissions` + similar to :meth:`update`. + + .. container:: operations + + .. describe:: x == y + + Checks if two permissions are equal. + .. describe:: x != y + + Checks if two permissions are not equal. + .. describe:: x <= y + + Checks if a permission is a subset of another permission. + .. describe:: x >= y + + Checks if a permission is a superset of another permission. + .. describe:: x < y + + Checks if a permission is a strict subset of another permission. + .. describe:: x > y + + .. describe:: x + y + + Adds two permissions together. Equivalent to ``x | y``. + .. describe:: x - y + + Subtracts two permissions from each other. + .. describe:: x | y + + Returns the union of two permissions. Equivalent to ``x + y``. + .. describe:: x & y + + Returns the intersection of two permissions. + .. describe:: ~x + + Returns the inverse of a permission. + + Checks if a permission is a strict superset of another permission. + .. describe:: hash(x) + + Return the permission's hash. + .. describe:: iter(x) + + Returns an iterator of ``(perm, value)`` pairs. This allows it + to be, for example, constructed as a dict or a list of pairs. + Note that aliases are not shown. + + Attributes + ---------- + value: :class:`int` + The raw value. This value is a bit array field of a 53-bit integer + representing the currently available permissions. You should query + permissions via the properties rather than using this raw value. + """ + + __slots__: tuple[str] = tuple() + + def __init__( + self, permissions: int = 0, **kwargs: bool + ) -> None: # pyright: ignore [reportMissingSuperCall] + if not isinstance( + permissions, int + ): # pyright: ignore [reportUnnecessaryIsInstance] + raise TypeError( # pyright: ignore [reportUnreachable] + f"Expected int parameter, received {permissions.__class__.__name__} instead." + ) + + self.value = permissions + for key, value in kwargs.items(): + if key not in self.VALID_FLAGS: + raise TypeError(f"{key!r} is not a valid permission name.") + setattr(self, key, value) + + def is_subset(self, other: Any) -> bool: # pyright: ignore [reportExplicitAny] + """Returns ``True`` if self has the same or fewer permissions as other.""" + if isinstance(other, Permissions): + return (self.value & other.value) == self.value + raise TypeError( + f"Cannot compare {self.__class__.__name__} with {other.__class__.__name__}" + ) + + def is_superset(self, other: Any) -> bool: # pyright: ignore [reportExplicitAny] + """Returns ``True`` if self has the same or more permissions as other.""" + if isinstance(other, Permissions): + return (self.value | other.value) == self.value + else: + raise TypeError( + f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}" + ) + + def is_strict_subset( + self, other: Any + ) -> bool: # pyright: ignore [reportExplicitAny] + """Returns ``True`` if the permissions on other are a strict subset of those on self.""" + return self.is_subset(other) and self != other + + def is_strict_superset( + self, other: Any + ) -> bool: # pyright: ignore [reportExplicitAny] + """Returns ``True`` if the permissions on other are a strict superset of those on self.""" + return self.is_superset(other) and self != other + + __le__: Callable[[Self, Any], bool] = ( + is_subset # pyright: ignore [reportExplicitAny] + ) + __ge__: Callable[[Self, Any], bool] = ( + is_superset # pyright: ignore [reportExplicitAny] + ) + __lt__: Callable[[Self, Any], bool] = ( + is_strict_subset # pyright: ignore [reportExplicitAny] + ) + __gt__: Callable[[Self, Any], bool] = ( + is_strict_superset # pyright: ignore [reportExplicitAny] + ) + + @classmethod + def none(cls: type[Self]) -> Self: + """A factory method that creates a :class:`Permissions` with all + permissions set to ``False``. + """ + return cls(0) + + @classmethod + def all(cls: type[Self]) -> Self: + """A factory method that creates a :class:`Permissions` with all + permissions set to ``True``. + """ + return cls(0b1111111111111111111111111111111111111111111111111) + + @classmethod + def all_channel(cls: type[Self]) -> Self: + """A :class:`Permissions` with all channel-specific permissions set to + ``True`` and the guild-specific ones set to ``False``. The guild-specific + permissions are currently: + + - :attr:`manage_emojis` + - :attr:`view_audit_log` + - :attr:`view_guild_insights` + - :attr:`view_creator_monetization_analytics` + - :attr:`manage_guild` + - :attr:`change_nickname` + - :attr:`manage_nicknames` + - :attr:`kick_members` + - :attr:`ban_members` + - :attr:`administrator` + + .. versionchanged:: 1.7 + Added :attr:`stream`, :attr:`priority_speaker` and :attr:`use_slash_commands` permissions. + + .. versionchanged:: 2.0 + Added :attr:`create_public_threads`, :attr:`create_private_threads`, :attr:`manage_threads`, + :attr:`use_external_stickers`, :attr:`send_messages_in_threads` and + :attr:`request_to_speak` permissions. + """ + return cls(0b111110110110011111101111111111101010001) + + @classmethod + def general(cls: type[Self]) -> Self: + """A factory method that creates a :class:`Permissions` with all + "General" permissions from the official Discord UI set to ``True``. + + .. versionchanged:: 1.7 + Permission :attr:`read_messages` is now included in the general permissions, but + permissions :attr:`administrator`, :attr:`create_instant_invite`, :attr:`kick_members`, + :attr:`ban_members`, :attr:`change_nickname` and :attr:`manage_nicknames` are + no longer part of the general permissions. + .. versionchanged:: 2.7 + Added :attr:`view_creator_monetization_analytics` permission. + """ + return cls(0b100000000001110000000010000000010010110000) + + @classmethod + def membership(cls: type[Self]) -> Self: + """A factory method that creates a :class:`Permissions` with all + "Membership" permissions from the official Discord UI set to ``True``. + + .. versionadded:: 1.7 + """ + return cls(0b00001100000000000000000000000111) + + @classmethod + def text(cls: type[Self]) -> Self: + """A factory method that creates a :class:`Permissions` with all + "Text" permissions from the official Discord UI set to ``True``. + + .. versionchanged:: 1.7 + Permission :attr:`read_messages` is no longer part of the text permissions. + Added :attr:`use_slash_commands` permission. + + .. versionchanged:: 2.0 + Added :attr:`create_public_threads`, :attr:`create_private_threads`, :attr:`manage_threads`, + :attr:`send_messages_in_threads` and :attr:`use_external_stickers` permissions. + """ + return cls(0b111110010000000000001111111100001000000) + + @classmethod + def voice(cls: type[Self]) -> Self: + """A factory method that creates a :class:`Permissions` with all + "Voice" permissions from the official Discord UI set to ``True``. + """ + return cls(0b1001001001000000000000011111100000000001100000000) + + @classmethod + def stage(cls: type[Self]) -> Self: + """A factory method that creates a :class:`Permissions` with all + "Stage Channel" permissions from the official Discord UI set to ``True``. + + .. versionadded:: 1.7 + """ + return cls(1 << 32) + + @classmethod + def stage_moderator(cls: type[Self]) -> Self: + """A factory method that creates a :class:`Permissions` with all + "Stage Moderator" permissions from the official Discord UI set to ``True``. + + .. versionadded:: 1.7 + """ + return cls(0b100000001010000000000000000000000) + + @classmethod + def advanced(cls: type[Self]) -> Self: + """A factory method that creates a :class:`Permissions` with all + "Advanced" permissions from the official Discord UI set to ``True``. + + .. versionadded:: 1.7 + """ + return cls(1 << 3) + + def update(self, **kwargs: bool) -> None: + r"""Bulk updates this permission object. + + Allows you to set multiple attributes by using keyword + arguments. The names must be equivalent to the properties + listed. Extraneous key/value pairs will be silently ignored. + + Parameters + ------------ + \*\*kwargs + A list of key/value pairs to bulk update permissions with. + """ + for key, value in kwargs.items(): + if key in self.VALID_FLAGS: + setattr(self, key, value) + + def handle_overwrite(self, allow: int, deny: int) -> None: + # Basically this is what's happening here. + # We have an original bit array, e.g. 1010 + # Then we have another bit array that is 'denied', e.g. 1111 + # And then we have the last one which is 'allowed', e.g. 0101 + # We want original OP denied to end up resulting in + # whatever is in denied to be set to 0. + # So 1010 OP 1111 -> 0000 + # Then we take this value and look at the allowed values. + # And whatever is allowed is set to 1. + # So 0000 OP2 0101 -> 0101 + # The OP is base & ~denied. + # The OP2 is base | allowed. + self.value = (self.value & ~deny) | allow + + @flag_value + def create_instant_invite(self) -> int: + """:class:`bool`: Returns ``True`` if the user can create instant invites.""" + return 1 << 0 + + @flag_value + def kick_members(self) -> int: + """:class:`bool`: Returns ``True`` if the user can kick users from the guild.""" + return 1 << 1 + + @flag_value + def ban_members(self) -> int: + """:class:`bool`: Returns ``True`` if a user can ban users from the guild.""" + return 1 << 2 + + @flag_value + def administrator(self) -> int: + """:class:`bool`: Returns ``True`` if a user is an administrator. This role overrides all other permissions. + + This also bypasses all channel-specific overrides. + """ + return 1 << 3 + + @flag_value + def manage_channels(self) -> int: + """:class:`bool`: Returns ``True`` if a user can edit, delete, or create channels in the guild. + + This also corresponds to the "Manage Channel" channel-specific override. + """ + return 1 << 4 + + @flag_value + def manage_guild(self) -> int: + """:class:`bool`: Returns ``True`` if a user can edit guild properties.""" + return 1 << 5 + + @flag_value + def add_reactions(self) -> int: + """:class:`bool`: Returns ``True`` if a user can add reactions to messages.""" + return 1 << 6 + + @flag_value + def view_audit_log(self) -> int: + """:class:`bool`: Returns ``True`` if a user can view the guild's audit log.""" + return 1 << 7 + + @flag_value + def priority_speaker(self) -> int: + """:class:`bool`: Returns ``True`` if a user can be more easily heard while talking.""" + return 1 << 8 + + @flag_value + def stream(self) -> int: + """:class:`bool`: Returns ``True`` if a user can stream in a voice channel.""" + return 1 << 9 + + @flag_value + def view_channel(self) -> int: + """:class:`bool`: Returns ``True`` if a user can view all or specific channels.""" + return 1 << 10 + + @make_permission_alias("view_channel") + def read_messages(self) -> int: + """:class:`bool`: An alias for :attr:`view_channel`. + + .. versionadded:: 1.3 + """ + return 1 << 10 + + @flag_value + def send_messages(self) -> int: + """:class:`bool`: Returns ``True`` if a user can send messages from all or specific text channels.""" + return 1 << 11 + + @flag_value + def send_tts_messages(self) -> int: + """:class:`bool`: Returns ``True`` if a user can send TTS messages from all or specific text channels.""" + return 1 << 12 + + @flag_value + def manage_messages(self) -> int: + """:class:`bool`: Returns ``True`` if a user can delete or pin messages in a text channel. + + .. note:: + + Note that there are currently no ways to edit other people's messages. + """ + return 1 << 13 + + @flag_value + def embed_links(self) -> int: + """:class:`bool`: Returns ``True`` if a user's messages will automatically be embedded by Discord.""" + return 1 << 14 + + @flag_value + def attach_files(self) -> int: + """:class:`bool`: Returns ``True`` if a user can send files in their messages.""" + return 1 << 15 + + @flag_value + def read_message_history(self) -> int: + """:class:`bool`: Returns ``True`` if a user can read a text channel's previous messages.""" + return 1 << 16 + + @flag_value + def mention_everyone(self) -> int: + """:class:`bool`: Returns ``True`` if a user's @everyone or @here will mention everyone in the text channel.""" + return 1 << 17 + + @flag_value + def external_emojis(self) -> int: + """:class:`bool`: Returns ``True`` if a user can use emojis from other guilds.""" + return 1 << 18 + + @make_permission_alias("external_emojis") + def use_external_emojis(self) -> int: + """:class:`bool`: An alias for :attr:`external_emojis`. + + .. versionadded:: 1.3 + """ + return 1 << 18 + + @flag_value + def view_guild_insights(self) -> int: + """:class:`bool`: Returns ``True`` if a user can view the guild's insights. + + .. versionadded:: 1.3 + """ + return 1 << 19 + + @flag_value + def connect(self) -> int: + """:class:`bool`: Returns ``True`` if a user can connect to a voice channel.""" + return 1 << 20 + + @flag_value + def speak(self) -> int: + """:class:`bool`: Returns ``True`` if a user can speak in a voice channel.""" + return 1 << 21 + + @flag_value + def mute_members(self) -> int: + """:class:`bool`: Returns ``True`` if a user can mute other users.""" + return 1 << 22 + + @flag_value + def deafen_members(self) -> int: + """:class:`bool`: Returns ``True`` if a user can deafen other users.""" + return 1 << 23 + + @flag_value + def move_members(self) -> int: + """:class:`bool`: Returns ``True`` if a user can move users between other voice channels.""" + return 1 << 24 + + @flag_value + def use_voice_activation(self) -> int: + """:class:`bool`: Returns ``True`` if a user can use voice activation in voice channels.""" + return 1 << 25 + + @flag_value + def change_nickname(self) -> int: + """:class:`bool`: Returns ``True`` if a user can change their nickname in the guild.""" + return 1 << 26 + + @flag_value + def manage_nicknames(self) -> int: + """:class:`bool`: Returns ``True`` if a user can change other user's nickname in the guild.""" + return 1 << 27 + + @flag_value + def manage_roles(self) -> int: + """:class:`bool`: Returns ``True`` if a user can create or edit roles less than their role's position. + + This also corresponds to the "Manage Permissions" channel-specific override. + """ + return 1 << 28 + + @make_permission_alias("manage_roles") + def manage_permissions(self) -> int: + """:class:`bool`: An alias for :attr:`manage_roles`. + + .. versionadded:: 1.3 + """ + return 1 << 28 + + @flag_value + def manage_webhooks(self) -> int: + """:class:`bool`: Returns ``True`` if a user can create, edit, or delete webhooks.""" + return 1 << 29 + + @flag_value + def manage_emojis(self) -> int: + """:class:`bool`: Returns ``True`` if a user can create, edit, or delete emojis.""" + return 1 << 30 + + @make_permission_alias("manage_emojis") + def manage_emojis_and_stickers(self) -> int: + """:class:`bool`: An alias for :attr:`manage_emojis`. + + .. versionadded:: 2.0 + """ + return 1 << 30 + + @flag_value + def use_slash_commands(self) -> int: + """:class:`bool`: Returns ``True`` if a user can use slash commands. + + .. versionadded:: 1.7 + """ + return 1 << 31 + + @make_permission_alias("use_slash_commands") + def use_application_commands(self) -> int: + """:class:`bool`: An alias for :attr:`use_slash_commands`. + + .. versionadded:: 2.0 + """ + return 1 << 31 + + @flag_value + def request_to_speak(self) -> int: + """:class:`bool`: Returns ``True`` if a user can request to speak in a stage channel. + + .. versionadded:: 1.7 + """ + return 1 << 32 + + @flag_value + def manage_events(self) -> int: + """:class:`bool`: Returns ``True`` if a user can manage guild events. + + .. versionadded:: 2.0 + """ + return 1 << 33 + + @flag_value + def manage_threads(self) -> int: + """:class:`bool`: Returns ``True`` if a user can manage threads. + + .. versionadded:: 2.0 + """ + return 1 << 34 + + @flag_value + def create_public_threads(self) -> int: + """:class:`bool`: Returns ``True`` if a user can create public threads. + + .. versionadded:: 2.0 + """ + return 1 << 35 + + @flag_value + def create_private_threads(self) -> int: + """:class:`bool`: Returns ``True`` if a user can create private threads. + + .. versionadded:: 2.0 + """ + return 1 << 36 + + @flag_value + def external_stickers(self) -> int: + """:class:`bool`: Returns ``True`` if a user can use stickers from other guilds. + + .. versionadded:: 2.0 + """ + return 1 << 37 + + @make_permission_alias("external_stickers") + def use_external_stickers(self) -> int: + """:class:`bool`: An alias for :attr:`external_stickers`. + + .. versionadded:: 2.0 + """ + return 1 << 37 + + @flag_value + def send_messages_in_threads(self) -> int: + """:class:`bool`: Returns ``True`` if a user can send messages in threads. + + .. versionadded:: 2.0 + """ + return 1 << 38 + + @flag_value + def start_embedded_activities(self) -> int: + """:class:`bool`: Returns ``True`` if a user can launch an activity flagged 'EMBEDDED' in a voice channel. + + .. versionadded:: 2.0 + """ + return 1 << 39 + + @flag_value + def moderate_members(self) -> int: + """:class:`bool`: Returns ``True`` if a user can moderate members (timeout). + + .. versionadded:: 2.0 + """ + return 1 << 40 + + @flag_value + def view_creator_monetization_analytics(self) -> int: + """:class:`bool`: Returns ``True`` if a user can view creator monetization (role subscription) analytics. + + .. versionadded:: 2.7 + """ + return 1 << 41 + + @flag_value + def use_soundboard(self) -> int: + """:class:`bool`: Returns ``True`` if a user can use the soundboard in a voice channel. + + .. versionadded:: 2.7 + """ + return 1 << 42 + + @flag_value + def use_external_sounds(self) -> int: + """:class:`bool`: Returns ``True`` if a user can use external soundboard sounds in a voice channel. + + .. versionadded:: 2.7 + """ + return 1 << 45 + + @flag_value + def send_voice_messages(self) -> int: + """:class:`bool`: Returns ``True`` if a member can send voice messages. + + .. versionadded:: 2.5 + """ + return 1 << 46 + + @flag_value + def set_voice_channel_status(self) -> int: + """:class:`bool`: Returns ``True`` if a member can set voice channel status. + + .. versionadded:: 2.5 + """ + return 1 << 48 + + @flag_value + def send_polls(self) -> int: + """:class:`bool`: Returns ``True`` if a member can send polls. + + .. versionadded:: 2.6 + """ + return 1 << 49 + + @flag_value + def use_external_apps(self) -> int: + """:class:`bool`: Returns ``True`` if a member's user-installed apps can show public responses. + Users will still be able to use user-installed apps, but responses will be ephemeral. + + This only applies to apps that are also not installed to the guild. + + .. versionadded:: 2.6 + """ + return 1 << 50 diff --git a/discord/models/types/flags/role.py b/discord/models/types/flags/role.py new file mode 100644 index 0000000000..70dc0df32d --- /dev/null +++ b/discord/models/types/flags/role.py @@ -0,0 +1,58 @@ +from typing_extensions import final + +from .base import BaseFlags, fill_with_flags, flag_value + + +@final +@fill_with_flags() +class RoleFlags(BaseFlags): + r"""Wraps up the Discord Role flags. + + .. container:: operations + + .. describe:: x == y + + Checks if two RoleFlags are equal. + .. describe:: x != y + + Checks if two RoleFlags are not equal. + .. describe:: x + y + + Adds two flags together. Equivalent to ``x | y``. + .. describe:: x - y + + Subtracts two flags from each other. + .. describe:: x | y + + Returns the union of two flags. Equivalent to ``x + y``. + .. describe:: x & y + + Returns the intersection of two flags. + .. describe:: ~x + + Returns the inverse of a flag. + .. describe:: hash(x) + + Return the flag's hash. + .. describe:: iter(x) + + Returns an iterator of ``(name, value)`` pairs. This allows it + to be, for example, constructed as a dict or a list of pairs. + Note that aliases are not shown. + + .. versionadded:: 2.6 + + Attributes + ----------- + value: :class:`int` + The raw value. This value is a bit array field of a 53-bit integer + representing the currently available flags. You should query + flags via the properties rather than using this raw value. + """ + + __slots__ = tuple() + + @flag_value + def in_prompt(self): + """:class:`bool`: Returns ``True`` if the role is selectable in one of the guild's :class:`~discord.OnboardingPrompt`.""" + return 1 << 0 diff --git a/discord/models/types/flags/system_channel.py b/discord/models/types/flags/system_channel.py new file mode 100644 index 0000000000..60a509b55c --- /dev/null +++ b/discord/models/types/flags/system_channel.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from typing_extensions import final, override + +from .base import BaseFlags, fill_with_flags, flag_value + + +@final +@fill_with_flags(inverted=True) +class SystemChannelFlags(BaseFlags): + r"""Wraps up a Discord system channel flag value. + + Similar to :class:`Permissions`\, the properties provided are two way. + You can set and retrieve individual bits using the properties as if they + were regular bools. This allows you to edit the system flags easily. + + To construct an object you can pass keyword arguments denoting the flags + to enable or disable. + + .. container:: operations + + .. describe:: x == y + + Checks if two flags are equal. + .. describe:: x != y + + Checks if two flags are not equal. + .. describe:: x + y + + Adds two flags together. Equivalent to ``x | y``. + .. describe:: x - y + + Subtracts two flags from each other. + .. describe:: x | y + + Returns the union of two flags. Equivalent to ``x + y``. + .. describe:: x & y + + Returns the intersection of two flags. + .. describe:: ~x + + Returns the inverse of a flag. + .. describe:: hash(x) + + Return the flag's hash. + .. describe:: iter(x) + + Returns an iterator of ``(name, value)`` pairs. This allows it + to be, for example, constructed as a dict or a list of pairs. + + Attributes + ----------- + value: :class:`int` + The raw value. This value is a bit array field of a 53-bit integer + representing the currently available flags. You should query + flags via the properties rather than using this raw value. + """ + + __slots__ = tuple() + + # For some reason the flags for system channels are "inverted" + # ergo, if they're set then it means "suppress" (off in the GUI toggle) + # Since this is counter-intuitive from an API perspective and annoying + # these will be inverted automatically + + @override + def _has_flag(self, o: int) -> bool: + return (self.value & o) != o + + @override + def _set_flag(self, o: int, toggle: bool) -> None: + if toggle: + self.value &= ~o + else: + self.value |= o + + @flag_value + def join_notifications(self): + """:class:`bool`: Returns ``True`` if the system channel is used for member join notifications.""" + return 1 + + @flag_value + def premium_subscriptions(self): + """:class:`bool`: Returns ``True`` if the system channel is used for "Nitro boosting" notifications.""" + return 2 + + @flag_value + def guild_reminder_notifications(self): + """:class:`bool`: Returns ``True`` if the system channel is used for server setup helpful tips notifications. + + .. versionadded:: 2.0 + """ + return 4 + + @flag_value + def join_notification_replies(self): + """:class:`bool`: Returns ``True`` if the system channel is allowing member join sticker replies. + + .. versionadded:: 2.0 + """ + return 8 diff --git a/discord/models/types/flags/user.py b/discord/models/types/flags/user.py new file mode 100644 index 0000000000..0622f148c7 --- /dev/null +++ b/discord/models/types/flags/user.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +from typing import Any, ClassVar + +from typing_extensions import Self, final + +from .base import BaseFlags, flag_value + + +@final +class UserFlags(BaseFlags): + """Wraps up the Discord user flags value. + + The properties provided are two way. You can set and retrieve individual + bits using the properties as if they were regular bools. + + .. container:: operations + + .. describe:: x == y + + Checks if two flags are equal. + .. describe:: x != y + + Checks if two flags are not equal. + .. describe:: x <= y + + Checks if a flag is a subset of another flag. + .. describe:: x >= y + + Checks if a flag is a superset of another flag. + .. describe:: x < y + + Checks if a flag is a strict subset of another flag. + .. describe:: x > y + + Checks if a flag is a strict superset of another flag. + .. describe:: hash(x) + + Return the flag's hash. + .. describe:: iter(x) + + Returns an iterator of ``(flag, value)`` pairs. This allows it + to be, for example, constructed as a dict or a list of pairs. + + Attributes + ---------- + value: :class:`int` + The raw value. This value is a bit array field of a 23-bit integer + representing the currently available flags. You should query + flags via the properties rather than using this raw value. + """ + + DEFAULT_VALUE: ClassVar[int] = 0 + + def __init__(self, value: int = 0, **kwargs: bool) -> None: + super().__init__(**kwargs) + self.value: int = value + + @classmethod + def none(cls: type[Self]) -> Self: + """A factory method that creates a :class:`UserFlags` with no flags set.""" + return cls(0) + + @classmethod + def all(cls: type[Self]) -> Self: + """A factory method that creates a :class:`UserFlags` with all flags set.""" + return cls(0b11111111111111111111111) + + @flag_value + def staff(self) -> int: + """:class:`bool`: Returns ``True`` if the user is a Discord Employee.""" + return 1 << 0 + + @flag_value + def partner(self) -> int: + """:class:`bool`: Returns ``True`` if the user is a Discord Partner.""" + return 1 << 1 + + @flag_value + def hypesquad(self) -> int: + """:class:`bool`: Returns ``True`` if the user is a HypeSquad Events member.""" + return 1 << 2 + + @flag_value + def bug_hunter_level_1(self) -> int: + """:class:`bool`: Returns ``True`` if the user is a Bug Hunter Level 1.""" + return 1 << 3 + + @flag_value + def hypesquad_house_bravery(self) -> int: + """:class:`bool`: Returns ``True`` if the user is a member of House Bravery.""" + return 1 << 6 + + @flag_value + def hypesquad_house_brilliance(self) -> int: + """:class:`bool`: Returns ``True`` if the user is a member of House Brilliance.""" + return 1 << 7 + + @flag_value + def hypesquad_house_balance(self) -> int: + """:class:`bool`: Returns ``True`` if the user is a member of House Balance.""" + return 1 << 8 + + @flag_value + def premium_early_supporter(self) -> int: + """:class:`bool`: Returns ``True`` if the user is an Early Nitro Supporter.""" + return 1 << 9 + + @flag_value + def team_pseudo_user(self) -> int: + """:class:`bool`: Returns ``True`` if the user is a team.""" + return 1 << 10 + + @flag_value + def bug_hunter_level_2(self) -> int: + """:class:`bool`: Returns ``True`` if the user is a Bug Hunter Level 2.""" + return 1 << 14 + + @flag_value + def verified_bot(self) -> int: + """:class:`bool`: Returns ``True`` if the user is a Verified Bot.""" + return 1 << 16 + + @flag_value + def verified_developer(self) -> int: + """:class:`bool`: Returns ``True`` if the user is an Early Verified Bot Developer.""" + return 1 << 17 + + @flag_value + def certified_moderator(self) -> int: + """:class:`bool`: Returns ``True`` if the user is a Discord Certified Moderator.""" + return 1 << 18 + + @flag_value + def bot_http_interactions(self) -> int: + """:class:`bool`: Returns ``True`` if the bot uses only HTTP interactions and is shown in the online member list.""" + return 1 << 19 + + @flag_value + def active_developer(self) -> int: + """:class:`bool`: Returns ``True`` if the user is an Active Developer.""" + return 1 << 22 + + def is_subset(self, other: Any) -> bool: # pyright: ignore [reportExplicitAny] + """Returns ``True`` if self has the same or fewer flags as other.""" + if isinstance(other, UserFlags): + return (self.value & other.value) == self.value + raise TypeError( + f"Cannot compare {self.__class__.__name__} with {other.__class__.__name__}" + ) + + def is_superset(self, other: Any) -> bool: # pyright: ignore [reportExplicitAny] + """Returns ``True`` if self has the same or more flags as other.""" + if isinstance(other, UserFlags): + return (self.value | other.value) == self.value + raise TypeError( + f"Cannot compare {self.__class__.__name__} with {other.__class__.__name__}" + ) + + def is_strict_subset( + self, other: Any + ) -> bool: # pyright: ignore [reportExplicitAny] + """Returns ``True`` if the flags on other are a strict subset of those on self.""" + return self.is_subset(other) and self != other + + def is_strict_superset( + self, other: Any + ) -> bool: # pyright: ignore [reportExplicitAny] + """Returns ``True`` if the flags on other are a strict superset of those on self.""" + return self.is_superset(other) and self != other + + __le__ = is_subset + __ge__ = is_superset + __lt__ = is_strict_subset + __gt__ = is_strict_superset diff --git a/discord/models/types/guild.py b/discord/models/types/guild.py new file mode 100644 index 0000000000..6518fb9938 --- /dev/null +++ b/discord/models/types/guild.py @@ -0,0 +1,5 @@ +from .snowflake import Snowflake + + +class GuildID(Snowflake): + """Represents a Discord Guild ID.""" diff --git a/discord/models/types/locale.py b/discord/models/types/locale.py new file mode 100644 index 0000000000..e25aeef690 --- /dev/null +++ b/discord/models/types/locale.py @@ -0,0 +1,39 @@ +from enum import Enum + +from typing_extensions import final + + +@final +class Locale(str, Enum): + Indonesian = "id" + Danish = "da" + German = "de" + EnglishUK = "en-GB" + EnglishUS = "en-US" + Spanish = "es-ES" + SpanishLATAM = "es-419" + French = "fr" + Croatian = "hr" + Italian = "it" + Lithuanian = "lt" + Hungarian = "hu" + Dutch = "nl" + Norwegian = "no" + Polish = "pl" + PortugueseBrazilian = "pt-BR" + Romanian = "ro" + Finnish = "fi" + Swedish = "sv-SE" + Vietnamese = "vi" + Turkish = "tr" + Czech = "cs" + Greek = "el" + Bulgarian = "bg" + Russian = "ru" + Ukrainian = "uk" + Hindi = "hi" + Thai = "th" + ChineseChina = "zh-CN" + Japanese = "ja" + ChineseTaiwan = "zh-TW" + Korean = "ko" diff --git a/discord/models/types/role.py b/discord/models/types/role.py new file mode 100644 index 0000000000..76aca6220a --- /dev/null +++ b/discord/models/types/role.py @@ -0,0 +1,5 @@ +from .snowflake import Snowflake + + +class RoleID(Snowflake): + """Represents a Discord Role ID.""" diff --git a/discord/models/types/snowflake.py b/discord/models/types/snowflake.py new file mode 100644 index 0000000000..1585557260 --- /dev/null +++ b/discord/models/types/snowflake.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from pydantic import GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema +from typing_extensions import override + +DISCORD_EPOCH = 1420070400000 # First second of 2015 + + +class Snowflake(int): + """Represents a Discord snowflake ID.""" + + @classmethod + def from_datetime(cls, dt: datetime) -> Snowflake: + """Creates a Snowflake from a datetime object.""" + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + unix_ms = int(dt.timestamp() * 1000) + discord_ms = unix_ms - DISCORD_EPOCH + return cls(discord_ms << 22) + + @property + def timestamp(self) -> datetime: + """Returns the creation time of this snowflake.""" + ms = (self >> 22) + DISCORD_EPOCH + return datetime.fromtimestamp(ms / 1000.0, tz=timezone.utc) + + @property + def worker_id(self) -> int: + """Returns the internal worker ID.""" + return (self & 0x3E0000) >> 17 + + @property + def process_id(self) -> int: + """Returns the internal process ID.""" + return (self & 0x1F000) >> 12 + + @property + def increment(self) -> int: + """Returns the increment count.""" + return self & 0xFFF + + @property + def id(self) -> int: + return int(self) + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, # pyright: ignore [reportExplicitAny] + _handler: GetCoreSchemaHandler, + ) -> CoreSchema: + """Define how Pydantic should handle validation and serialization of Snowflakes""" + + def validate_from_int(value: int) -> Snowflake: + if value < 0: + raise ValueError("Snowflake cannot be negative") + return cls(value) + + from_int_schema = core_schema.chain_schema( + [ + core_schema.int_schema(), + core_schema.no_info_plain_validator_function(validate_from_int), + ] + ) + + return core_schema.json_or_python_schema( + json_schema=from_int_schema, + python_schema=core_schema.union_schema( + [ + # Check if it's already a Snowflake instance + core_schema.is_instance_schema(cls), + from_int_schema, + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + str, return_schema=core_schema.str_schema(), when_used="json" + ), + ) + + @override + def __repr__(self) -> str: + return f"" + + @override + def __str__(self) -> str: + return str(int(self)) + + +__all__ = ["Snowflake"] diff --git a/discord/models/types/sticker.py b/discord/models/types/sticker.py new file mode 100644 index 0000000000..5cfbe7fc8e --- /dev/null +++ b/discord/models/types/sticker.py @@ -0,0 +1,5 @@ +from .snowflake import Snowflake + + +class StickerID(Snowflake): + """Represents a Discord Guild ID.""" diff --git a/discord/models/types/user.py b/discord/models/types/user.py new file mode 100644 index 0000000000..13bd8d5d5f --- /dev/null +++ b/discord/models/types/user.py @@ -0,0 +1,5 @@ +from .snowflake import Snowflake + + +class UserID(Snowflake): + """Represents a Discord Channel ID.""" diff --git a/discord/models/types/utils.py b/discord/models/types/utils.py new file mode 100644 index 0000000000..cc4e3b1ae5 --- /dev/null +++ b/discord/models/types/utils.py @@ -0,0 +1,11 @@ +from typing_extensions import TypeAlias, final + + +@final +class MISSING: + pass + + +MissingSentinel: TypeAlias = type[MISSING] + +__all__ = ["MISSING", "MissingSentinel"] diff --git a/discord/state.py b/discord/state.py index cf74d99285..1d669931ed 100644 --- a/discord/state.py +++ b/discord/state.py @@ -43,7 +43,7 @@ Union, ) -from . import utils +from . import models, utils from .activity import BaseActivity from .audit_logs import AuditLogEntry from .automod import AutoModRule @@ -657,18 +657,18 @@ async def _delay_ready(self) -> None: finally: self._ready_task = None - def parse_ready(self, data) -> None: + def parse_ready(self, data: models.gateway.ReadyData) -> None: if self._ready_task is not None: self._ready_task.cancel() self._ready_state = asyncio.Queue() self.clear(views=False) - self.user = ClientUser(state=self, data=data["user"]) - self.store_user(data["user"]) + self.user = ClientUser(state=self, data=data.user) + self.store_user(data.user) # TODO Rewrite cache to use model_dump if self.application_id is None: try: - application = data["application"] + application = data.application except KeyError: pass else: @@ -676,12 +676,18 @@ def parse_ready(self, data) -> None: # flags will always be present here self.application_flags = ApplicationFlags._from_value(application["flags"]) # type: ignore - for guild_data in data["guilds"]: - self._add_guild_from_data(guild_data) + for guild_data in data.guilds: + self._add_guild_from_data( + guild_data.model_dump() + ) # TODO Rewrite cache to support Pydantic self.dispatch("connect") self._ready_task = asyncio.create_task(self._delay_ready()) + parse_ready._supports_model = ( # pyright: ignore [reportFunctionMemberAccess] + models.gateway.Ready + ) + def parse_resumed(self, data) -> None: self.dispatch("resumed") diff --git a/discord/user.py b/discord/user.py index 9fa995cf66..432fe47182 100644 --- a/discord/user.py +++ b/discord/user.py @@ -25,10 +25,14 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, TypeVar +from typing_extensions import override + import discord.abc +from . import models from .asset import Asset from .colour import Colour from .flags import PublicUserFlags @@ -58,12 +62,12 @@ class _UserTag: - __slots__ = () + __slots__ = () # pyright: ignore [reportUnannotatedClassAttribute] id: int class BaseUser(_UserTag): - __slots__ = ( + __slots__ = ( # pyright: ignore [reportUnannotatedClassAttribute] "name", "id", "discriminator", @@ -88,16 +92,15 @@ class BaseUser(_UserTag): _state: ConnectionState _avatar: str | None _banner: str | None - _accent_colour: int | None - _avatar_decoration: dict | None - _public_flags: int + _accent_colour: models.types.Color | None + _avatar_decoration: models.AvatarDecorationData | None + _public_flags: models.types.UserFlags | None - def __init__( - self, *, state: ConnectionState, data: UserPayload | PartialUserPayload - ) -> None: + def __init__(self, *, state: ConnectionState, data: models.User) -> None: self._state = state self._update(data) + @override def __repr__(self) -> str: if self.is_migrated: if self.global_name is not None: @@ -117,6 +120,7 @@ def __repr__(self) -> str: f" bot={self.bot} system={self.system}>" ) + @override def __str__(self) -> str: return ( f"{self.name}#{self.discriminator}" @@ -128,24 +132,36 @@ def __str__(self) -> str: ) ) + @override def __eq__(self, other: Any) -> bool: return isinstance(other, _UserTag) and other.id == self.id + @override def __hash__(self) -> int: return self.id >> 22 - def _update(self, data: UserPayload) -> None: - self.name = data["username"] - self.id = int(data["id"]) - self.discriminator = data["discriminator"] - self.global_name = data.get("global_name", None) or None - self._avatar = data["avatar"] - self._banner = data.get("banner", None) - self._accent_colour = data.get("accent_color", None) - self._avatar_decoration = data.get("avatar_decoration_data", None) - self._public_flags = data.get("public_flags", 0) - self.bot = data.get("bot", False) - self.system = data.get("system", False) + def _update(self, data: models.User) -> None: + self.name = data.username + self.id = data.id + self.discriminator = data.discriminator + self.global_name = data.global_name + self._avatar = data.avatar + self._banner = data.banner if data.banner is not models.MISSING else None + self._accent_colour = ( + data.accent_color if data.accent_color is not models.MISSING else None + ) + self._avatar_decoration = ( + data.avatar_decoration_data + if data.avatar_decoration_data is not models.MISSING + else None + ) + self._public_flags = ( + data.public_flags + if data.public_flags is not models.MISSING + else models.types.UserFlags.none() + ) + self.bot = data.bot if data.bot is not models.MISSING else False + self.system = data.system if data.system is not models.MISSING else False @classmethod def _copy(cls: type[BU], user: BU) -> BU: @@ -206,7 +222,9 @@ def default_avatar(self) -> Asset: """ eq = (self.id >> 22) if self.is_migrated else int(self.discriminator) perc = 6 if self.is_migrated else 5 - return Asset._from_default_avatar(self._state, eq % perc) + return Asset._from_default_avatar( + self._state, eq % perc + ) # pyright: ignore [reportPrivateUsage] @property def display_avatar(self) -> Asset: @@ -229,7 +247,9 @@ def banner(self) -> Asset | None: """ if self._banner is None: return None - return Asset._from_user_banner(self._state, self.id, self._banner) + return Asset._from_user_banner( + self._state, self.id, self._banner + ) # pyright: ignore [reportPrivateUsage] @property def avatar_decoration(self) -> Asset | None: @@ -239,8 +259,8 @@ def avatar_decoration(self) -> Asset | None: """ if self._avatar_decoration is None: return None - return Asset._from_avatar_decoration( - self._state, self.id, self._avatar_decoration.get("asset") + return Asset._from_avatar_decoration( # pyright: ignore [reportPrivateUsage] + self._state, self.id, self._avatar_decoration.asset ) @property @@ -387,17 +407,24 @@ class ClientUser(BaseUser): Specifies if the user has MFA turned on and working. """ - __slots__ = ("locale", "_flags", "verified", "mfa_enabled", "__weakref__") + __slots__ = ( + "locale", + "_flags", + "verified", + "mfa_enabled", + "__weakref__", + ) # pyright: ignore [reportUnannotatedClassAttribute] if TYPE_CHECKING: verified: bool - locale: str | None + locale: models.types.Locale | None mfa_enabled: bool - _flags: int + _flags: models.types.UserFlags - def __init__(self, *, state: ConnectionState, data: UserPayload) -> None: + def __init__(self, *, state: ConnectionState, data: models.User) -> None: super().__init__(state=state, data=data) + @override def __repr__(self) -> str: if self.is_migrated: if self.global_name is not None: @@ -417,13 +444,20 @@ def __repr__(self) -> str: f" bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>" ) - def _update(self, data: UserPayload) -> None: + @override + def _update(self, data: models.User) -> None: super()._update(data) # There's actually an Optional[str] phone field as well, but I won't use it - self.verified = data.get("verified", False) - self.locale = data.get("locale") - self._flags = data.get("flags", 0) - self.mfa_enabled = data.get("mfa_enabled", False) + self.verified = data.verified if data.verified is not models.MISSING else False + self.locale = data.locale if data.locale is not models.MISSING else None + self._flags = ( + data.flags + if data.flags is not models.MISSING + else models.types.UserFlags.none() + ) + self.mfa_enabled = ( + data.mfa_enabled if data.mfa_enabled is not models.MISSING else False + ) # TODO: Username might not be able to edit anymore. async def edit( @@ -538,7 +572,15 @@ class User(BaseUser, discord.abc.Messageable): __slots__ = ("_stored",) - def __init__(self, *, state: ConnectionState, data: UserPayload) -> None: + def __init__(self, *, state: ConnectionState, data: models.User) -> None: + if isinstance(data, dict): + data = models.User(**data) + warnings.warn( + "Passing a dict to User is deprecated and will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(state=state, data=data) self._stored: bool = False diff --git a/requirements/_.txt b/requirements/_.txt index 5305a96bd1..851fe7caff 100644 --- a/requirements/_.txt +++ b/requirements/_.txt @@ -1,2 +1,3 @@ aiohttp>=3.6.0,<4.0 -typing_extensions>=4,<5; python_version < "3.11" +typing_extensions>=4.5.0,<5 +pydantic>=2.10.4,<3.0