diff --git a/packages/atproto_client/models/string_formats.py b/packages/atproto_client/models/string_formats.py index d0df87d3..b46a989d 100644 --- a/packages/atproto_client/models/string_formats.py +++ b/packages/atproto_client/models/string_formats.py @@ -3,6 +3,7 @@ from typing import Callable, Mapping, Set, Union from urllib.parse import urlparse +from atproto_core.nsid import validate_nsid as atproto_core_validate_nsid from pydantic import BeforeValidator, Field, ValidationInfo from pydantic_core import core_schema from typing_extensions import Annotated, Literal @@ -28,26 +29,25 @@ r'[a-zA-Z0-9._%-]{1,2048}' # method-specific-id with length limit r'(? str: @only_validate_if_strict def validate_nsid(v: str, info: ValidationInfo) -> str: if ( - not NSID_RE.match(v) - or '..' in v # No double dots + not atproto_core_validate_nsid(v, soft_fail=True) or len(v) > MAX_NSID_LENGTH or any(c in v for c in '@_*#!') # Explicitly disallow special chars - or not all(seg for seg in v.split('.')) # No empty segments or any(len(seg) > 63 for seg in v.split('.')) # Max segment length or any(seg[-1].isdigit() for seg in v.split('.')) # No segments ending in numbers - or any(seg.endswith('-') for seg in v.split('.')) # No segments ending in hyphen ): raise ValueError( f'Invalid NSID: must be dot-separated segments (e.g. app.bsky.feed.post) with max length {MAX_NSID_LENGTH}' @@ -140,72 +137,46 @@ def validate_cid(v: str, info: ValidationInfo) -> str: @only_validate_if_strict def validate_at_uri(v: str, info: ValidationInfo) -> str: - # Check length if len(v) >= MAX_AT_URI_LENGTH: raise ValueError(f'Invalid AT-URI: must be under {MAX_AT_URI_LENGTH} chars') - # Check for invalid path patterns - if ( - '/./' in v # No dot segments - or '/../' in v - or v.endswith('/') # No trailing slashes - or '#' in v # No fragments - or '?' in v # No query params - or ( - '%' in v # Check percent encoding - and not re.match(r'%[0-9A-Fa-f]{2}', v[v.index('%') :]) - ) - ): - raise ValueError( - ( - 'Invalid AT-URI: must not contain /./, /../, trailing slashes,' - ' fragments, queries, or invalid percent encoding' - ) - ) - - parts = v[5:].split('/') # Skip 'at://' - - # Must have valid authority - if len(parts) < 1: - raise ValueError('Invalid AT-URI: missing authority') - - # If there's a path, must have collection - if len(parts) > 1: - if len(parts) == 2: - try: - validate_nsid(parts[1], info) - except ValueError: - raise ValueError('Invalid AT-URI: collection must be a valid NSID') from None - else: - raise ValueError('Invalid AT-URI: must be in format at://authority/collection[/record]') - - # Basic format check still needed for authority validation if not AT_URI_RE.match(v): - raise ValueError('Invalid AT-URI: invalid authority format') + raise ValueError('Invalid AT-URI: invalid format') return v @only_validate_if_strict def validate_datetime(v: str, info: ValidationInfo) -> str: - # Must contain T separator - if 'T' not in v: - raise ValueError('Invalid datetime: must contain T separator') - - # Must have timezone - orig_val = v - v = re.sub(r'([+-][0-9]{2}:[0-9]{2}|Z)$', '', orig_val) - if v == orig_val: - raise ValueError('Invalid datetime: must include timezone (Z or +/-HH:MM)') - - # Strip fractional seconds before parsing - v = re.sub(r'\.[0-9]+$', '', v) - + # Must contain uppercase T and Z if used + if v != v.strip(): + raise ValueError('Invalid datetime: no whitespace allowed') + + # Must contain uppercase T + if 'T' not in v or ('z' in v and 'Z' not in v): + raise ValueError('Invalid datetime: must contain uppercase T separator') + + # Must have seconds (HH:MM:SS) + time_part = v.split('T')[1] + if len(time_part.split(':')) != 3: + raise ValueError('Invalid datetime: seconds are required') + + # If has decimal point, must have digits after it + if '.' in v and not re.search(r'\.\d+', v): + raise ValueError('Invalid datetime: invalid fractional seconds format') + + # Must match exactly timezone pattern with nothing after + if v.endswith('-00:00'): + raise ValueError('Invalid datetime: -00:00 timezone not allowed') + if not (re.match(r'.*Z$', v) or re.match(r'.*[+-]\d{2}:\d{2}$', v)): + raise ValueError('Invalid datetime: must include timezone') + + # Final validation try: - datetime.fromisoformat(v) - return orig_val + datetime.fromisoformat(v.replace('Z', '+00:00')) + return v except ValueError: - raise ValueError('Invalid datetime: must be valid ISO 8601 format') from None + raise ValueError('Invalid datetime: invalid format') from None @only_validate_if_strict diff --git a/tests/test_atproto_client/models/tests/test_string_formats.py b/tests/test_atproto_client/models/tests/test_string_formats.py index ba983153..cb7ce7c4 100644 --- a/tests/test_atproto_client/models/tests/test_string_formats.py +++ b/tests/test_atproto_client/models/tests/test_string_formats.py @@ -40,11 +40,23 @@ def get_test_cases(filename: str) -> List[str]: - """Get non-comment, non-empty lines from an interop test file.""" + """Get non-comment, non-empty lines from an interop test file. + + Important: Preserves whitespace in test cases. This is critical for + format validators where leading/trailing/internal whitespace makes a + value invalid. For example, ' 1985-04-12T23:20:50.123Z' (with leading space) + should be invalid for datetime validation. + + Args: + filename: Name of the test file to read from interop test files directory + + Returns: + List of test cases with original whitespace preserved + """ return [ - line.strip() + line for line in INTEROP_TEST_FILES_DIR.joinpath(filename).read_text().splitlines() - if line.strip() and not line.startswith('#') + if line and not line.startswith('#') ] @@ -124,40 +136,17 @@ def invalid_data() -> dict: @pytest.mark.parametrize( - 'validator_type,field_name,error_keywords,invalid_value', - [(string_formats.AtUri, 'at_uri', ['Invalid AT-URI'], c) for c in get_test_cases('aturi_syntax_invalid.txt')] - + [ - (string_formats.DateTime, 'datetime', ['Invalid datetime'], c) - for c in get_test_cases('datetime_syntax_invalid.txt') - ] - + [ - (string_formats.Handle, 'handle', 'must be a domain name', c) - for c in get_test_cases('handle_syntax_invalid.txt') - ] - + [ - (string_formats.Did, 'did', 'must be in format did:method:identifier', c) - for c in get_test_cases('did_syntax_invalid.txt') - ] - + [ - (string_formats.Nsid, 'nsid', 'must be dot-separated segments', c) - for c in get_test_cases('nsid_syntax_invalid.txt') - ] - + [ - (string_formats.Tid, 'tid', 'must be exactly 13 lowercase letters/numbers', c) - for c in get_test_cases('tid_syntax_invalid.txt') - ] - + [ - (string_formats.RecordKey, 'record_key', 'must contain only alphanumeric', c) - for c in get_test_cases('recordkey_syntax_invalid.txt') - ], + 'validator_type,field_name,invalid_value', + [(string_formats.AtUri, 'at_uri', c) for c in get_test_cases('aturi_syntax_invalid.txt')] + + [(string_formats.DateTime, 'datetime', c) for c in get_test_cases('datetime_syntax_invalid.txt')] + + [(string_formats.Handle, 'handle', c) for c in get_test_cases('handle_syntax_invalid.txt')] + + [(string_formats.Did, 'did', c) for c in get_test_cases('did_syntax_invalid.txt')] + + [(string_formats.Nsid, 'nsid', c) for c in get_test_cases('nsid_syntax_invalid.txt')] + + [(string_formats.Tid, 'tid', c) for c in get_test_cases('tid_syntax_invalid.txt')] + + [(string_formats.RecordKey, 'record_key', c) for c in get_test_cases('recordkey_syntax_invalid.txt')], ) -def test_string_format_validation( - validator_type: type, field_name: str, error_keywords: List[str], invalid_value: str, valid_data: dict -) -> None: +def test_string_format_validation(validator_type: type, field_name: str, invalid_value: str, valid_data: dict) -> None: """Test validation for each string format type.""" - if any(invalid_value == skip_value for _, skip_value in SKIP_THESE_VALUES): - pytest.xfail(f'TODO: Fix validation for {invalid_value}') - SomeTypeAdapter = TypeAdapter(validator_type) # Test that validation is skipped by default @@ -168,10 +157,8 @@ def test_string_format_validation( assert validated_value == valid_data[field_name] # Test that invalid data fails strict validation - with pytest.raises(ValidationError) as exc_info: + with pytest.raises(ValidationError): SomeTypeAdapter.validate_python(invalid_value, context={_OPT_IN_KEY: True}) - error_msg = str(exc_info.value) - assert any(keyword in error_msg for keyword in error_keywords) @pytest.mark.parametrize( @@ -232,26 +219,3 @@ class FooModel(BaseModel): assert isinstance(instance, FooModel) assert instance.handle == invalid_data['handle'] assert instance.did == invalid_data['did'] - - -@pytest.mark.parametrize('validator_type,value', SKIP_THESE_VALUES) -def test_skipped_validation_cases(validator_type: type, value: str) -> None: - """ - Test each skipped case that we suspect is a discrepancy between valid/invalid test files - """ - _TAdapter = TypeAdapter(validator_type) - - # Should validate successfully with strict validation - validated = _TAdapter.validate_python(value, context={_OPT_IN_KEY: True}) - assert validated == value - - # Also verify it appears in the corresponding invalid test file - invalid_filename = { - string_formats.AtUri: 'aturi_syntax_invalid.txt', - string_formats.DateTime: 'datetime_syntax_invalid.txt', - string_formats.Handle: 'handle_syntax_invalid.txt', - string_formats.Nsid: 'nsid_syntax_invalid.txt', - }[validator_type] - - invalid_cases = get_test_cases(invalid_filename) - assert value in invalid_cases, f'{value} not found in {invalid_filename} despite being marked as invalid'