diff --git a/packages/atproto_client/models/string_formats.py b/packages/atproto_client/models/string_formats.py index b46a989d..6bb2799e 100644 --- a/packages/atproto_client/models/string_formats.py +++ b/packages/atproto_client/models/string_formats.py @@ -1,6 +1,7 @@ import re from datetime import datetime -from typing import Callable, Mapping, Set, Union +from inspect import signature +from typing import Mapping, Set, TypeAlias, Union, cast from urllib.parse import urlparse from atproto_core.nsid import validate_nsid as atproto_core_validate_nsid @@ -50,20 +51,29 @@ r')?$' ) +WithOrWithoutInfoValidator: TypeAlias = Union[ + core_schema.WithInfoValidatorFunction, core_schema.NoInfoValidatorFunction +] + -def only_validate_if_strict(validate_fn: core_schema.WithInfoValidatorFunction) -> Callable: - """Skip validation if not opting into strict validation.""" +def only_validate_if_strict(validate_fn: WithOrWithoutInfoValidator) -> WithOrWithoutInfoValidator: + """Skip pydantic validation if not opting into strict validation via context.""" + params = list(signature(validate_fn).parameters.values()) + validator_wants_info = len(params) > 1 and params[1].annotation is ValidationInfo def wrapper(v: str, info: ValidationInfo) -> str: + """Could likely be generalized to support arbitrary signatures.""" if info and isinstance(info.context, Mapping) and info.context.get(_OPT_IN_KEY, False): - return validate_fn(v, info) + if validator_wants_info: + return cast(core_schema.WithInfoValidatorFunction, validate_fn)(v, info) + return cast(core_schema.NoInfoValidatorFunction, validate_fn)(v) return v return wrapper @only_validate_if_strict -def validate_handle(v: str, info: ValidationInfo) -> str: +def validate_handle(v: str) -> str: # Check ASCII first if not v.isascii(): raise ValueError('Invalid handle: must contain only ASCII characters') @@ -78,7 +88,7 @@ def validate_handle(v: str, info: ValidationInfo) -> str: @only_validate_if_strict -def validate_did(v: str, info: ValidationInfo) -> str: +def validate_did(v: str) -> str: # Check for invalid characters if any(c in v for c in '/?#[]@'): raise ValueError('Invalid DID: cannot contain /, ?, #, [, ], or @ characters') @@ -98,7 +108,7 @@ def validate_did(v: str, info: ValidationInfo) -> str: @only_validate_if_strict -def validate_nsid(v: str, info: ValidationInfo) -> str: +def validate_nsid(v: str) -> str: if ( not atproto_core_validate_nsid(v, soft_fail=True) or len(v) > MAX_NSID_LENGTH @@ -113,14 +123,14 @@ def validate_nsid(v: str, info: ValidationInfo) -> str: @only_validate_if_strict -def validate_language(v: str, info: ValidationInfo) -> str: +def validate_language(v: str) -> str: if not LANG_RE.match(v): raise ValueError('Invalid language code: must be ISO language code (e.g. en or en-US)') return v @only_validate_if_strict -def validate_record_key(v: str, info: ValidationInfo) -> str: +def validate_record_key(v: str) -> str: if v in INVALID_RECORD_KEYS or not RKEY_RE.match(v): raise ValueError( 'Invalid record key: must contain only alphanumeric, dot, underscore, colon, tilde, or hyphen characters' @@ -129,14 +139,14 @@ def validate_record_key(v: str, info: ValidationInfo) -> str: @only_validate_if_strict -def validate_cid(v: str, info: ValidationInfo) -> str: +def validate_cid(v: str) -> str: if not CID_RE.match(v): raise ValueError('Invalid CID: must be a valid Content Identifier with minimum length 8') return v @only_validate_if_strict -def validate_at_uri(v: str, info: ValidationInfo) -> str: +def validate_at_uri(v: str) -> str: if len(v) >= MAX_AT_URI_LENGTH: raise ValueError(f'Invalid AT-URI: must be under {MAX_AT_URI_LENGTH} chars') @@ -147,7 +157,7 @@ def validate_at_uri(v: str, info: ValidationInfo) -> str: @only_validate_if_strict -def validate_datetime(v: str, info: ValidationInfo) -> str: +def validate_datetime(v: str) -> str: # Must contain uppercase T and Z if used if v != v.strip(): raise ValueError('Invalid datetime: no whitespace allowed') @@ -180,14 +190,14 @@ def validate_datetime(v: str, info: ValidationInfo) -> str: @only_validate_if_strict -def validate_tid(v: str, info: ValidationInfo) -> str: +def validate_tid(v: str) -> str: if not TID_RE.match(v) or (ord(v[0]) & 0x40): raise ValueError(f'Invalid TID: must be exactly {TID_LENGTH} lowercase letters/numbers') return v @only_validate_if_strict -def validate_uri(v: str, info: ValidationInfo) -> str: +def validate_uri(v: str) -> str: if len(v) >= MAX_URI_LENGTH or ' ' in v: raise ValueError(f'Invalid URI: must be under {MAX_URI_LENGTH} chars and not contain spaces') parsed = urlparse(v) @@ -212,7 +222,7 @@ def validate_uri(v: str, info: ValidationInfo) -> str: Uri = Annotated[str, BeforeValidator(validate_uri)] # Any valid ATProto string format -ATProtoString = Annotated[ +AtProtoString = Annotated[ Union[Handle, Did, Nsid, AtUri, Cid, DateTime, Tid, RecordKey, Uri, Language], Field(description='ATProto string format'), ] diff --git a/packages/atproto_client/models/utils.py b/packages/atproto_client/models/utils.py index 30e2cfdc..94ae0d43 100644 --- a/packages/atproto_client/models/utils.py +++ b/packages/atproto_client/models/utils.py @@ -2,7 +2,7 @@ import typing as t import typing_extensions as te -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from pydantic_core import from_json, to_json from atproto_client import models @@ -108,7 +108,11 @@ def _get_or_create( return model_data try: - return model.model_validate(model_data, context={'strict_string_format': strict_string_format}) + if issubclass(model, BaseModel): + return model.model_validate(model_data, context={'strict_string_format': strict_string_format}) + if not isinstance(model_data, t.Mapping): + raise ModelError(f'Cannot parse model of type {model}') + return model(**model_data) except ValidationError as e: raise ModelError(str(e)) from e diff --git a/tests/test_atproto_client/models/tests/test_string_formats.py b/tests/test_atproto_client/models/tests/test_string_formats.py index cb7ce7c4..87203ebe 100644 --- a/tests/test_atproto_client/models/tests/test_string_formats.py +++ b/tests/test_atproto_client/models/tests/test_string_formats.py @@ -1,3 +1,4 @@ +from functools import lru_cache from pathlib import Path from typing import List @@ -13,32 +14,6 @@ INTEROP_TEST_FILES_DIR: Path = Path('tests/test_atproto_client/interop-test-files/syntax') -# TODO: 230 passed, 11 xfailed -# These cases appear in both _valid.txt and _invalid.txt files. -# Need investigation to determine if our validation is incorrect or if test data needs updating: -SKIP_THESE_VALUES = [ - ( - string_formats.AtUri, - 'at://did:plc:asdf123', - ), # Listed as both valid and invalid in AT-URI files under "enforces spec basics" - ( - string_formats.AtUri, - 'at://did:plc:asdf123/com.atproto.feed.post', - ), # Same AT-URI pattern - appears in both valid/invalid files - ( - string_formats.DateTime, - '1985-04-12T23:20:50.123Z', - ), # Listed as "preferred" in valid but also appears in invalid under RFC-3339 section - ( - string_formats.DateTime, - '1985-04-12T23:20:50.123-00:00', - ), # Listed as "supported" in valid but marked invalid under timezone formats - (string_formats.DateTime, '1985-04-12T23:20Z'), # Similar timezone format discrepancy between valid/invalid files - (string_formats.Handle, 'john.test'), # Base pattern appears valid but numeric suffix versions are marked invalid - (string_formats.Nsid, 'one.two.three'), # Same pattern - base form valid but numeric suffixes marked invalid -] - - def get_test_cases(filename: str) -> List[str]: """Get non-comment, non-empty lines from an interop test file. @@ -60,60 +35,44 @@ def get_test_cases(filename: str) -> List[str]: ] -@pytest.fixture -def valid_handles() -> List[str]: - return get_test_cases('handle_syntax_valid.txt') - - -@pytest.fixture -def valid_dids() -> List[str]: - return get_test_cases('did_syntax_valid.txt') - - -@pytest.fixture -def valid_nsids() -> List[str]: - return get_test_cases('nsid_syntax_valid.txt') - - -@pytest.fixture -def valid_aturis() -> List[str]: - return get_test_cases('aturi_syntax_valid.txt') - - -@pytest.fixture -def valid_datetimes() -> List[str]: - return get_test_cases('datetime_syntax_valid.txt') - - -@pytest.fixture -def valid_tids() -> List[str]: - return get_test_cases('tid_syntax_valid.txt') - - -@pytest.fixture -def valid_record_keys() -> List[str]: - return get_test_cases('recordkey_syntax_valid.txt') +@lru_cache +def read_test_data() -> dict: + """Load all test data once at session start""" + return { + 'valid': { + 'handle': get_test_cases('handle_syntax_valid.txt'), + 'did': get_test_cases('did_syntax_valid.txt'), + 'nsid': get_test_cases('nsid_syntax_valid.txt'), + 'at_uri': get_test_cases('aturi_syntax_valid.txt'), + 'datetime': get_test_cases('datetime_syntax_valid.txt'), + 'tid': get_test_cases('tid_syntax_valid.txt'), + 'record_key': get_test_cases('recordkey_syntax_valid.txt'), + }, + 'invalid': { + 'handle': get_test_cases('handle_syntax_invalid.txt'), + 'did': get_test_cases('did_syntax_invalid.txt'), + 'nsid': get_test_cases('nsid_syntax_invalid.txt'), + 'at_uri': get_test_cases('aturi_syntax_invalid.txt'), + 'datetime': get_test_cases('datetime_syntax_invalid.txt'), + 'tid': get_test_cases('tid_syntax_invalid.txt'), + 'record_key': get_test_cases('recordkey_syntax_invalid.txt'), + }, + } @pytest.fixture -def valid_data( - valid_handles: List[str], - valid_dids: List[str], - valid_nsids: List[str], - valid_aturis: List[str], - valid_datetimes: List[str], - valid_tids: List[str], - valid_record_keys: List[str], -) -> dict: +def valid_data() -> dict: + """Get first valid example of each type plus constants""" + test_data = read_test_data() return { - 'handle': valid_handles[0], - 'did': valid_dids[0], - 'nsid': valid_nsids[0], - 'at_uri': valid_aturis[0], + 'handle': test_data['valid']['handle'][0], + 'did': test_data['valid']['did'][0], + 'nsid': test_data['valid']['nsid'][0], + 'at_uri': test_data['valid']['at_uri'][0], 'cid': 'bafyreidfayvfuwqa2beehqn7axeeeaej5aqvaowxgwcdt2rw', # No interop test file for CID - 'datetime': valid_datetimes[0], - 'tid': valid_tids[0], - 'record_key': valid_record_keys[0], + 'datetime': test_data['valid']['datetime'][0], + 'tid': test_data['valid']['tid'][0], + 'record_key': test_data['valid']['record_key'][0], 'uri': 'https://example.com', # No interop test file for URI 'language': 'en-US', # No interop test file for language } @@ -121,15 +80,17 @@ def valid_data( @pytest.fixture def invalid_data() -> dict: + """Get first invalid example of each type plus constants""" + test_data = read_test_data() return { - 'handle': get_test_cases('handle_syntax_invalid.txt')[0], - 'did': get_test_cases('did_syntax_invalid.txt')[0], - 'nsid': get_test_cases('nsid_syntax_invalid.txt')[0], - 'at_uri': get_test_cases('aturi_syntax_invalid.txt')[0], + 'handle': test_data['invalid']['handle'][0], + 'did': test_data['invalid']['did'][0], + 'nsid': test_data['invalid']['nsid'][0], + 'at_uri': test_data['invalid']['at_uri'][0], 'cid': 'short', # No interop test file for CID - 'datetime': get_test_cases('datetime_syntax_invalid.txt')[0], - 'tid': get_test_cases('tid_syntax_invalid.txt')[0], - 'record_key': get_test_cases('recordkey_syntax_invalid.txt')[0], + 'datetime': test_data['invalid']['datetime'][0], + 'tid': test_data['invalid']['tid'][0], + 'record_key': test_data['invalid']['record_key'][0], 'uri': 'invalid-uri-no-scheme', # No interop test file for URI 'language': 'invalid!', # No interop test file for language } @@ -137,13 +98,19 @@ def invalid_data() -> dict: @pytest.mark.parametrize( 'validator_type,field_name,invalid_value', - [(string_formats.AtUri, 'at_uri', c) for c in get_test_cases('aturi_syntax_invalid.txt')] - + [(string_formats.DateTime, 'datetime', c) for c in get_test_cases('datetime_syntax_invalid.txt')] - + [(string_formats.Handle, 'handle', c) for c in get_test_cases('handle_syntax_invalid.txt')] - + [(string_formats.Did, 'did', c) for c in get_test_cases('did_syntax_invalid.txt')] - + [(string_formats.Nsid, 'nsid', c) for c in get_test_cases('nsid_syntax_invalid.txt')] - + [(string_formats.Tid, 'tid', c) for c in get_test_cases('tid_syntax_invalid.txt')] - + [(string_formats.RecordKey, 'record_key', c) for c in get_test_cases('recordkey_syntax_invalid.txt')], + [ + (validator_type, field_name, invalid_value) + for validator_type, field_name in [ + (string_formats.AtUri, 'at_uri'), + (string_formats.DateTime, 'datetime'), + (string_formats.Handle, 'handle'), + (string_formats.Did, 'did'), + (string_formats.Nsid, 'nsid'), + (string_formats.Tid, 'tid'), + (string_formats.RecordKey, 'record_key'), + ] + for invalid_value in read_test_data()['invalid'][field_name] + ], ) def test_string_format_validation(validator_type: type, field_name: str, invalid_value: str, valid_data: dict) -> None: """Test validation for each string format type.""" @@ -179,7 +146,7 @@ def test_string_format_validation(validator_type: type, field_name: str, invalid def test_generic_string_format_validation(valid_value: str) -> None: """Test that ATProtoString accepts each valid string format.""" - validated = TypeAdapter(string_formats.ATProtoString).validate_python(valid_value, context={_OPT_IN_KEY: True}) + validated = TypeAdapter(string_formats.AtProtoString).validate_python(valid_value, context={_OPT_IN_KEY: True}) assert validated == valid_value @@ -199,18 +166,14 @@ class FooModel(BaseModel): assert instance.did == valid_data['did'] # Test invalid handle fails - try: + with pytest.raises(ModelError) as exc_info: get_or_create({'handle': invalid_data['handle'], 'did': valid_data['did']}, FooModel, strict_string_format=True) - pytest.fail('Handle validation should have failed') - except ModelError as e: - assert 'must be a domain name' in str(e) + assert 'must be a domain name' in str(exc_info.value) # Test invalid did fails - try: + with pytest.raises(ModelError) as exc_info: get_or_create({'handle': valid_data['handle'], 'did': invalid_data['did']}, FooModel, strict_string_format=True) - pytest.fail('Did validation should have failed') - except ModelError as e: - assert 'must be in format did:method:identifier' in str(e) + assert 'must be in format did:method:identifier' in str(exc_info.value) # Test that validation is skipped when strict_string_format=False instance = get_or_create( diff --git a/tests/test_atproto_client/models/tests/test_utils.py b/tests/test_atproto_client/models/tests/test_utils.py index f08427bc..f58a5339 100644 --- a/tests/test_atproto_client/models/tests/test_utils.py +++ b/tests/test_atproto_client/models/tests/test_utils.py @@ -1,5 +1,8 @@ +from dataclasses import dataclass + import pytest -from atproto_client.models.utils import is_json, load_json +from atproto_client.exceptions import ModelError +from atproto_client.models.utils import get_or_create, is_json, load_json def test_load_json() -> None: @@ -17,7 +20,7 @@ def test_load_json() -> None: assert load_json('{"key": "value"}'.encode('UTF-16'), strict=False) is None with pytest.raises(TypeError): - load_json(None) + load_json(None) # type: ignore[reportArgumentType] def test_is_json() -> None: @@ -33,4 +36,18 @@ def test_is_json() -> None: assert is_json(b'{}') is True with pytest.raises(TypeError): - load_json(None) + load_json(None) # type: ignore[reportArgumentType] + + +def test_get_or_create_works_with_dataclasses() -> None: + """Test that get_or_create works with dataclasses.""" + + @dataclass + class Foo: + bar: str + + with pytest.raises(ModelError, match='Cannot parse model of type'): + get_or_create(('bar', 'not a mapping'), Foo) + + result = get_or_create({'bar': 'baz'}, Foo) + assert result == Foo(bar='baz')