Skip to content

Commit

Permalink
tests n tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Nov 25, 2024
1 parent eb1d28b commit e0edaeb
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 37 deletions.
57 changes: 40 additions & 17 deletions packages/atproto_client/models/string_formats.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
import string
from datetime import datetime
from typing import Callable, Mapping, Set, Union
from urllib.parse import urlparse
Expand All @@ -25,9 +24,13 @@
NSID_RE = re.compile(r'(?![0-9])((?!-)[a-z0-9-]{1,63}(?<!-)\.){2,}[a-zA-Z]{1,63}')
LANG_RE = re.compile(r'^(i|[a-z]{2,3})(-[A-Za-z0-9-]+)?$')
RKEY_RE = re.compile(r'^[A-Za-z0-9._:~-]{1,512}$')
TID_RE = re.compile(rf'^[{string.ascii_lowercase}234567]{{{TID_LENGTH}}}$')
TID_RE = re.compile(rf'^[2-7a-z]{{{TID_LENGTH}}}$')
CID_RE = re.compile(r'^[A-Za-z0-9+]{8,}$')
AT_URI_RE = re.compile(r'at://[^/]+(/[^/]+(/[^/]+)?)?')
AT_URI_RE = re.compile(
r'^at://' # Protocol
r'([a-z0-9][a-z0-9.-]*[a-z0-9]|did:[a-z]+:[a-z0-9.:%-]+)' # Authority: either domain or DID
r'(/[a-z][a-z0-9.-]*(\.[a-z][a-z0-9.-]*)*(/[a-z0-9.-]+)?)?$' # Optional path segments
)


def only_validate_if_strict(validate_fn: core_schema.WithInfoValidatorFunction) -> Callable:
Expand Down Expand Up @@ -91,34 +94,54 @@ def validate_cid(v: str, info: ValidationInfo) -> str:

@only_validate_if_strict
def validate_at_uri(v: str, info: ValidationInfo) -> str:
if len(v) >= MAX_URI_LENGTH or '/./' in v or '/../' in v or v.endswith(('/.', '/..')):
raise ValueError(f'Invalid AT-URI: must be under {MAX_URI_LENGTH} chars and not contain /./ or /../ patterns')
if len(v) >= MAX_URI_LENGTH:
raise ValueError(f'Invalid AT-URI: must be under {MAX_URI_LENGTH} chars')

if not AT_URI_RE.match(v):
raise ValueError('Invalid AT-URI: must be in format at://authority/collection/record')

if (
'/./' in v
or '/../' in v
or v.endswith('/')
or '#' in v
or
# Invalid percent encoding patterns
('%' in v and not re.match(r'%[0-9A-Fa-f]{2}', v[v.index('%') :]))
):
raise ValueError(
'Invalid AT-URI: must be in format at://authority/collection/record (e.g. at://user.bsky.social/posts/123)'
'Invalid AT-URI: must not contain /./, /../, trailing slashes, fragments, or invalid percent encoding'
)

return v


@only_validate_if_strict
def validate_datetime(v: str, info: ValidationInfo) -> str:
if 'T' not in v or not any(v.endswith(x) for x in ('Z', '+00:00')):
raise ValueError('Invalid datetime format: must be ISO 8601 with timezone (e.g. 2023-01-01T12:00:00Z)')
# Must contain T separator
if 'T' not in v:
raise ValueError('Invalid datetime: must contain T separator')

# Must have timezone
orig_val = v
v = re.sub(r'([+-][0-9]{2}:[0-9]{2}|Z)$', '', orig_val)
if v == orig_val:
raise ValueError('Invalid datetime: must include timezone (Z or +/-HH:MM)')

# Strip fractional seconds before parsing
v = re.sub(r'\.[0-9]+$', '', v)

try:
datetime.fromisoformat(v.replace('Z', '+00:00'))
return v
datetime.fromisoformat(v)
return orig_val
except ValueError:
raise ValueError(
'Invalid datetime format: must be ISO 8601 with timezone (e.g. 2023-01-01T12:00:00Z)'
) from None
raise ValueError('Invalid datetime: must be valid ISO 8601 format') from None


