Skip to content

Commit

Permalink
fix test case loading
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Nov 29, 2024
1 parent 1428671 commit 34444be
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 133 deletions.
115 changes: 43 additions & 72 deletions packages/atproto_client/models/string_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, Mapping, Set, Union
from urllib.parse import urlparse

from atproto_core.nsid import validate_nsid as atproto_core_validate_nsid
from pydantic import BeforeValidator, Field, ValidationInfo
from pydantic_core import core_schema
from typing_extensions import Annotated, Literal
Expand All @@ -28,26 +29,25 @@
r'[a-zA-Z0-9._%-]{1,2048}' # method-specific-id with length limit
r'(?<!:)$' # Cannot end with colon
)
NSID_RE = re.compile(
r'(?![0-9])' # Can't start with number
r'((?!-)[a-z0-9-]{1,63}(?<!-)\.){2,}' # At least 2 segments, each 1-63 chars
r'[a-zA-Z]' # Last segment must start with letter
r'[a-zA-Z0-9-]*' # Middle chars
r'[a-zA-Z]' # Must end with letter
r'$' # End of string
)
LANG_RE = re.compile(r'^(i|[a-z]{2,3})(-[A-Za-z0-9-]+)?$')
RKEY_RE = re.compile(r'^[A-Za-z0-9._:~-]{1,512}$')
TID_RE = re.compile(rf'^[2-7a-z]{{{TID_LENGTH}}}$')
CID_RE = re.compile(r'^[A-Za-z0-9+]{8,}$')
AT_URI_RE = re.compile(
r'^at://' # Protocol
r'(' # Authority is either:
r'([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z][a-zA-Z0-9-]*[a-zA-Z]' # Handle (must be domain)
r'^at://' # Must start with at://
r'(' # Authority group start
# For DIDs: Only allowed chars are letters, numbers, period, hyphen, and percent
r'did:[a-z]+:[a-zA-Z0-9.-]+' # Notice removed underscore from allowed chars
r'|' # or
r'did:[a-z]+:[a-zA-Z0-9._%-]+' # DID
r')'
r'(/[a-z][a-zA-Z0-9-]*(\.[a-z][a-zA-Z0-9-]*)+(/[a-zA-Z0-9._:~-]+)?)?$' # Optional COLLECTION/RKEY
# Handle: require 2+ segments, TLD can't start with digit
r'[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?' # First segment
r'(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*' # Middle segments
r'\.[a-zA-Z][a-zA-Z0-9-]*' # TLD must start with letter
r')' # Authority group end
r'(?:' # Optional path group
r'/[a-z][a-zA-Z0-9-]*(\.[a-z][a-zA-Z0-9-])+' # NSID
r'(?:/[A-Za-z0-9._:~-]+)?' # Optional record key
r')?$'
)


