Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Dec 2, 2024
1 parent 9807a81 commit abda096
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 118 deletions.
40 changes: 25 additions & 15 deletions packages/atproto_client/models/string_formats.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
from datetime import datetime
from typing import Callable, Mapping, Set, Union
from inspect import signature
from typing import Mapping, Set, TypeAlias, Union, cast
from urllib.parse import urlparse

from atproto_core.nsid import validate_nsid as atproto_core_validate_nsid
Expand Down Expand Up @@ -50,20 +51,29 @@
r')?$'
)

WithOrWithoutInfoValidator: TypeAlias = Union[
core_schema.WithInfoValidatorFunction, core_schema.NoInfoValidatorFunction
]


def only_validate_if_strict(validate_fn: core_schema.WithInfoValidatorFunction) -> Callable:
"""Skip validation if not opting into strict validation."""
def only_validate_if_strict(validate_fn: WithOrWithoutInfoValidator) -> WithOrWithoutInfoValidator:
"""Skip pydantic validation if not opting into strict validation via context."""
params = list(signature(validate_fn).parameters.values())
validator_wants_info = len(params) > 1 and params[1].annotation is ValidationInfo

def wrapper(v: str, info: ValidationInfo) -> str:
"""Could likely be generalized to support arbitrary signatures."""
if info and isinstance(info.context, Mapping) and info.context.get(_OPT_IN_KEY, False):
return validate_fn(v, info)
if validator_wants_info:
return cast(core_schema.WithInfoValidatorFunction, validate_fn)(v, info)
return cast(core_schema.NoInfoValidatorFunction, validate_fn)(v)
return v

return wrapper


@only_validate_if_strict
def validate_handle(v: str, info: ValidationInfo) -> str:
def validate_handle(v: str) -> str:
# Check ASCII first
if not v.isascii():
raise ValueError('Invalid handle: must contain only ASCII characters')
Expand All @@ -78,7 +88,7 @@ def validate_handle(v: str, info: ValidationInfo) -> str:


@only_validate_if_strict
def validate_did(v: str, info: ValidationInfo) -> str:
def validate_did(v: str) -> str:
# Check for invalid characters
if any(c in v for c in '/?#[]@'):
raise ValueError('Invalid DID: cannot contain /, ?, #, [, ], or @ characters')
Expand All @@ -98,7 +108,7 @@ def validate_did(v: str, info: ValidationInfo) -> str:


