From db72d789e5f9ed65dcb13b167111fb633d9472ff Mon Sep 17 00:00:00 2001 From: "Ilya (Marshal)" Date: Fri, 26 Jul 2024 12:46:00 +0200 Subject: [PATCH 1/3] Use jiter instead of built-in json module to improve performance --- packages/atproto_client/models/utils.py | 10 ++--- packages/atproto_client/request.py | 4 +- packages/atproto_lexicon/parser.py | 5 ++- packages/atproto_server/auth/jwt.py | 8 ++-- .../models/fetch_test_data.py | 4 +- .../models/test_data/feed_record.json | 2 +- .../models/test_data/get_follows.json | 41 ++++++++++++------- .../thread_view_post_with_embed_media.json | 10 +++-- .../test_atproto_client/models/tests/utils.py | 5 ++- tests/test_atproto_crypto/test_verify.py | 4 +- 10 files changed, 56 insertions(+), 37 deletions(-) diff --git a/packages/atproto_client/models/utils.py b/packages/atproto_client/models/utils.py index 1c51a8e5..ce5cd903 100644 --- a/packages/atproto_client/models/utils.py +++ b/packages/atproto_client/models/utils.py @@ -1,9 +1,9 @@ -import json import types import typing as t import typing_extensions as te from pydantic import ValidationError +from pydantic_core import from_json, to_json from atproto_client import models from atproto_client.exceptions import ( @@ -126,19 +126,19 @@ def get_model_as_dict(model: t.Union[DotDict, BlobRef, ModelBase]) -> t.Dict[str def get_model_as_json(model: t.Union[DotDict, BlobRef, ModelBase]) -> str: if isinstance(model, DotDict): - return json.dumps(get_model_as_dict(model)) + return to_json(get_model_as_dict(model)).decode('UTF-8') return model.model_dump_json(exclude_none=True, by_alias=True) def is_json(json_data: t.Union[str, bytes]) -> bool: if isinstance(json_data, bytes): - json_data.decode('UTF-8') + json_data.decode('UTF-8', errors='ignore') try: - json.loads(json_data) + from_json(json_data) return True - except: # noqa + except ValueError: return False diff --git a/packages/atproto_client/request.py b/packages/atproto_client/request.py index 2e2e602d..40345125 100644 --- a/packages/atproto_client/request.py +++ b/packages/atproto_client/request.py @@ -1,9 +1,9 @@ -import json import typing as t from dataclasses import dataclass import httpx import typing_extensions as te +from pydantic_core import from_json from atproto_client import exceptions from atproto_client.models.common import XrpcError @@ -66,7 +66,7 @@ def _handle_response(response: httpx.Response) -> httpx.Response: headers=_convert_headers_to_dict(response.headers), ) if response.content and is_json(response.content): - data: t.Dict[str, t.Any] = json.loads(response.content) + data: t.Dict[str, t.Any] = from_json(response.content) error_response.content = t.cast(XrpcError, get_or_create(data, XrpcError, strict=False)) if response.status_code in {401, 403}: diff --git a/packages/atproto_lexicon/parser.py b/packages/atproto_lexicon/parser.py index 3393194d..495bad87 100644 --- a/packages/atproto_lexicon/parser.py +++ b/packages/atproto_lexicon/parser.py @@ -1,8 +1,9 @@ -import json import os import typing as t from pathlib import Path +from pydantic_core import from_json + from atproto_lexicon import models from atproto_lexicon.exceptions import LexiconParsingError @@ -19,7 +20,7 @@ def lexicon_parse(data: dict, model_class: t.Optional[t.Type[L]] = models.Lexico def lexicon_parse_file(lexicon_path: t.Union[Path, str], *, soft_fail: bool = False) -> t.Optional[models.LexiconDoc]: try: with open(lexicon_path, encoding='UTF-8') as f: - plain_lexicon = json.loads(f.read()) + plain_lexicon = from_json(f.read()) return lexicon_parse(plain_lexicon) except Exception as e: # noqa: BLE001 if soft_fail: diff --git a/packages/atproto_server/auth/jwt.py b/packages/atproto_server/auth/jwt.py index d1da59b0..de556eba 100644 --- a/packages/atproto_server/auth/jwt.py +++ b/packages/atproto_server/auth/jwt.py @@ -1,10 +1,10 @@ import binascii -import json import typing as t from datetime import datetime, timezone from atproto_crypto.verify import verify_signature from pydantic import BaseModel, ConfigDict +from pydantic_core import from_json from atproto_server.auth.utils import base64url_decode from atproto_server.exceptions import ( @@ -64,14 +64,14 @@ def parse_jwt(jwt: t.Union[str, bytes]) -> t.Tuple[bytes, bytes, t.Dict[str, t.A raise TokenDecodeError('Invalid header padding') from e try: - header = json.loads(header_data) + header = from_json(header_data) except ValueError as e: raise TokenDecodeError(f'Invalid header string: {e}') from e if not isinstance(header, dict): raise TokenDecodeError('Invalid header string: must be a json object') - header = t.cast(t.Dict[str, t.Any], json.loads(header_data)) # we expect object in header + header = t.cast(t.Dict[str, t.Any], from_json(header_data)) # we expect an object in header try: payload = base64url_decode(payload_segment) @@ -96,7 +96,7 @@ def decode_jwt_payload(payload: t.Union[str, bytes]) -> JwtPayload: :obj:`JwtPayload`: The decoded payload of the given JWT. """ try: - plain_payload = json.loads(payload) + plain_payload = from_json(payload) except ValueError as e: raise TokenDecodeError(f'Invalid payload string: {e}') from e if not isinstance(plain_payload, dict): diff --git a/tests/test_atproto_client/models/fetch_test_data.py b/tests/test_atproto_client/models/fetch_test_data.py index eddfc872..027a6e70 100644 --- a/tests/test_atproto_client/models/fetch_test_data.py +++ b/tests/test_atproto_client/models/fetch_test_data.py @@ -1,10 +1,10 @@ -import json import logging import os import typing as t from atproto_client import Client from atproto_client.request import Request, Response +from pydantic_core import to_json if t.TYPE_CHECKING: from atproto_client.models.common import XrpcError @@ -73,7 +73,7 @@ def get_unique_filename(name: str) -> str: def get_pretty_json(data: dict) -> str: - return json.dumps(data, indent=4) + return to_json(data, indent=4).decode('UTF-8') def save_response(name: str) -> None: diff --git a/tests/test_atproto_client/models/test_data/feed_record.json b/tests/test_atproto_client/models/test_data/feed_record.json index b65f2b1b..a4772c9d 100644 --- a/tests/test_atproto_client/models/test_data/feed_record.json +++ b/tests/test_atproto_client/models/test_data/feed_record.json @@ -2,7 +2,6 @@ "uri": "at://did:plc:s6jnht6koorxz7trghirytmf/app.bsky.feed.generator/atproto", "cid": "bafyreifiexhek65jxj3ucz6y6zstj45rmtigybzygjkg6lretyqgtge5ai", "value": { - "did": "did:web:feed.atproto.blue", "$type": "app.bsky.feed.generator", "avatar": { "$type": "blob", @@ -14,6 +13,7 @@ }, "createdAt": "2023-07-20T10:17:40.298101", "description": "Posts related to the protocol. Powered by The AT Protocol SDK for Python", + "did": "did:web:feed.atproto.blue", "displayName": "AT Protocol" } } \ No newline at end of file diff --git a/tests/test_atproto_client/models/test_data/get_follows.json b/tests/test_atproto_client/models/test_data/get_follows.json index c49d2631..df045303 100644 --- a/tests/test_atproto_client/models/test_data/get_follows.json +++ b/tests/test_atproto_client/models/test_data/get_follows.json @@ -7,7 +7,7 @@ "avatar": "https://cdn.bsky.app/img/avatar/plain/did:plc:vpkhqolt662uhesyj6nxm7ys/bafkreietenc2aywhzxdyvalgffxl35tvxri6wv2iwryfdcdoecdfg553fy@jpeg", "associated": { "chat": { - "allowIncoming": "following" + "allowIncoming": "all" } }, "viewer": { @@ -16,17 +16,18 @@ "following": "at://did:plc:kvwvcn5iqfooopmyzvb4qzba/app.bsky.graph.follow/3jueqtsliw22n" }, "labels": [], - "description": "Technical advisor to @bluesky, first engineer at Protocol Labs. Wizard Utopian", - "indexedAt": "2024-05-14T15:42:22.889Z" + "createdAt": "2022-11-17T01:04:43.624Z", + "description": "Technical advisor to @bluesky, first engineer at Protocol Labs. Wizard Utopian.", + "indexedAt": "2024-06-07T03:19:53.642Z" }, { "did": "did:plc:l3rouwludahu3ui3bt66mfvj", "handle": "divy.zone", - "displayName": "devin ivy \ud83d\udc0b", + "displayName": "devin ivy 🐋", "avatar": "https://cdn.bsky.app/img/avatar/plain/did:plc:l3rouwludahu3ui3bt66mfvj/bafkreicg6y3mlr3eszmbjm3swyuncwf4ruzohcsvtcbrjzyhkwibtx7nyy@jpeg", "associated": { "chat": { - "allowIncoming": "following" + "allowIncoming": "all" } }, "viewer": { @@ -36,6 +37,7 @@ "cid": "bafyreigk3kmjipz5emav5vqmvdtivwz757tfpg5lnsbfoscnqx7wpjjime", "name": "test mute list", "purpose": "app.bsky.graph.defs#modlist", + "listItemCount": 1, "indexedAt": "2023-08-28T10:08:27.442Z", "labels": [], "viewer": { @@ -46,13 +48,14 @@ "following": "at://did:plc:kvwvcn5iqfooopmyzvb4qzba/app.bsky.graph.follow/3jueqt6dbqs2g" }, "labels": [], - "description": "\ud83c\udf00 bluesky team", - "indexedAt": "2024-03-08T04:03:32.618Z" + "createdAt": "2022-11-17T00:39:19.084Z", + "description": "🌀 bluesky team", + "indexedAt": "2024-06-14T20:22:02.642Z" }, { "did": "did:plc:oky5czdrnfjpqslsw2a5iclo", "handle": "jay.bsky.team", - "displayName": "Jay \ud83e\udd8b", + "displayName": "Jay 🦋", "avatar": "https://cdn.bsky.app/img/avatar/plain/did:plc:oky5czdrnfjpqslsw2a5iclo/bafkreihidru2xruxdxlvvcixc7lbgoudzicjbrdgacdhdhxyfw4yut4nfq@jpeg", "associated": { "chat": { @@ -65,7 +68,8 @@ "following": "at://did:plc:kvwvcn5iqfooopmyzvb4qzba/app.bsky.graph.follow/3judl7ak7gp2f" }, "labels": [], - "description": "CEO of Bluesky, steward of AT Protocol. \n\nLet\u2019s build a federated republic, starting with this server. \ud83c\udf31 \ud83e\udeb4 \ud83c\udf33 ", + "createdAt": "2022-11-17T06:31:40.296Z", + "description": "CEO of Bluesky, steward of AT Protocol. \n\nLet’s build a federated republic, starting with this server. 🌱 🪴 🌳 ", "indexedAt": "2024-02-06T22:21:45.352Z" }, { @@ -84,7 +88,8 @@ "following": "at://did:plc:kvwvcn5iqfooopmyzvb4qzba/app.bsky.graph.follow/3judkza2vrb2y" }, "labels": [], - "description": "Official Bluesky account (check domain\ud83d\udc46)\n\nFollow for updates and announcements", + "createdAt": "2023-04-12T04:53:57.057Z", + "description": "Official Bluesky account (check domain👆)\n\nFollow for updates and announcements", "indexedAt": "2024-01-25T23:46:28.776Z" }, { @@ -92,13 +97,19 @@ "handle": "vercel.com", "displayName": "Vercel", "avatar": "https://cdn.bsky.app/img/avatar/plain/did:plc:m2jwplpernhxkzbo4ev5ljwj/bafkreicebob2yf5lv6yg72luzv5qwsr6ob65j6oc3jciyowqkfiz736oqu@jpeg", + "associated": { + "chat": { + "allowIncoming": "following" + } + }, "viewer": { "muted": false, "blockedBy": false, "following": "at://did:plc:kvwvcn5iqfooopmyzvb4qzba/app.bsky.graph.follow/3judkwijszc25" }, "labels": [], - "description": "Vercel\u2019s frontend cloud gives developers the frameworks, workflows, and infrastructure to build a faster, more personalized web. Creators of @nextjs.org.", + "createdAt": "2023-04-25T00:08:45.850Z", + "description": "Vercel’s frontend cloud gives developers the frameworks, workflows, and infrastructure to build a faster, more personalized web. Creators of @nextjs.org.", "indexedAt": "2024-02-16T21:57:28.740Z" }, { @@ -118,7 +129,8 @@ "followedBy": "at://did:plc:s6jnht6koorxz7trghirytmf/app.bsky.graph.follow/3jucc25a7qs2k" }, "labels": [], - "description": "Software Engineer\n\n\ud83d\udc0d The AT Protocol SDK for Python: https://atproto.blue/\n\ud83c\udf7f Custom Feed in Python: https://github.com/MarshalX/bluesky-feed-generator\n\ud83c\udfce\ufe0f Fast DAG-CBOR decoder for Python: https://github.com/MarshalX/python-libipld\n\nhttps://marshal.dev", + "createdAt": "2023-04-12T11:14:00.501Z", + "description": "Software Engineer\n\n🐍 The AT Protocol SDK for Python: https://atproto.blue/\n🍿 Custom Feed in Python: https://github.com/MarshalX/bluesky-feed-generator\n🏎️ Fast DAG-CBOR decoder for Python: https://github.com/MarshalX/python-libipld\n\nhttps://marshal.dev", "indexedAt": "2024-01-26T00:15:07.447Z" } ], @@ -137,7 +149,8 @@ "blockedBy": false }, "labels": [], - "description": "account for tests", - "indexedAt": "2024-05-22T21:04:02.588Z" + "createdAt": "2023-04-26T19:05:34.249Z", + "description": "account for tests\n\nAuthor: @marshal.dev\nGitHub: https://github.com/MarshalX/atproto\nWebsite: https://atproto.blue", + "indexedAt": "2024-05-22T21:19:48.088Z" } } \ No newline at end of file diff --git a/tests/test_atproto_client/models/test_data/thread_view_post_with_embed_media.json b/tests/test_atproto_client/models/test_data/thread_view_post_with_embed_media.json index 84e14c63..16c4690c 100644 --- a/tests/test_atproto_client/models/test_data/thread_view_post_with_embed_media.json +++ b/tests/test_atproto_client/models/test_data/thread_view_post_with_embed_media.json @@ -18,7 +18,8 @@ "muted": false, "blockedBy": false }, - "labels": [] + "labels": [], + "createdAt": "2023-04-26T19:05:34.249Z" }, "record": { "$type": "app.bsky.feed.post", @@ -85,7 +86,8 @@ "muted": false, "blockedBy": false }, - "labels": [] + "labels": [], + "createdAt": "2023-04-26T19:05:34.249Z" }, "value": { "$type": "app.bsky.feed.post", @@ -108,7 +110,9 @@ "repostCount": 0, "likeCount": 0, "indexedAt": "2023-09-28T12:49:36.735Z", - "viewer": {}, + "viewer": { + "threadMuted": false + }, "labels": [] }, "replies": [] diff --git a/tests/test_atproto_client/models/tests/utils.py b/tests/test_atproto_client/models/tests/utils.py index 574d2d45..e0342f37 100644 --- a/tests/test_atproto_client/models/tests/utils.py +++ b/tests/test_atproto_client/models/tests/utils.py @@ -1,9 +1,10 @@ -import json import os +from pydantic_core import from_json + TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), '..', 'test_data') def load_data_from_file(test_name: str) -> dict: with open(os.path.join(TEST_DATA_PATH, f'{test_name}.json'), 'r') as f: - return json.load(f) + return from_json(f.read()) diff --git a/tests/test_atproto_crypto/test_verify.py b/tests/test_atproto_crypto/test_verify.py index 73349622..ca2ca5ed 100644 --- a/tests/test_atproto_crypto/test_verify.py +++ b/tests/test_atproto_crypto/test_verify.py @@ -1,9 +1,9 @@ import base64 -import json import os import pytest from atproto_crypto.verify import verify_signature +from pydantic_core import from_json # Ref: https://github.com/bluesky-social/atproto/blob/main/interop-test-files/crypto/signature-fixtures.json _FIXTURES_FILE_PATH = os.path.join(os.path.dirname(__file__), 'signature-fixtures.json') @@ -11,7 +11,7 @@ def _load_test_cases() -> list: with open(_FIXTURES_FILE_PATH, encoding='UTF-8') as file: - return json.load(file) + return from_json(file.read()) def _fix_base64_padding(data: str) -> str: From ff322798c014c3a499ed0370bc6cbb713ae892b5 Mon Sep 17 00:00:00 2001 From: "Ilya (Marshal)" Date: Fri, 26 Jul 2024 12:53:47 +0200 Subject: [PATCH 2/3] love you windows --- docs/fix_title_of_models.py | 2 +- docs/gen_aliases_db.py | 2 +- examples/advanced_usage/session_reuse.py | 4 ++-- packages/atproto_codegen/clients/generate_async_client.py | 2 +- tests/test_atproto_client/models/tests/utils.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/fix_title_of_models.py b/docs/fix_title_of_models.py index af218741..53d8b02b 100644 --- a/docs/fix_title_of_models.py +++ b/docs/fix_title_of_models.py @@ -7,7 +7,7 @@ def main(path: Path) -> None: for file in files: if file.startswith('atproto_client.models.'): file_path = os.path.join(root, file) - with open(file_path, 'r', encoding='UTF-8') as f: + with open(file_path, encoding='UTF-8') as f: content = f.read() content = content.replace('atproto\\_client.models.', '') with open(file_path, 'w', encoding='UTF-8') as f: diff --git a/docs/gen_aliases_db.py b/docs/gen_aliases_db.py index 38f8c5c1..8173eb22 100644 --- a/docs/gen_aliases_db.py +++ b/docs/gen_aliases_db.py @@ -5,7 +5,7 @@ def main(init_path: Path, output_path: Path) -> None: aliases_db = ['ALIASES_DB = {'] - with open(init_path, 'r', encoding='UTF-8') as f: + with open(init_path, encoding='UTF-8') as f: tree = ast.parse(f.read()) for node in ast.walk(tree): if isinstance(node, ast.ImportFrom) and node.names: diff --git a/examples/advanced_usage/session_reuse.py b/examples/advanced_usage/session_reuse.py index 1307ca09..7cda8671 100644 --- a/examples/advanced_usage/session_reuse.py +++ b/examples/advanced_usage/session_reuse.py @@ -5,14 +5,14 @@ def get_session() -> Optional[str]: try: - with open('session.txt') as f: + with open('session.txt', encoding='UTF-8') as f: return f.read() except FileNotFoundError: return None def save_session(session_string: str) -> None: - with open('session.txt', 'w') as f: + with open('session.txt', 'w', encoding='UTF-8') as f: f.write(session_string) diff --git a/packages/atproto_codegen/clients/generate_async_client.py b/packages/atproto_codegen/clients/generate_async_client.py index 5a9fcb9b..e4533462 100644 --- a/packages/atproto_codegen/clients/generate_async_client.py +++ b/packages/atproto_codegen/clients/generate_async_client.py @@ -8,7 +8,7 @@ def gen_client(input_filename: str, output_filename: str) -> None: - with open(_CLIENT_DIR.joinpath(input_filename), 'r', encoding='UTF-8') as f: + with open(_CLIENT_DIR.joinpath(input_filename), encoding='UTF-8') as f: code = f.read() # TODO(MarshalX): Get automatically diff --git a/tests/test_atproto_client/models/tests/utils.py b/tests/test_atproto_client/models/tests/utils.py index e0342f37..01899369 100644 --- a/tests/test_atproto_client/models/tests/utils.py +++ b/tests/test_atproto_client/models/tests/utils.py @@ -6,5 +6,5 @@ def load_data_from_file(test_name: str) -> dict: - with open(os.path.join(TEST_DATA_PATH, f'{test_name}.json'), 'r') as f: + with open(os.path.join(TEST_DATA_PATH, f'{test_name}.json'), encoding='UTF-8') as f: return from_json(f.read()) From 131d9516e3944021bb4c4b9ee6346af97e3b486e Mon Sep 17 00:00:00 2001 From: "Ilya (Marshal)" Date: Fri, 26 Jul 2024 13:28:31 +0200 Subject: [PATCH 3/3] use jiter for httpx responses; remove twice deserialization on XrpcError --- packages/atproto_client/models/utils.py | 15 ++++---- packages/atproto_client/request.py | 10 +++--- .../did/resolvers/plc_resolver.py | 5 +-- .../did/resolvers/web_resolver.py | 5 +-- .../models/tests/test_utils.py | 36 +++++++++++++++++++ update_lexicons.py | 3 +- 6 files changed, 58 insertions(+), 16 deletions(-) create mode 100644 tests/test_atproto_client/models/tests/test_utils.py diff --git a/packages/atproto_client/models/utils.py b/packages/atproto_client/models/utils.py index ce5cd903..ae16a44e 100644 --- a/packages/atproto_client/models/utils.py +++ b/packages/atproto_client/models/utils.py @@ -132,14 +132,17 @@ def get_model_as_json(model: t.Union[DotDict, BlobRef, ModelBase]) -> str: def is_json(json_data: t.Union[str, bytes]) -> bool: - if isinstance(json_data, bytes): - json_data.decode('UTF-8', errors='ignore') + return load_json(json_data, strict=False) is not None + +def load_json(json_data: t.Union[str, bytes], strict: bool = True) -> t.Optional[t.Dict[str, t.Any]]: try: - from_json(json_data) - return True - except ValueError: - return False + return from_json(json_data) + except ValueError as e: + if strict: + raise e + + return None def is_record_type(model: t.Union[ModelBase, DotDict], expected_type: t.Union[str, types.ModuleType]) -> bool: diff --git a/packages/atproto_client/request.py b/packages/atproto_client/request.py index 40345125..31c72951 100644 --- a/packages/atproto_client/request.py +++ b/packages/atproto_client/request.py @@ -7,7 +7,7 @@ from atproto_client import exceptions from atproto_client.models.common import XrpcError -from atproto_client.models.utils import get_or_create, is_json +from atproto_client.models.utils import get_or_create, load_json @dataclass @@ -35,7 +35,7 @@ def _convert_headers_to_dict(headers: httpx.Headers) -> t.Dict[str, str]: def _parse_response(response: httpx.Response) -> Response: content = response.content if response.headers.get('content-type') == 'application/json; charset=utf-8': - content = response.json() + content = from_json(response.content) return Response( success=True, @@ -65,9 +65,9 @@ def _handle_response(response: httpx.Response) -> httpx.Response: content=response.content, headers=_convert_headers_to_dict(response.headers), ) - if response.content and is_json(response.content): - data: t.Dict[str, t.Any] = from_json(response.content) - error_response.content = t.cast(XrpcError, get_or_create(data, XrpcError, strict=False)) + error_content = load_json(response.content, strict=False) + if error_content: + error_response.content = t.cast(XrpcError, get_or_create(error_content, XrpcError, strict=False)) if response.status_code in {401, 403}: raise exceptions.UnauthorizedError(error_response) diff --git a/packages/atproto_identity/did/resolvers/plc_resolver.py b/packages/atproto_identity/did/resolvers/plc_resolver.py index 20f08e1b..90b50f33 100644 --- a/packages/atproto_identity/did/resolvers/plc_resolver.py +++ b/packages/atproto_identity/did/resolvers/plc_resolver.py @@ -1,6 +1,7 @@ import typing as t import httpx +from pydantic_core import from_json from atproto_identity.did.resolvers.base_resolver import AsyncBaseResolver, BaseResolver from atproto_identity.exceptions import DidPlcResolverError @@ -31,7 +32,7 @@ def resolve_without_validation(self, did: str) -> t.Optional[t.Dict[str, t.Any]] return None response.raise_for_status() - return response.json() + return from_json(response.content) except httpx.HTTPError as e: raise DidPlcResolverError(f'Error resolving DID {did}') from e @@ -58,6 +59,6 @@ async def resolve_without_validation(self, did: str) -> t.Optional[t.Dict[str, t return None response.raise_for_status() - return response.json() + return from_json(response.content) except httpx.HTTPError as e: raise DidPlcResolverError(f'Error resolving DID {did}') from e diff --git a/packages/atproto_identity/did/resolvers/web_resolver.py b/packages/atproto_identity/did/resolvers/web_resolver.py index d302d470..79cc4dcd 100644 --- a/packages/atproto_identity/did/resolvers/web_resolver.py +++ b/packages/atproto_identity/did/resolvers/web_resolver.py @@ -1,6 +1,7 @@ import typing as t import httpx +from pydantic_core import from_json from atproto_identity.did.resolvers.base_resolver import AsyncBaseResolver, BaseResolver from atproto_identity.exceptions import DidWebResolverError, PoorlyFormattedDidError, UnsupportedDidWebPathError @@ -45,7 +46,7 @@ def resolve_without_validation(self, did: str) -> t.Dict[str, t.Any]: try: response = self._client.get(url, timeout=self._timeout) response.raise_for_status() - return response.json() + return from_json(response.content) except httpx.HTTPError as e: raise DidWebResolverError(f'Error resolving DID {did}') from e @@ -68,6 +69,6 @@ async def resolve_without_validation(self, did: str) -> t.Dict[str, t.Any]: try: response = await self._client.get(url, timeout=self._timeout) response.raise_for_status() - return response.json() + return from_json(response.content) except httpx.HTTPError as e: raise DidWebResolverError(f'Error resolving DID {did}') from e diff --git a/tests/test_atproto_client/models/tests/test_utils.py b/tests/test_atproto_client/models/tests/test_utils.py new file mode 100644 index 00000000..f08427bc --- /dev/null +++ b/tests/test_atproto_client/models/tests/test_utils.py @@ -0,0 +1,36 @@ +import pytest +from atproto_client.models.utils import is_json, load_json + + +def test_load_json() -> None: + assert load_json('{"key": "value"}') + assert load_json(b'{"key": "value"}') + + assert load_json('{"key": "value"', strict=False) is None + with pytest.raises(ValueError): + load_json(b'{"key": "value"') + + assert load_json('{"key": "value"', strict=False) is None + with pytest.raises(ValueError): + load_json(b'{"key": "value"') + + assert load_json('{"key": "value"}'.encode('UTF-16'), strict=False) is None + + with pytest.raises(TypeError): + load_json(None) + + +def test_is_json() -> None: + assert is_json('{"key": "value"}') is True + assert is_json(b'{"key": "value"}') is True + + assert is_json('{"key": "value"') is False + assert is_json(b'{"key": "value"') is False + + assert is_json('{"key": "value"}'.encode('UTF-16')) is False + + assert is_json(b'') is False + assert is_json(b'{}') is True + + with pytest.raises(TypeError): + load_json(None) diff --git a/update_lexicons.py b/update_lexicons.py index d6af2bef..8691082c 100755 --- a/update_lexicons.py +++ b/update_lexicons.py @@ -8,6 +8,7 @@ from pathlib import Path import httpx +from pydantic_core import from_json _GITHUB_BASE_URL = 'https://github.com' _GITHUB_API_BASE_URL = 'https://api.github.com' @@ -44,7 +45,7 @@ def _get_last_commit_info() -> t.Tuple[str, str, str]: ) response.raise_for_status() - response_json = response.json() + response_json = from_json(response.content) commit_info = response_json[0] sha = commit_info['sha']