@only_validate_if_strict
def validate_tid(v: str, info: ValidationInfo) -> str:
if not TID_RE.match(v):
raise ValueError(f'Invalid TID format: must be exactly {TID_LENGTH} lowercase letters/numbers')
if ord(v[0]) & 0x40:
raise ValueError('Invalid TID: high bit cannot be 1')
if not TID_RE.match(v) or (ord(v[0]) & 0x40):
raise ValueError(f'Invalid TID: must be exactly {TID_LENGTH} lowercase letters/numbers')
return v


Expand Down
51 changes: 31 additions & 20 deletions tests/test_atproto_client/models/tests/test_string_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,40 +98,51 @@ def invalid_data() -> dict:


@pytest.mark.parametrize(
'validator_type,field_name,expected_error',
[
(string_formats.Handle, 'handle', 'must be a domain name'),
(string_formats.Did, 'did', 'must be in format did:method:identifier'),
(string_formats.Nsid, 'nsid', 'must be dot-separated segments'),
(string_formats.AtUri, 'at_uri', 'must be in format at://authority/collection/record'),
(string_formats.Cid, 'cid', 'must be a valid Content Identifier'),
(string_formats.DateTime, 'datetime', 'must be ISO 8601 with timezone'),
(string_formats.Tid, 'tid', 'must be exactly 13 lowercase letters/numbers'),
(string_formats.RecordKey, 'record_key', 'must contain only alphanumeric'),
(string_formats.Uri, 'uri', 'must be a valid URI with scheme and authority/path'),
(string_formats.Language, 'language', 'must be ISO language code'),
'validator_type,field_name,error_keywords,invalid_value',
[(string_formats.AtUri, 'at_uri', ['Invalid AT-URI'], case) for case in get_test_cases('aturi_syntax_invalid.txt')]
+ [
(string_formats.DateTime, 'datetime', ['Invalid datetime'], case)
for case in get_test_cases('datetime_syntax_invalid.txt')
]
+ [
(string_formats.Handle, 'handle', 'must be a domain name', case)
for case in get_test_cases('handle_syntax_invalid.txt')
]
+ [
(string_formats.Did, 'did', 'must be in format did:method:identifier', case)
for case in get_test_cases('did_syntax_invalid.txt')
]
+ [
(string_formats.Nsid, 'nsid', 'must be dot-separated segments', case)
for case in get_test_cases('nsid_syntax_invalid.txt')
]
+ [
(string_formats.Tid, 'tid', 'must be exactly 13 lowercase letters/numbers', case)
for case in get_test_cases('tid_syntax_invalid.txt')
]
+ [
(string_formats.RecordKey, 'record_key', 'must contain only alphanumeric', case)
for case in get_test_cases('recordkey_syntax_invalid.txt')
],
)
def test_string_format_validation(
validator_type: type, field_name: str, expected_error: str, valid_data: dict, invalid_data: dict
validator_type: type, field_name: str, error_keywords: List[str], invalid_value: str, valid_data: dict
) -> None:
"""Test validation for each string format type."""
SomeTypeAdapter = TypeAdapter(validator_type)

# Test that validation is skipped by default
assert SomeTypeAdapter.validate_python(invalid_data[field_name]) == invalid_data[field_name]
assert SomeTypeAdapter.validate_python(invalid_value) == invalid_value

# Test that valid data passes strict validation
validated_value = SomeTypeAdapter.validate_python(valid_data[field_name], context={_OPT_IN_KEY: True})
assert validated_value == valid_data[field_name]

# Test that invalid data fails strict validation
try:
SomeTypeAdapter.validate_python(invalid_data[field_name], context={_OPT_IN_KEY: True})
pytest.fail(f'{validator_type.__name__} validation should have failed')
except ValidationError as e:
error = e.errors()[0]
assert expected_error in error['msg']
with pytest.raises(ValidationError) as exc_info:
SomeTypeAdapter.validate_python(invalid_value, context={_OPT_IN_KEY: True})
error_msg = str(exc_info.value)
assert any(keyword in error_msg for keyword in error_keywords)


@pytest.mark.parametrize(
Expand Down

0 comments on commit e0edaeb

Please sign in to comment.