@only_validate_if_strict
def validate_nsid(v: str, info: ValidationInfo) -> str:
def validate_nsid(v: str) -> str:
if (
not atproto_core_validate_nsid(v, soft_fail=True)
or len(v) > MAX_NSID_LENGTH
Expand All @@ -113,14 +123,14 @@ def validate_nsid(v: str, info: ValidationInfo) -> str:


@only_validate_if_strict
def validate_language(v: str, info: ValidationInfo) -> str:
def validate_language(v: str) -> str:
if not LANG_RE.match(v):
raise ValueError('Invalid language code: must be ISO language code (e.g. en or en-US)')
return v


@only_validate_if_strict
def validate_record_key(v: str, info: ValidationInfo) -> str:
def validate_record_key(v: str) -> str:
if v in INVALID_RECORD_KEYS or not RKEY_RE.match(v):
raise ValueError(
'Invalid record key: must contain only alphanumeric, dot, underscore, colon, tilde, or hyphen characters'
Expand All @@ -129,14 +139,14 @@ def validate_record_key(v: str, info: ValidationInfo) -> str:


@only_validate_if_strict
def validate_cid(v: str, info: ValidationInfo) -> str:
def validate_cid(v: str) -> str:
if not CID_RE.match(v):
raise ValueError('Invalid CID: must be a valid Content Identifier with minimum length 8')
return v


@only_validate_if_strict
def validate_at_uri(v: str, info: ValidationInfo) -> str:
def validate_at_uri(v: str) -> str:
if len(v) >= MAX_AT_URI_LENGTH:
raise ValueError(f'Invalid AT-URI: must be under {MAX_AT_URI_LENGTH} chars')

Expand All @@ -147,7 +157,7 @@ def validate_at_uri(v: str, info: ValidationInfo) -> str:


@only_validate_if_strict
def validate_datetime(v: str, info: ValidationInfo) -> str:
def validate_datetime(v: str) -> str:
# Must contain uppercase T and Z if used
if v != v.strip():
raise ValueError('Invalid datetime: no whitespace allowed')
Expand Down Expand Up @@ -180,14 +190,14 @@ def validate_datetime(v: str, info: ValidationInfo) -> str:


@only_validate_if_strict
def validate_tid(v: str, info: ValidationInfo) -> str:
def validate_tid(v: str) -> str:
if not TID_RE.match(v) or (ord(v[0]) & 0x40):
raise ValueError(f'Invalid TID: must be exactly {TID_LENGTH} lowercase letters/numbers')
return v


@only_validate_if_strict
def validate_uri(v: str, info: ValidationInfo) -> str:
def validate_uri(v: str) -> str:
if len(v) >= MAX_URI_LENGTH or ' ' in v:
raise ValueError(f'Invalid URI: must be under {MAX_URI_LENGTH} chars and not contain spaces')
parsed = urlparse(v)
Expand All @@ -212,7 +222,7 @@ def validate_uri(v: str, info: ValidationInfo) -> str:
Uri = Annotated[str, BeforeValidator(validate_uri)]

# Any valid ATProto string format
ATProtoString = Annotated[
AtProtoString = Annotated[
Union[Handle, Did, Nsid, AtUri, Cid, DateTime, Tid, RecordKey, Uri, Language],
Field(description='ATProto string format'),
]
8 changes: 6 additions & 2 deletions packages/atproto_client/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing as t

import typing_extensions as te
from pydantic import ValidationError
from pydantic import BaseModel, ValidationError
from pydantic_core import from_json, to_json

from atproto_client import models
Expand Down Expand Up @@ -108,7 +108,11 @@ def _get_or_create(
return model_data

try:
return model.model_validate(model_data, context={'strict_string_format': strict_string_format})
if issubclass(model, BaseModel):
return model.model_validate(model_data, context={'strict_string_format': strict_string_format})
if not isinstance(model_data, t.Mapping):
raise ModelError(f'Cannot parse model of type {model}')
return model(**model_data)
except ValidationError as e:
raise ModelError(str(e)) from e

Expand Down
159 changes: 61 additions & 98 deletions tests/test_atproto_client/models/tests/test_string_formats.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import lru_cache
from pathlib import Path
from typing import List

Expand All @@ -13,32 +14,6 @@
INTEROP_TEST_FILES_DIR: Path = Path('tests/test_atproto_client/interop-test-files/syntax')


# TODO: 230 passed, 11 xfailed
# These cases appear in both _valid.txt and _invalid.txt files.
# Need investigation to determine if our validation is incorrect or if test data needs updating:
SKIP_THESE_VALUES = [
(
string_formats.AtUri,
'at://did:plc:asdf123',
), # Listed as both valid and invalid in AT-URI files under "enforces spec basics"
(
string_formats.AtUri,
'at://did:plc:asdf123/com.atproto.feed.post',
), # Same AT-URI pattern - appears in both valid/invalid files
(
string_formats.DateTime,
'1985-04-12T23:20:50.123Z',
), # Listed as "preferred" in valid but also appears in invalid under RFC-3339 section
(
string_formats.DateTime,
'1985-04-12T23:20:50.123-00:00',
), # Listed as "supported" in valid but marked invalid under timezone formats
(string_formats.DateTime, '1985-04-12T23:20Z'), # Similar timezone format discrepancy between valid/invalid files
(string_formats.Handle, 'john.test'), # Base pattern appears valid but numeric suffix versions are marked invalid
(string_formats.Nsid, 'one.two.three'), # Same pattern - base form valid but numeric suffixes marked invalid
]


def get_test_cases(filename: str) -> List[str]:
"""Get non-comment, non-empty lines from an interop test file.
Expand All @@ -60,90 +35,82 @@ def get_test_cases(filename: str) -> List[str]:
]


@pytest.fixture
def valid_handles() -> List[str]:
return get_test_cases('handle_syntax_valid.txt')


@pytest.fixture
def valid_dids() -> List[str]:
return get_test_cases('did_syntax_valid.txt')


@pytest.fixture
def valid_nsids() -> List[str]:
return get_test_cases('nsid_syntax_valid.txt')


@pytest.fixture
def valid_aturis() -> List[str]:
return get_test_cases('aturi_syntax_valid.txt')


@pytest.fixture
def valid_datetimes() -> List[str]:
return get_test_cases('datetime_syntax_valid.txt')


@pytest.fixture
def valid_tids() -> List[str]:
return get_test_cases('tid_syntax_valid.txt')


@pytest.fixture
def valid_record_keys() -> List[str]:
return get_test_cases('recordkey_syntax_valid.txt')
@lru_cache
def read_test_data() -> dict:
"""Load all test data once at session start"""
return {
'valid': {
'handle': get_test_cases('handle_syntax_valid.txt'),
'did': get_test_cases('did_syntax_valid.txt'),
'nsid': get_test_cases('nsid_syntax_valid.txt'),
'at_uri': get_test_cases('aturi_syntax_valid.txt'),
'datetime': get_test_cases('datetime_syntax_valid.txt'),
'tid': get_test_cases('tid_syntax_valid.txt'),
'record_key': get_test_cases('recordkey_syntax_valid.txt'),
},
'invalid': {
'handle': get_test_cases('handle_syntax_invalid.txt'),
'did': get_test_cases('did_syntax_invalid.txt'),
'nsid': get_test_cases('nsid_syntax_invalid.txt'),
'at_uri': get_test_cases('aturi_syntax_invalid.txt'),
'datetime': get_test_cases('datetime_syntax_invalid.txt'),
'tid': get_test_cases('tid_syntax_invalid.txt'),
'record_key': get_test_cases('recordkey_syntax_invalid.txt'),
},
}


@pytest.fixture
def valid_data(
valid_handles: List[str],
valid_dids: List[str],
valid_nsids: List[str],
valid_aturis: List[str],
valid_datetimes: List[str],
valid_tids: List[str],
valid_record_keys: List[str],
) -> dict:
def valid_data() -> dict:
"""Get first valid example of each type plus constants"""
test_data = read_test_data()
return {
'handle': valid_handles[0],
'did': valid_dids[0],
'nsid': valid_nsids[0],
'at_uri': valid_aturis[0],
'handle': test_data['valid']['handle'][0],
'did': test_data['valid']['did'][0],
'nsid': test_data['valid']['nsid'][0],
'at_uri': test_data['valid']['at_uri'][0],
'cid': 'bafyreidfayvfuwqa2beehqn7axeeeaej5aqvaowxgwcdt2rw', # No interop test file for CID
'datetime': valid_datetimes[0],
'tid': valid_tids[0],
'record_key': valid_record_keys[0],
'datetime': test_data['valid']['datetime'][0],
'tid': test_data['valid']['tid'][0],
'record_key': test_data['valid']['record_key'][0],
'uri': 'https://example.com', # No interop test file for URI
'language': 'en-US', # No interop test file for language
}


@pytest.fixture
def invalid_data() -> dict:
"""Get first invalid example of each type plus constants"""
test_data = read_test_data()
return {
'handle': get_test_cases('handle_syntax_invalid.txt')[0],
'did': get_test_cases('did_syntax_invalid.txt')[0],
'nsid': get_test_cases('nsid_syntax_invalid.txt')[0],
'at_uri': get_test_cases('aturi_syntax_invalid.txt')[0],
'handle': test_data['invalid']['handle'][0],
'did': test_data['invalid']['did'][0],
'nsid': test_data['invalid']['nsid'][0],
'at_uri': test_data['invalid']['at_uri'][0],
'cid': 'short', # No interop test file for CID
'datetime': get_test_cases('datetime_syntax_invalid.txt')[0],
'tid': get_test_cases('tid_syntax_invalid.txt')[0],
'record_key': get_test_cases('recordkey_syntax_invalid.txt')[0],
'datetime': test_data['invalid']['datetime'][0],
'tid': test_data['invalid']['tid'][0],
'record_key': test_data['invalid']['record_key'][0],
'uri': 'invalid-uri-no-scheme', # No interop test file for URI
'language': 'invalid!', # No interop test file for language
}


@pytest.mark.parametrize(
'validator_type,field_name,invalid_value',
[(string_formats.AtUri, 'at_uri', c) for c in get_test_cases('aturi_syntax_invalid.txt')]
+ [(string_formats.DateTime, 'datetime', c) for c in get_test_cases('datetime_syntax_invalid.txt')]
+ [(string_formats.Handle, 'handle', c) for c in get_test_cases('handle_syntax_invalid.txt')]
+ [(string_formats.Did, 'did', c) for c in get_test_cases('did_syntax_invalid.txt')]
+ [(string_formats.Nsid, 'nsid', c) for c in get_test_cases('nsid_syntax_invalid.txt')]
+ [(string_formats.Tid, 'tid', c) for c in get_test_cases('tid_syntax_invalid.txt')]
+ [(string_formats.RecordKey, 'record_key', c) for c in get_test_cases('recordkey_syntax_invalid.txt')],
[
(validator_type, field_name, invalid_value)
for validator_type, field_name in [
(string_formats.AtUri, 'at_uri'),
(string_formats.DateTime, 'datetime'),
(string_formats.Handle, 'handle'),
(string_formats.Did, 'did'),
(string_formats.Nsid, 'nsid'),
(string_formats.Tid, 'tid'),
(string_formats.RecordKey, 'record_key'),
]
for invalid_value in read_test_data()['invalid'][field_name]
],
)
def test_string_format_validation(validator_type: type, field_name: str, invalid_value: str, valid_data: dict) -> None:
"""Test validation for each string format type."""
Expand Down Expand Up @@ -179,7 +146,7 @@ def test_string_format_validation(validator_type: type, field_name: str, invalid
def test_generic_string_format_validation(valid_value: str) -> None:
"""Test that ATProtoString accepts each valid string format."""

validated = TypeAdapter(string_formats.ATProtoString).validate_python(valid_value, context={_OPT_IN_KEY: True})
validated = TypeAdapter(string_formats.AtProtoString).validate_python(valid_value, context={_OPT_IN_KEY: True})
assert validated == valid_value


Expand All @@ -199,18 +166,14 @@ class FooModel(BaseModel):
assert instance.did == valid_data['did']

# Test invalid handle fails
try:
with pytest.raises(ModelError) as exc_info:
get_or_create({'handle': invalid_data['handle'], 'did': valid_data['did']}, FooModel, strict_string_format=True)
pytest.fail('Handle validation should have failed')
except ModelError as e:
assert 'must be a domain name' in str(e)
assert 'must be a domain name' in str(exc_info.value)

# Test invalid did fails
try:
with pytest.raises(ModelError) as exc_info:
get_or_create({'handle': valid_data['handle'], 'did': invalid_data['did']}, FooModel, strict_string_format=True)
pytest.fail('Did validation should have failed')
except ModelError as e:
assert 'must be in format did:method:identifier' in str(e)
assert 'must be in format did:method:identifier' in str(exc_info.value)

# Test that validation is skipped when strict_string_format=False
instance = get_or_create(
Expand Down
Loading

0 comments on commit abda096

Please sign in to comment.