Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pydantic #2675

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ These changes are available on the `master` branch, but have not yet been releas
- Replaced audioop (deprecated module) implementation of `PCMVolumeTransformer.read`
method with a pure Python equivalent.
([#2176](https://github.com/Pycord-Development/pycord/pull/2176))
- Updated `Guild.filesize_limit` to 10 Mb instead of 25 Mb following Discord's API
changes. ([#2671](https://github.com/Pycord-Development/pycord/pull/2671))

### Deprecated

Expand Down
75 changes: 52 additions & 23 deletions discord/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from typing import TYPE_CHECKING, Any, Literal

import yarl
from typing_extensions import Final, override

from . import utils
from .errors import DiscordException, InvalidArgument
Expand All @@ -39,6 +40,7 @@
if TYPE_CHECKING:
ValidStaticFormatTypes = Literal["webp", "jpeg", "jpg", "png"]
ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"]
from .state import ConnectionState

VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"})
VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"}
Expand All @@ -49,7 +51,7 @@

class AssetMixin:
url: str
_state: Any | None
_state: ConnectionState | None

async def read(self) -> bytes:
"""|coro|
Expand Down Expand Up @@ -77,7 +79,9 @@ async def read(self) -> bytes:

async def save(
self,
fp: str | bytes | os.PathLike | io.BufferedIOBase,
fp: (
str | bytes | os.PathLike | io.BufferedIOBase
), # pyright: ignore [reportMissingTypeArgument]
*,
seek_begin: bool = True,
) -> int:
Expand Down Expand Up @@ -117,7 +121,7 @@ async def save(
fp.seek(0)
return written
else:
with open(fp, "wb") as f:
with open(fp, "wb") as f: # pyright: ignore [reportUnknownArgumentType]
return f.write(data)


Expand Down Expand Up @@ -154,16 +158,23 @@ class Asset(AssetMixin):
"_key",
)

BASE = "https://cdn.discordapp.com"
BASE: Final = "https://cdn.discordapp.com"

def __init__(self, state, *, url: str, key: str, animated: bool = False):
self._state = state
self._url = url
self._animated = animated
self._key = key
def __init__(
self,
state: ConnectionState | None,
*,
url: str,
key: str,
animated: bool = False,
):
self._state: ConnectionState | None = state
self._url: str = url
self._animated: bool = animated
self._key: str = key

@classmethod
def _from_default_avatar(cls, state, index: int) -> Asset:
def _from_default_avatar(cls, state: ConnectionState, index: int) -> Asset:
return cls(
state,
url=f"{cls.BASE}/embed/avatars/{index}.png",
Expand All @@ -172,7 +183,7 @@ def _from_default_avatar(cls, state, index: int) -> Asset:
)

@classmethod
def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset:
def _from_avatar(cls, state: ConnectionState, user_id: int, avatar: str) -> Asset:
animated = avatar.startswith("a_")
format = "gif" if animated else "png"
return cls(
Expand All @@ -184,7 +195,10 @@ def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset:

@classmethod
def _from_avatar_decoration(
cls, state, user_id: int, avatar_decoration: str
cls,
state: ConnectionState,
user_id: int,
avatar_decoration: str, # pyright: ignore [reportUnusedParameter]
) -> Asset:
animated = avatar_decoration.startswith("a_")
endpoint = (
Expand All @@ -201,7 +215,7 @@ def _from_avatar_decoration(

@classmethod
def _from_guild_avatar(
cls, state, guild_id: int, member_id: int, avatar: str
cls, state: ConnectionState, guild_id: int, member_id: int, avatar: str
) -> Asset:
animated = avatar.startswith("a_")
format = "gif" if animated else "png"
Expand All @@ -214,7 +228,7 @@ def _from_guild_avatar(

@classmethod
def _from_guild_banner(
cls, state, guild_id: int, member_id: int, banner: str
cls, state: ConnectionState, guild_id: int, member_id: int, banner: str
) -> Asset:
animated = banner.startswith("a_")
format = "gif" if animated else "png"
Expand All @@ -226,7 +240,9 @@ def _from_guild_banner(
)

@classmethod
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
def _from_icon(
cls, state: ConnectionState, object_id: int, icon_hash: str, path: str
) -> Asset:
return cls(
state,
url=f"{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024",
Expand All @@ -235,7 +251,9 @@ def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
)

@classmethod
def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset:
def _from_cover_image(
cls, state: ConnectionState, object_id: int, cover_image_hash: str
) -> Asset:
return cls(
state,
url=f"{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024",
Expand All @@ -244,7 +262,9 @@ def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asse
)

@classmethod
def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset:
def _from_guild_image(
cls, state: ConnectionState, guild_id: int, image: str, path: str
) -> Asset:
animated = False
format = "png"
if path == "banners":
Expand All @@ -259,7 +279,9 @@ def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset
)

@classmethod
def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset:
def _from_guild_icon(
cls, state: ConnectionState, guild_id: int, icon_hash: str
) -> Asset:
animated = icon_hash.startswith("a_")
format = "gif" if animated else "png"
return cls(
Expand All @@ -270,7 +292,7 @@ def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset:
)

@classmethod
def _from_sticker_banner(cls, state, banner: int) -> Asset:
def _from_sticker_banner(cls, state: ConnectionState, banner: int) -> Asset:
return cls(
state,
url=f"{cls.BASE}/app-assets/710982414301790216/store/{banner}.png",
Expand All @@ -279,7 +301,9 @@ def _from_sticker_banner(cls, state, banner: int) -> Asset:
)

@classmethod
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:
def _from_user_banner(
cls, state: ConnectionState, user_id: int, banner_hash: str
) -> Asset:
animated = banner_hash.startswith("a_")
format = "gif" if animated else "png"
return cls(
Expand All @@ -291,7 +315,7 @@ def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:

@classmethod
def _from_scheduled_event_image(
cls, state, event_id: int, cover_hash: str
cls, state: ConnectionState, event_id: int, cover_hash: str
) -> Asset:
return cls(
state,
Expand All @@ -300,24 +324,29 @@ def _from_scheduled_event_image(
animated=False,
)

@override
def __str__(self) -> str:
return self._url

def __len__(self) -> int:
return len(self._url)

@override
def __repr__(self):
shorten = self._url.replace(self.BASE, "")
return f"<Asset url={shorten!r}>"

def __eq__(self, other):
@override
def __eq__(self, other: Any): # pyright: ignore [reportExplicitAny]
return isinstance(other, Asset) and self._url == other._url

@override
def __hash__(self):
return hash(self._url)

@property
def url(self) -> str:
@override
def url(self) -> str: # pyright: ignore [reportIncompatibleVariableOverride]
"""Returns the underlying URL of the asset."""
return self._url

Expand Down
4 changes: 2 additions & 2 deletions discord/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

import aiohttp

from . import utils
from . import models, utils
from .activity import ActivityTypes, BaseActivity, create_activity
from .appinfo import AppInfo, PartialAppInfo
from .application_role_connection import ApplicationRoleConnectionMetadata
Expand Down Expand Up @@ -1840,7 +1840,7 @@ async def fetch_user(self, user_id: int, /) -> User:
:exc:`HTTPException`
Fetching the user failed.
"""
data = await self.http.get_user(user_id)
data: models.User = await self.http.get_user(user_id)
return User(state=self._connection, data=data)

async def fetch_channel(
Expand Down
17 changes: 15 additions & 2 deletions discord/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@
import traceback
import zlib
from collections import deque, namedtuple
from typing import Any

import aiohttp
from pydantic import BaseModel

from discord import models

from . import utils
from .activity import BaseActivity
Expand Down Expand Up @@ -548,11 +552,20 @@ async def received_message(self, msg, /):
)

try:
func = self._discord_parsers[event]
func: Any = self._discord_parsers[event]
except KeyError:
_log.debug("Unknown event %s.", event)
else:
func(data)
if hasattr(func, "_supports_model") and issubclass(
func._supports_model, models.gateway.GatewayEvent
):
func(
func._supports_model(
**msg
).d # pyright: ignore [reportUnknownMemberType, reportAttributeAccessIssue]
)
else:
func(data)

# remove the dispatched listeners
removed = []
Expand Down
16 changes: 8 additions & 8 deletions discord/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
overload,
)

from . import abc, utils
from . import abc, models, utils
from .asset import Asset
from .automod import AutoModAction, AutoModRule, AutoModTriggerMetadata
from .channel import *
Expand Down Expand Up @@ -289,11 +289,11 @@ class Guild(Hashable):
)

_PREMIUM_GUILD_LIMITS: ClassVar[dict[int | None, _GuildLimit]] = {
None: _GuildLimit(emoji=50, stickers=5, bitrate=96e3, filesize=26214400),
0: _GuildLimit(emoji=50, stickers=5, bitrate=96e3, filesize=26214400),
1: _GuildLimit(emoji=100, stickers=15, bitrate=128e3, filesize=26214400),
2: _GuildLimit(emoji=150, stickers=30, bitrate=256e3, filesize=52428800),
3: _GuildLimit(emoji=250, stickers=60, bitrate=384e3, filesize=104857600),
None: _GuildLimit(emoji=50, stickers=5, bitrate=96e3, filesize=10_485_760),
0: _GuildLimit(emoji=50, stickers=5, bitrate=96e3, filesize=10_485_760),
1: _GuildLimit(emoji=100, stickers=15, bitrate=128e3, filesize=10_485_760),
2: _GuildLimit(emoji=150, stickers=30, bitrate=256e3, filesize=52_428_800),
3: _GuildLimit(emoji=250, stickers=60, bitrate=384e3, filesize=104_857_600),
}

def __init__(self, *, data: GuildPayload, state: ConnectionState):
Expand Down Expand Up @@ -2114,9 +2114,9 @@ async def fetch_ban(self, user: Snowflake) -> BanEntry:
HTTPException
An error occurred while fetching the information.
"""
data: BanPayload = await self._state.http.get_ban(user.id, self.id)
data: models.Ban = await self._state.http.get_ban(user.id, self.id)
return BanEntry(
user=User(state=self._state, data=data["user"]), reason=data["reason"]
user=User(state=self._state, data=data.user), reason=data.reason
)

async def fetch_channel(self, channel_id: int, /) -> GuildChannel | Thread:
Expand Down
Loading
Loading