diff --git a/docs/source/atproto/atproto_client.client.rst b/docs/source/atproto/atproto_client.client.rst index 3981b05b..99154ba8 100644 --- a/docs/source/atproto/atproto_client.client.rst +++ b/docs/source/atproto/atproto_client.client.rst @@ -22,7 +22,6 @@ Submodules atproto_client.client.async_client atproto_client.client.async_raw - atproto_client.client.auth atproto_client.client.base atproto_client.client.client atproto_client.client.raw diff --git a/docs/source/atproto/atproto_crypto.rst b/docs/source/atproto/atproto_crypto.rst new file mode 100644 index 00000000..28ae4580 --- /dev/null +++ b/docs/source/atproto/atproto_crypto.rst @@ -0,0 +1,15 @@ +atproto\_crypto +=============== + +.. automodule:: atproto_crypto + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + atproto_crypto.verify diff --git a/docs/source/atproto/atproto_crypto.verify.rst b/docs/source/atproto/atproto_crypto.verify.rst new file mode 100644 index 00000000..13e17038 --- /dev/null +++ b/docs/source/atproto/atproto_crypto.verify.rst @@ -0,0 +1,7 @@ +atproto\_crypto.verify +====================== + +.. automodule:: atproto_crypto.verify + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/atproto/atproto_server.auth.jwt.rst b/docs/source/atproto/atproto_server.auth.jwt.rst new file mode 100644 index 00000000..40d75b60 --- /dev/null +++ b/docs/source/atproto/atproto_server.auth.jwt.rst @@ -0,0 +1,7 @@ +atproto\_server.auth.jwt +======================== + +.. automodule:: atproto_server.auth.jwt + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/atproto/atproto_server.auth.rst b/docs/source/atproto/atproto_server.auth.rst new file mode 100644 index 00000000..50073bb7 --- /dev/null +++ b/docs/source/atproto/atproto_server.auth.rst @@ -0,0 +1,16 @@ +atproto\_server.auth +==================== + +.. automodule:: atproto_server.auth + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + atproto_server.auth.jwt + atproto_server.auth.utils diff --git a/docs/source/atproto/atproto_server.auth.utils.rst b/docs/source/atproto/atproto_server.auth.utils.rst new file mode 100644 index 00000000..75701181 --- /dev/null +++ b/docs/source/atproto/atproto_server.auth.utils.rst @@ -0,0 +1,7 @@ +atproto\_server.auth.utils +========================== + +.. automodule:: atproto_server.auth.utils + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/atproto/atproto_server.exceptions.rst b/docs/source/atproto/atproto_server.exceptions.rst new file mode 100644 index 00000000..cf092e5c --- /dev/null +++ b/docs/source/atproto/atproto_server.exceptions.rst @@ -0,0 +1,7 @@ +atproto\_server.exceptions +========================== + +.. automodule:: atproto_server.exceptions + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/atproto/atproto_server.rst b/docs/source/atproto/atproto_server.rst index 1962bf6f..f842a72f 100644 --- a/docs/source/atproto/atproto_server.rst +++ b/docs/source/atproto/atproto_server.rst @@ -5,3 +5,19 @@ atproto\_server :members: :undoc-members: :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + atproto_server.auth + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + atproto_server.exceptions diff --git a/docs/source/atproto/modules.rst b/docs/source/atproto/modules.rst index 1cb5a2e1..68d9a8fe 100644 --- a/docs/source/atproto/modules.rst +++ b/docs/source/atproto/modules.rst @@ -9,6 +9,7 @@ packages atproto_client atproto_codegen atproto_core + atproto_crypto atproto_firehose atproto_identity atproto_lexicon diff --git a/docs/source/atproto_server/auth.rst b/docs/source/atproto_server/auth.rst new file mode 100644 index 00000000..bebf4f28 --- /dev/null +++ b/docs/source/atproto_server/auth.rst @@ -0,0 +1,7 @@ +Auth +==== + +.. automodule:: atproto_server.auth.jwt + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/atproto_server/index.rst b/docs/source/atproto_server/index.rst new file mode 100644 index 00000000..f9056f25 --- /dev/null +++ b/docs/source/atproto_server/index.rst @@ -0,0 +1,12 @@ +Server +====== + +.. automodule:: atproto_server + :members: + :undoc-members: + :inherited-members: + +.. toctree:: + :maxdepth: 4 + + auth diff --git a/docs/source/exceptions.rst b/docs/source/exceptions.rst index c18f03df..fbe4a5f2 100644 --- a/docs/source/exceptions.rst +++ b/docs/source/exceptions.rst @@ -37,6 +37,14 @@ Identity :show-inheritance: +Server +###### + +.. automodule:: atproto_server.exceptions + :members: + :undoc-members: + :show-inheritance: + Lexicon ####### diff --git a/docs/source/index.rst b/docs/source/index.rst index 10c25797..8a8de8b5 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,6 +23,7 @@ Documentation namespace models atproto_identity/identity + atproto_server/index .. toctree:: :caption: Core diff --git a/packages/atproto/__init__.py b/packages/atproto/__init__.py index a23926d2..2c452c73 100644 --- a/packages/atproto/__init__.py +++ b/packages/atproto/__init__.py @@ -16,18 +16,30 @@ from atproto_firehose import models as firehose_models from atproto_identity.cache.in_memory_cache import AsyncDidInMemoryCache, DidInMemoryCache from atproto_identity.resolver import AsyncIdResolver, IdResolver +from atproto_server.auth.jwt import ( + JwtPayload, + decode_jwt_payload, + get_jwt_payload, + parse_jwt, + validate_jwt_payload, + verify_jwt, + verify_jwt_async, +) __all__ = [ + # client 'AsyncClient', 'Client', 'client_utils', 'models', + # core 'CAR', 'CID', 'CIDType', 'DidDocument', 'NSID', 'AtUri', + # firehose 'AsyncFirehoseSubscribeLabelsClient', 'AsyncFirehoseSubscribeReposClient', 'FirehoseSubscribeLabelsClient', @@ -35,8 +47,17 @@ 'parse_subscribe_labels_message', 'parse_subscribe_repos_message', 'firehose_models', + # identity 'AsyncDidInMemoryCache', 'DidInMemoryCache', 'AsyncIdResolver', 'IdResolver', + # server jwt + 'JwtPayload', + 'decode_jwt_payload', + 'get_jwt_payload', + 'parse_jwt', + 'validate_jwt_payload', + 'verify_jwt', + 'verify_jwt_async', ] diff --git a/packages/atproto/exceptions.py b/packages/atproto/exceptions.py index d5deb7af..c6f1e3f7 100644 --- a/packages/atproto/exceptions.py +++ b/packages/atproto/exceptions.py @@ -3,3 +3,4 @@ from atproto_firehose.exceptions import * from atproto_identity.exceptions import * from atproto_lexicon.exceptions import * +from atproto_server.exceptions import * diff --git a/packages/atproto_client/client/auth.py b/packages/atproto_client/client/auth.py deleted file mode 100644 index 97660dfd..00000000 --- a/packages/atproto_client/client/auth.py +++ /dev/null @@ -1,19 +0,0 @@ -import typing as t -from dataclasses import dataclass - -import jwt as _jwt - - -@dataclass -class JwtPayload: - exp: int # expired at - iat: int # created at - scope: str - sub: str # DID - aud: t.Optional[str] = None # DID - jti: t.Optional[str] = None # in refresh token only - - -def get_jwt_payload(token: str) -> JwtPayload: - plain_payload = _jwt.decode(token, options={'verify_signature': False}) - return JwtPayload(**plain_payload) diff --git a/packages/atproto_client/client/methods_mixin/session.py b/packages/atproto_client/client/methods_mixin/session.py index 6073ebc8..43def184 100644 --- a/packages/atproto_client/client/methods_mixin/session.py +++ b/packages/atproto_client/client/methods_mixin/session.py @@ -3,12 +3,12 @@ from datetime import timedelta import typing_extensions as te - -from atproto_client.client.auth import get_jwt_payload +from atproto_server.auth.jwt import get_jwt_payload if t.TYPE_CHECKING: + from atproto_server.auth.jwt import JwtPayload + from atproto_client import models - from atproto_client.client.auth import JwtPayload @dataclass diff --git a/packages/atproto_core/did_doc/did_doc.py b/packages/atproto_core/did_doc/did_doc.py index b6191309..bcc60c17 100644 --- a/packages/atproto_core/did_doc/did_doc.py +++ b/packages/atproto_core/did_doc/did_doc.py @@ -80,6 +80,7 @@ def get_service_endpoint(did_doc: 'DidDocument', id_: str, type_: str) -> t.Opti """Returns the service endpoint of the given DID document. Args: + did_doc: The DID document. id_: The service ID. type_: The service type. diff --git a/packages/atproto_crypto/__init__.py b/packages/atproto_crypto/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/atproto_crypto/verify.py b/packages/atproto_crypto/verify.py new file mode 100644 index 00000000..028b28e3 --- /dev/null +++ b/packages/atproto_crypto/verify.py @@ -0,0 +1,13 @@ +import typing as t +import warnings + + +def verify_signature(did_key: str, signing_input: t.Union[str, bytes], signature: t.Union[str, bytes]) -> bool: + # TODO(MarshalX): implement + warnings.warn( + 'verify_signature is not implemented yet. Do not trust to this signing_input', + RuntimeWarning, + stacklevel=0, + ) + + return True diff --git a/packages/atproto_server/__init__.py b/packages/atproto_server/__init__.py index 35ed5aac..e69de29b 100644 --- a/packages/atproto_server/__init__.py +++ b/packages/atproto_server/__init__.py @@ -1 +0,0 @@ -# TODO(MarshalX): When ATProto will stable diff --git a/packages/atproto_server/auth/__init__.py b/packages/atproto_server/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/atproto_server/auth/jwt.py b/packages/atproto_server/auth/jwt.py new file mode 100644 index 00000000..87178071 --- /dev/null +++ b/packages/atproto_server/auth/jwt.py @@ -0,0 +1,262 @@ +import binascii +import json +import typing as t +from dataclasses import dataclass +from datetime import datetime, timezone + +from atproto_crypto.verify import verify_signature + +from atproto_server.auth.utils import base64url_decode +from atproto_server.exceptions import ( + TokenDecodeError, + TokenExpiredSignatureError, + TokenImmatureSignatureError, + TokenInvalidAudienceError, + TokenInvalidIssuedAtError, + TokenInvalidSignatureError, +) + +GetSigningKeyCallback = t.Callable[[str, bool], str] +GetSigningKeyCallbackAsync = t.Callable[[str, bool], t.Coroutine[t.Any, t.Any, str]] + + +@dataclass +class JwtPayload: + """The payload of the JWT.""" + + # TODO(MarshalX): separate service and user tokens to different classes? + exp: int # expired at + iat: t.Optional[int] = None # created at + scope: t.Optional[str] = None + sub: t.Optional[str] = None # DID + iss: t.Optional[str] = None # DID. In service token + aud: t.Optional[str] = None # DID + jti: t.Optional[str] = None # in refresh token only + + +def parse_jwt(jwt: t.Union[str, bytes]) -> t.Tuple[bytes, bytes, t.Dict[str, t.Any], bytes]: + """Parse the given JWT. + + Args: + jwt: The JWT to parse. + + Returns: + :obj:`tuple` of :obj:`bytes`, :obj:`bytes`, :obj:`dict`, :obj:`bytes`: + The parsed JWT: payload, signing input, header, signature. + """ + if isinstance(jwt, str): + jwt = jwt.encode('UTF-8') + + if not isinstance(jwt, bytes): + raise TokenDecodeError(f'Invalid token type. Token must be a {bytes}') + + try: + signing_input, crypto_segment = jwt.rsplit(b'.', 1) + header_segment, payload_segment = signing_input.split(b'.', 1) + except ValueError as e: + raise TokenDecodeError('Not enough segments') from e + + try: + header_data = base64url_decode(header_segment) + except (TypeError, binascii.Error) as e: + raise TokenDecodeError('Invalid header padding') from e + + try: + header = json.loads(header_data) + except ValueError as e: + raise TokenDecodeError(f'Invalid header string: {e}') from e + + if not isinstance(header, dict): + raise TokenDecodeError('Invalid header string: must be a json object') + + try: + payload = base64url_decode(payload_segment) + except (TypeError, binascii.Error) as e: + raise TokenDecodeError('Invalid payload padding') from e + + try: + signature = base64url_decode(crypto_segment) + except (TypeError, binascii.Error) as e: + raise TokenDecodeError('Invalid crypto padding') from e + + return payload, signing_input, header, signature + + +def decode_jwt_payload(payload: t.Union[str, bytes]) -> JwtPayload: + """Decode the given JWT payload. + + Args: + payload: The JWT payload to decode. + + Returns: + :obj:`JwtPayload`: The decoded payload of the given JWT. + """ + try: + plain_payload = json.loads(payload) + except ValueError as e: + raise TokenDecodeError(f'Invalid payload string: {e}') from e + if not isinstance(plain_payload, dict): + raise TokenDecodeError('Invalid payload string: must be a json object') + + return JwtPayload(**plain_payload) + + +def get_jwt_payload(jwt: str) -> JwtPayload: + """Return the payload of the given JWT. + + Args: + jwt: The JWT to get the payload from. + + Returns: + :obj:`JwtPayload`: The payload of the given JWT. + """ + payload, *_ = parse_jwt(jwt) + return decode_jwt_payload(payload) + + +def _validate_exp( + exp: int, + now: float, + leeway: float, +) -> None: + try: + exp = int(exp) + except ValueError as e: + raise TokenDecodeError('Expiration Time claim (exp) must be an integer.') from e + + if exp <= (now - leeway): + raise TokenExpiredSignatureError('Signature has expired') + + +def _validate_iat( + iat: int, + now: float, + leeway: float, +) -> None: + try: + iat = int(iat) + except ValueError as e: + raise TokenInvalidIssuedAtError('Issued At claim (iat) must be an integer.') from e + if iat > (now + leeway): + raise TokenImmatureSignatureError('The token is not yet valid (iat)') + + +def validate_jwt_payload(payload: JwtPayload, leeway: int = 0) -> None: + """Validate the given JWT payload. + + Args: + payload: The JWT payload to validate. + leeway: The leeway in seconds to accept when verifying time claims (exp, iat). + + Returns: + :obj:`None`: The payload is valid. + + Raises: + TokenDecodeError: If the given JWT is invalid. + TokenExpiredSignatureError: If the given JWT is expired. + TokenImmatureSignatureError: If the given JWT is immature. + TokenInvalidIssuedAtError: If the given JWT has invalid issued at. + """ + now = datetime.now(tz=timezone.utc).timestamp() + + if payload.exp is not None: + _validate_exp(payload.exp, now, leeway) + if payload.iat is not None: + _validate_iat(payload.iat, now, leeway) + + +def _verify_signature(signing_key: str, signing_input: bytes, signature: bytes) -> bool: + try: + return verify_signature(signing_key, signing_input, signature) + except Exception as e: # noqa: BLE001 + raise TokenInvalidSignatureError('Could not verify JWT signature') from e + + +def verify_jwt( + jwt: str, get_signing_key_callback: GetSigningKeyCallback, own_did: t.Optional[str] = None +) -> JwtPayload: + """Verify the given JWT. + + Args: + jwt: The JWT to verify. + get_signing_key_callback: The callback to get the signing key. + own_did: The DID of the service (aud). + + Returns: + :obj:`JwtPayload`: The payload of the given JWT. + + Raises: + TokenDecodeError: If the given JWT is invalid. + TokenExpiredSignatureError: If the given JWT is expired. + TokenImmatureSignatureError: If the given JWT is immature. + TokenInvalidAudienceError: If the given JWT has invalid audience. + TokenInvalidIssuedAtError: If the given JWT has invalid issued at. + TokenInvalidSignatureError: If the given JWT has invalid signature. + """ + plain_payload, signing_input, header, signature = parse_jwt(jwt) + + payload = decode_jwt_payload(plain_payload) + validate_jwt_payload(payload) + + if own_did and payload.aud != own_did: + raise TokenInvalidAudienceError('Invalid subject') + + signing_key = get_signing_key_callback(payload.iss, False) + if _verify_signature(signing_key, signing_input, signature): + return payload + + fresh_signing_key = get_signing_key_callback(payload.iss, True) # get signing key without a cache + if fresh_signing_key == signing_key: + raise TokenInvalidSignatureError('Could not verify JWT signature. Fresh signing key is equal to the old one') + + if _verify_signature(fresh_signing_key, signing_input, signature): + return payload + + # this code should be unreachable + # verifying methods must raise exception before + raise TokenInvalidSignatureError('Invalid signature') + + +async def verify_jwt_async( + jwt: str, get_signing_key_callback: GetSigningKeyCallbackAsync, own_did: t.Optional[str] = None +) -> JwtPayload: + """Asynchronously verifies the given JWT. + + Args: + jwt: The JWT to verify. + get_signing_key_callback: The callback to get the signing key. + own_did: The DID of the service (aud). + + Returns: + :obj:`JwtPayload`: The payload of the given JWT. + + Raises: + TokenDecodeError: If the given JWT is invalid. + TokenExpiredSignatureError: If the given JWT is expired. + TokenImmatureSignatureError: If the given JWT is immature. + TokenInvalidAudienceError: If the given JWT has invalid audience. + TokenInvalidIssuedAtError: If the given JWT has invalid issued at. + TokenInvalidSignatureError: If the given JWT has invalid signature. + """ + plain_payload, signing_input, header, signature = parse_jwt(jwt) + + payload = decode_jwt_payload(plain_payload) + validate_jwt_payload(payload) + + if own_did and payload.aud != own_did: + raise TokenInvalidAudienceError('Invalid subject') + + signing_key = await get_signing_key_callback(payload.iss, False) + if _verify_signature(signing_key, signing_input, signature): + return payload + + fresh_signing_key = await get_signing_key_callback(payload.iss, True) # get signing key without a cache + if fresh_signing_key == signing_key: + raise TokenInvalidSignatureError('Could not verify JWT signature. Fresh signing key is equal to the old one') + + if _verify_signature(fresh_signing_key, signing_input, signature): + return payload + + # this code should be unreachable + # verifying methods must raise exception before + raise TokenInvalidSignatureError('Invalid signature') diff --git a/packages/atproto_server/auth/utils.py b/packages/atproto_server/auth/utils.py new file mode 100644 index 00000000..50e28f81 --- /dev/null +++ b/packages/atproto_server/auth/utils.py @@ -0,0 +1,26 @@ +import base64 +import typing as t + + +def force_bytes(value: t.Union[bytes, str]) -> bytes: + if isinstance(value, str): + return value.encode('UTF-8') + if isinstance(value, bytes): + return value + + raise TypeError('Expected a string value') + + +def base64url_decode(input_data: t.Union[bytes, str]) -> bytes: + input_bytes = force_bytes(input_data) + + rem = len(input_bytes) % 4 + + if rem > 0: + input_bytes += b'=' * (4 - rem) + + return base64.urlsafe_b64decode(input_bytes) + + +def base64url_encode(input_data: bytes) -> bytes: + return base64.urlsafe_b64encode(input_data).replace(b'=', b'') diff --git a/packages/atproto_server/exceptions.py b/packages/atproto_server/exceptions.py new file mode 100644 index 00000000..a1aa01af --- /dev/null +++ b/packages/atproto_server/exceptions.py @@ -0,0 +1,29 @@ +from atproto_core.exceptions import AtProtocolError + + +class InvalidTokenError(AtProtocolError): + pass + + +class TokenDecodeError(InvalidTokenError): + pass + + +class TokenInvalidSignatureError(TokenDecodeError): + pass + + +class TokenInvalidAudienceError(InvalidTokenError): + pass + + +class TokenExpiredSignatureError(InvalidTokenError): + pass + + +class TokenInvalidIssuedAtError(InvalidTokenError): + pass + + +class TokenImmatureSignatureError(InvalidTokenError): + pass diff --git a/poetry.lock b/poetry.lock index f90edd46..84cf39a1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1009,26 +1009,6 @@ files = [ plugins = ["importlib-metadata"] windows-terminal = ["colorama (>=0.4.6)"] -[[package]] -name = "pyjwt" -version = "2.8.0" -description = "JSON Web Token implementation in Python" -optional = false -python-versions = ">=3.7" -files = [ - {file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"}, - {file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"}, -] - -[package.dependencies] -typing-extensions = {version = "*", markers = "python_version <= \"3.7\""} - -[package.extras] -crypto = ["cryptography (>=3.4.0)"] -dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] -docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] -tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] - [[package]] name = "pytest" version = "7.3.2" @@ -1603,4 +1583,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = ">=3.7.1,<3.13" -content-hash = "3f3d04d2fdff122758173acf448b5d3a721efe6c23f06b3bf529482969767e33" +content-hash = "7c80756ebeafec89f395a6c2e93c5cc15eeb793d974bbebac533708c6bab4a3b" diff --git a/pyproject.toml b/pyproject.toml index 6e76463c..2c7c4f73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ packages = [ { include = "atproto_client", from = "packages" }, { include = "atproto_codegen", from = "packages" }, { include = "atproto_core", from = "packages" }, + { include = "atproto_crypto", from = "packages" }, { include = "atproto_firehose", from = "packages" }, { include = "atproto_identity", from = "packages" }, { include = "atproto_lexicon", from = "packages" }, @@ -51,7 +52,6 @@ atproto = "atproto_cli:atproto_cli" python = ">=3.7.1,<3.13" httpx = ">=0.24.0,<0.26.0" typing-extensions = ">=4.6.1,<5" -pyjwt = ">=2.7.0,<3" click = ">=8.1.3,<9" websockets = ">=11.0.3,<13" pydantic = ">=2.0,<3.0" diff --git a/tests/test_atproto_server/__init__.py b/tests/test_atproto_server/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_atproto_server/auth/__init__.py b/tests/test_atproto_server/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_atproto_server/auth/test_jwt.py b/tests/test_atproto_server/auth/test_jwt.py new file mode 100644 index 00000000..26124c8a --- /dev/null +++ b/tests/test_atproto_server/auth/test_jwt.py @@ -0,0 +1,78 @@ +import pytest +from atproto_server.auth.jwt import get_jwt_payload, parse_jwt, validate_jwt_payload, verify_jwt, verify_jwt_async +from atproto_server.exceptions import TokenDecodeError, TokenExpiredSignatureError, TokenInvalidAudienceError + +_TEST_JWT_EXPIRED = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NksifQ.eyJpc3MiOiJkaWQ6cGxjOmt2d3ZjbjVpcWZvb29wbXl6dmI0cXpiYSIsImF1ZCI6ImRpZDp3ZWI6ZmVlZC5hdHByb3RvLmJsdWUiLCJleHAiOjE3MDQ4NDExMzh9.50SlT6vw26HsDXVDM4D2D53_Dvzd6bjp3TDc5EyDVD4ob9i3EEB7fmaKE0XR4egMS9Kf9eMdVqH5gJNCaIah4Q' # noqa: E501 +# exp in 2033 +_TEST_JWT_INVALID_SIGN = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NksifQ.eyJpc3MiOiJkaWQ6cGxjOmt2d3ZjbjVpcWZvb29wbXl6dmI0cXpiYSIsImF1ZCI6ImRpZDp3ZWI6ZmVlZC5hdHByb3RvLmJsdWUiLCJleHAiOjIwMDAwMDAwMDB9.50SlT6vw26HsDXVDM4D2D53_Dvzd6bjp3TDc5EyDVD4ob9i3EEB7fmaKE0XR4egMS9Kf9eMdVqH5gJNCaIah4Q' # noqa: E501 + + +def test_parse_jwt_empty() -> None: + with pytest.raises(TokenDecodeError): + parse_jwt('') + + +def test_parse_jwt() -> None: + payload, signing_input, header, signature = parse_jwt(_TEST_JWT_EXPIRED) + + assert header['typ'] == 'JWT' + assert header['alg'] == 'ES256K' + + assert payload + assert signing_input + assert len(signature) == 64 + + +def test_get_jwt_payload() -> None: + payload = get_jwt_payload(_TEST_JWT_EXPIRED) + assert payload.iss == 'did:plc:kvwvcn5iqfooopmyzvb4qzba' + assert payload.aud == 'did:web:feed.atproto.blue' + assert payload.exp == 1704841138 + + +def test_validate_jwt_payload_expired() -> None: + payload = get_jwt_payload(_TEST_JWT_EXPIRED) + with pytest.raises(TokenExpiredSignatureError): + validate_jwt_payload(payload) + + +def test_validate_jwt_payload_valid() -> None: + payload = get_jwt_payload(_TEST_JWT_INVALID_SIGN) + validate_jwt_payload(payload) + + +def test_verify_jwt() -> None: + expected_iss = 'did:plc:kvwvcn5iqfooopmyzvb4qzba' + expected_aud = 'did:web:feed.atproto.blue' + + def get_signing_key(iss: str, force_refresh: bool) -> str: + assert iss == expected_iss + + if force_refresh: + return 'refreshedKey' + return 'key' + + verify_jwt(_TEST_JWT_INVALID_SIGN, get_signing_key) + verify_jwt(_TEST_JWT_INVALID_SIGN, get_signing_key, expected_aud) + + with pytest.raises(TokenInvalidAudienceError): + verify_jwt(_TEST_JWT_INVALID_SIGN, get_signing_key, 'blabla') + + +@pytest.mark.asyncio +async def test_verify_jwt_async() -> None: + expected_iss = 'did:plc:kvwvcn5iqfooopmyzvb4qzba' + expected_aud = 'did:web:feed.atproto.blue' + + async def get_signing_key(iss: str, force_refresh: bool) -> str: + assert iss == expected_iss + + if force_refresh: + return 'refreshedKey' + return 'key' + + await verify_jwt_async(_TEST_JWT_INVALID_SIGN, get_signing_key) + await verify_jwt_async(_TEST_JWT_INVALID_SIGN, get_signing_key, expected_aud) + + with pytest.raises(TokenInvalidAudienceError): + await verify_jwt_async(_TEST_JWT_INVALID_SIGN, get_signing_key, 'blabla')