Expand Down Expand Up @@ -100,14 +100,11 @@ def validate_did(v: str, info: ValidationInfo) -> str:
@only_validate_if_strict
def validate_nsid(v: str, info: ValidationInfo) -> str:
if (
not NSID_RE.match(v)
or '..' in v # No double dots
not atproto_core_validate_nsid(v, soft_fail=True)
or len(v) > MAX_NSID_LENGTH
or any(c in v for c in '@_*#!') # Explicitly disallow special chars
or not all(seg for seg in v.split('.')) # No empty segments
or any(len(seg) > 63 for seg in v.split('.')) # Max segment length
or any(seg[-1].isdigit() for seg in v.split('.')) # No segments ending in numbers
or any(seg.endswith('-') for seg in v.split('.')) # No segments ending in hyphen
):
raise ValueError(
f'Invalid NSID: must be dot-separated segments (e.g. app.bsky.feed.post) with max length {MAX_NSID_LENGTH}'
Expand Down Expand Up @@ -140,72 +137,46 @@ def validate_cid(v: str, info: ValidationInfo) -> str:

@only_validate_if_strict
def validate_at_uri(v: str, info: ValidationInfo) -> str:
# Check length
if len(v) >= MAX_AT_URI_LENGTH:
raise ValueError(f'Invalid AT-URI: must be under {MAX_AT_URI_LENGTH} chars')

# Check for invalid path patterns
if (
'/./' in v # No dot segments
or '/../' in v
or v.endswith('/') # No trailing slashes
or '#' in v # No fragments
or '?' in v # No query params
or (
'%' in v # Check percent encoding
and not re.match(r'%[0-9A-Fa-f]{2}', v[v.index('%') :])
)
):
raise ValueError(
(
'Invalid AT-URI: must not contain /./, /../, trailing slashes,'
' fragments, queries, or invalid percent encoding'
)
)

parts = v[5:].split('/') # Skip 'at://'

# Must have valid authority
if len(parts) < 1:
raise ValueError('Invalid AT-URI: missing authority')

# If there's a path, must have collection
if len(parts) > 1:
if len(parts) == 2:
try:
validate_nsid(parts[1], info)
except ValueError:
raise ValueError('Invalid AT-URI: collection must be a valid NSID') from None
else:
raise ValueError('Invalid AT-URI: must be in format at://authority/collection[/record]')

# Basic format check still needed for authority validation
if not AT_URI_RE.match(v):
raise ValueError('Invalid AT-URI: invalid authority format')
raise ValueError('Invalid AT-URI: invalid format')

return v


@only_validate_if_strict
def validate_datetime(v: str, info: ValidationInfo) -> str:
# Must contain T separator
if 'T' not in v:
raise ValueError('Invalid datetime: must contain T separator')

# Must have timezone
orig_val = v
v = re.sub(r'([+-][0-9]{2}:[0-9]{2}|Z)$', '', orig_val)
if v == orig_val:
raise ValueError('Invalid datetime: must include timezone (Z or +/-HH:MM)')

# Strip fractional seconds before parsing
v = re.sub(r'\.[0-9]+$', '', v)

# Must contain uppercase T and Z if used
if v != v.strip():
raise ValueError('Invalid datetime: no whitespace allowed')

# Must contain uppercase T
if 'T' not in v or ('z' in v and 'Z' not in v):
raise ValueError('Invalid datetime: must contain uppercase T separator')

# Must have seconds (HH:MM:SS)
time_part = v.split('T')[1]
if len(time_part.split(':')) != 3:
raise ValueError('Invalid datetime: seconds are required')

# If has decimal point, must have digits after it
if '.' in v and not re.search(r'\.\d+', v):
raise ValueError('Invalid datetime: invalid fractional seconds format')

# Must match exactly timezone pattern with nothing after
if v.endswith('-00:00'):
raise ValueError('Invalid datetime: -00:00 timezone not allowed')
if not (re.match(r'.*Z$', v) or re.match(r'.*[+-]\d{2}:\d{2}$', v)):
raise ValueError('Invalid datetime: must include timezone')

# Final validation
try:
datetime.fromisoformat(v)
return orig_val
datetime.fromisoformat(v.replace('Z', '+00:00'))
return v
except ValueError:
raise ValueError('Invalid datetime: must be valid ISO 8601 format') from None
raise ValueError('Invalid datetime: invalid format') from None


@only_validate_if_strict
Expand Down
86 changes: 25 additions & 61 deletions tests/test_atproto_client/models/tests/test_string_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,23 @@


def get_test_cases(filename: str) -> List[str]:
"""Get non-comment, non-empty lines from an interop test file."""
"""Get non-comment, non-empty lines from an interop test file.
Important: Preserves whitespace in test cases. This is critical for
format validators where leading/trailing/internal whitespace makes a
value invalid. For example, ' 1985-04-12T23:20:50.123Z' (with leading space)
should be invalid for datetime validation.
Args:
filename: Name of the test file to read from interop test files directory
Returns:
List of test cases with original whitespace preserved
"""
return [
line.strip()
line
for line in INTEROP_TEST_FILES_DIR.joinpath(filename).read_text().splitlines()
if line.strip() and not line.startswith('#')
if line and not line.startswith('#')
]


Expand Down Expand Up @@ -124,40 +136,17 @@ def invalid_data() -> dict:


@pytest.mark.parametrize(
'validator_type,field_name,error_keywords,invalid_value',
[(string_formats.AtUri, 'at_uri', ['Invalid AT-URI'], c) for c in get_test_cases('aturi_syntax_invalid.txt')]
+ [
(string_formats.DateTime, 'datetime', ['Invalid datetime'], c)
for c in get_test_cases('datetime_syntax_invalid.txt')
]
+ [
(string_formats.Handle, 'handle', 'must be a domain name', c)
for c in get_test_cases('handle_syntax_invalid.txt')
]
+ [
(string_formats.Did, 'did', 'must be in format did:method:identifier', c)
for c in get_test_cases('did_syntax_invalid.txt')
]
+ [
(string_formats.Nsid, 'nsid', 'must be dot-separated segments', c)
for c in get_test_cases('nsid_syntax_invalid.txt')
]
+ [
(string_formats.Tid, 'tid', 'must be exactly 13 lowercase letters/numbers', c)
for c in get_test_cases('tid_syntax_invalid.txt')
]
+ [
(string_formats.RecordKey, 'record_key', 'must contain only alphanumeric', c)
for c in get_test_cases('recordkey_syntax_invalid.txt')
],
'validator_type,field_name,invalid_value',
[(string_formats.AtUri, 'at_uri', c) for c in get_test_cases('aturi_syntax_invalid.txt')]
+ [(string_formats.DateTime, 'datetime', c) for c in get_test_cases('datetime_syntax_invalid.txt')]
+ [(string_formats.Handle, 'handle', c) for c in get_test_cases('handle_syntax_invalid.txt')]
+ [(string_formats.Did, 'did', c) for c in get_test_cases('did_syntax_invalid.txt')]
+ [(string_formats.Nsid, 'nsid', c) for c in get_test_cases('nsid_syntax_invalid.txt')]
+ [(string_formats.Tid, 'tid', c) for c in get_test_cases('tid_syntax_invalid.txt')]
+ [(string_formats.RecordKey, 'record_key', c) for c in get_test_cases('recordkey_syntax_invalid.txt')],
)
def test_string_format_validation(
validator_type: type, field_name: str, error_keywords: List[str], invalid_value: str, valid_data: dict
) -> None:
def test_string_format_validation(validator_type: type, field_name: str, invalid_value: str, valid_data: dict) -> None:
"""Test validation for each string format type."""
if any(invalid_value == skip_value for _, skip_value in SKIP_THESE_VALUES):
pytest.xfail(f'TODO: Fix validation for {invalid_value}')

SomeTypeAdapter = TypeAdapter(validator_type)

# Test that validation is skipped by default
Expand All @@ -168,10 +157,8 @@ def test_string_format_validation(
assert validated_value == valid_data[field_name]

# Test that invalid data fails strict validation
with pytest.raises(ValidationError) as exc_info:
with pytest.raises(ValidationError):
SomeTypeAdapter.validate_python(invalid_value, context={_OPT_IN_KEY: True})
error_msg = str(exc_info.value)
assert any(keyword in error_msg for keyword in error_keywords)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -232,26 +219,3 @@ class FooModel(BaseModel):
assert isinstance(instance, FooModel)
assert instance.handle == invalid_data['handle']
assert instance.did == invalid_data['did']


@pytest.mark.parametrize('validator_type,value', SKIP_THESE_VALUES)
def test_skipped_validation_cases(validator_type: type, value: str) -> None:
"""
Test each skipped case that we suspect is a discrepancy between valid/invalid test files
"""
_TAdapter = TypeAdapter(validator_type)

# Should validate successfully with strict validation
validated = _TAdapter.validate_python(value, context={_OPT_IN_KEY: True})
assert validated == value

# Also verify it appears in the corresponding invalid test file
invalid_filename = {
string_formats.AtUri: 'aturi_syntax_invalid.txt',
string_formats.DateTime: 'datetime_syntax_invalid.txt',
string_formats.Handle: 'handle_syntax_invalid.txt',
string_formats.Nsid: 'nsid_syntax_invalid.txt',
}[validator_type]

invalid_cases = get_test_cases(invalid_filename)
assert value in invalid_cases, f'{value} not found in {invalid_filename} despite being marked as invalid'

0 comments on commit 34444be

Please sign in to comment.