diff --git a/examples/advanced_usage/validate_string_formats.py b/examples/advanced_usage/validate_string_formats.py new file mode 100644 index 00000000..323d8f12 --- /dev/null +++ b/examples/advanced_usage/validate_string_formats.py @@ -0,0 +1,30 @@ +from atproto_client.models import string_formats +from pydantic import TypeAdapter, ValidationError + +some_good_handle = 'test.bsky.social' +some_bad_handle = 'invalid@ @handle' + +strict_validation_context = {'strict_string_format': True} +HandleTypeAdapter = TypeAdapter(string_formats.Handle) + +assert string_formats._OPT_IN_KEY == 'strict_string_format' # noqa: S101 + +# values will not be validated if not opting in +sneaky_bad_handle = HandleTypeAdapter.validate_python(some_bad_handle) + +assert sneaky_bad_handle == some_bad_handle # noqa: S101 + +print(f'{sneaky_bad_handle=}\n\n') + +# values will be validated if opting in +validated_good_handle = HandleTypeAdapter.validate_python(some_good_handle, context=strict_validation_context) + +assert validated_good_handle == some_good_handle # noqa: S101 + +print(f'{validated_good_handle=}\n\n') + +try: + print('Trying to validate a bad handle with strict validation...') + HandleTypeAdapter.validate_python(some_bad_handle, context=strict_validation_context) +except ValidationError as e: + print(e) diff --git a/packages/atproto_client/models/string_formats.py b/packages/atproto_client/models/string_formats.py index 62eaa66f..19da26d8 100644 --- a/packages/atproto_client/models/string_formats.py +++ b/packages/atproto_client/models/string_formats.py @@ -44,71 +44,79 @@ def wrapper(v: str, info: ValidationInfo) -> str: @only_validate_if_strict def validate_handle(v: str, info: ValidationInfo) -> str: if not DOMAIN_RE.match(v.lower()) or len(v) > MAX_HANDLE_LENGTH: - raise ValueError('Invalid handle') + raise ValueError( + f'Invalid handle: must be a domain name (e.g. user.bsky.social) with max length {MAX_HANDLE_LENGTH}' + ) return v @only_validate_if_strict def validate_did(v: str, info: ValidationInfo) -> str: if not DID_RE.match(v): - raise ValueError('Invalid DID') + raise ValueError('Invalid DID: must be in format did:method:identifier (e.g. did:plc:1234abcd)') return v @only_validate_if_strict def validate_nsid(v: str, info: ValidationInfo) -> str: if not NSID_RE.match(v) or '.' not in v or len(v) > MAX_NSID_LENGTH: - raise ValueError('Invalid NSID') + raise ValueError( + f'Invalid NSID: must be dot-separated segments (e.g. app.bsky.feed.post) with max length {MAX_NSID_LENGTH}' + ) return v @only_validate_if_strict def validate_language(v: str, info: ValidationInfo) -> str: if not LANG_RE.match(v): - raise ValueError('Invalid language code') + raise ValueError('Invalid language code: must be ISO language code (e.g. en or en-US)') return v @only_validate_if_strict def validate_record_key(v: str, info: ValidationInfo) -> str: if v in INVALID_RECORD_KEYS or not RKEY_RE.match(v): - raise ValueError('Invalid record key') + raise ValueError( + 'Invalid record key: must contain only alphanumeric, dot, underscore, colon, tilde, or hyphen characters' + ) return v @only_validate_if_strict def validate_cid(v: str, info: ValidationInfo) -> str: if not CID_RE.match(v): - raise ValueError('Invalid CID') + raise ValueError('Invalid CID: must be a valid Content Identifier with minimum length 8') return v @only_validate_if_strict def validate_at_uri(v: str, info: ValidationInfo) -> str: if len(v) >= MAX_URI_LENGTH or '/./' in v or '/../' in v or v.endswith(('/.', '/..')): - raise ValueError('Invalid AT-URI') + raise ValueError(f'Invalid AT-URI: must be under {MAX_URI_LENGTH} chars and not contain /./ or /../ patterns') if not AT_URI_RE.match(v): - raise ValueError('Invalid AT-URI format') + raise ValueError( + 'Invalid AT-URI: must be in format at://authority/collection/record (e.g. at://user.bsky.social/posts/123)' + ) return v @only_validate_if_strict def validate_datetime(v: str, info: ValidationInfo) -> str: - # could just use pydantic_extra_types.pendulum_dt.DateTime but - # see https://github.com/pydantic/pydantic-extra-types/issues/239 if 'T' not in v or not any(v.endswith(x) for x in ('Z', '+00:00')): - raise ValueError('Invalid datetime format') + raise ValueError('Invalid datetime format: must be ISO 8601 with timezone (e.g. 2023-01-01T12:00:00Z)') try: datetime.fromisoformat(v.replace('Z', '+00:00')) return v except ValueError: - raise ValueError('Invalid datetime format') from None + raise ValueError( + 'Invalid datetime format: must be ISO 8601 with timezone (e.g. 2023-01-01T12:00:00Z)' + ) from None @only_validate_if_strict def validate_tid(v: str, info: ValidationInfo) -> str: if not TID_RE.match(v): - raise ValueError('Invalid TID format') + raise ValueError(f'Invalid TID format: must be exactly {TID_LENGTH} lowercase letters/numbers') if ord(v[0]) & 0x40: raise ValueError('Invalid TID: high bit cannot be 1') return v @@ -117,14 +125,14 @@ def validate_tid(v: str, info: ValidationInfo) -> str: @only_validate_if_strict def validate_uri(v: str, info: ValidationInfo) -> str: if len(v) >= MAX_URI_LENGTH or ' ' in v: - raise ValueError('Invalid URI') + raise ValueError(f'Invalid URI: must be under {MAX_URI_LENGTH} chars and not contain spaces') parsed = urlparse(v) if not ( parsed.scheme and parsed.scheme[0].isalpha() and (parsed.netloc or parsed.path or parsed.query or parsed.fragment) ): - raise ValueError('Invalid URI') + raise ValueError('Invalid URI: must be a valid URI with scheme and authority/path (e.g. https://example.com)') return v diff --git a/tests/test_atproto_client/models/tests/test_string_formats.py b/tests/test_atproto_client/models/tests/test_string_formats.py index 69564493..28b8024e 100644 --- a/tests/test_atproto_client/models/tests/test_string_formats.py +++ b/tests/test_atproto_client/models/tests/test_string_formats.py @@ -6,7 +6,6 @@ @pytest.fixture def invalid_data(): - # TODO: retrieve believable examples of invalid data return { 'handle': 'invalid@ @handle', 'did': 'not-a-did', @@ -16,14 +15,13 @@ def invalid_data(): 'datetime': '2023-01-01', 'tid': 'invalid-tid', 'record_key': '..', - 'uri': ' invalid uri ', + 'uri': 'invalid-uri-no-scheme', 'language': 'invalid!', } @pytest.fixture def valid_data(): - # TODO: retrieve real examples of test data return { 'handle': 'test.bsky.social', 'did': 'did:plc:1234abcd', @@ -41,16 +39,16 @@ def valid_data(): @pytest.mark.parametrize( 'type_name,field_name,expected_error', [ - ('Handle', 'handle', 'Invalid handle'), - ('Did', 'did', 'Invalid DID'), - ('Nsid', 'nsid', 'Invalid NSID'), - ('AtUri', 'at_uri', 'Invalid AT-URI'), - ('Cid', 'cid', 'Invalid CID'), - ('DateTime', 'datetime', 'Invalid datetime format'), - ('Tid', 'tid', 'Invalid TID format'), - ('RecordKey', 'record_key', 'Invalid record key'), - ('Uri', 'uri', 'Invalid URI'), - ('Language', 'language', 'Invalid language code'), + ('Handle', 'handle', 'must be a domain name'), + ('Did', 'did', 'must be in format did:method:identifier'), + ('Nsid', 'nsid', 'must be dot-separated segments'), + ('AtUri', 'at_uri', 'must be in format at://authority/collection/record'), + ('Cid', 'cid', 'must be a valid Content Identifier'), + ('DateTime', 'datetime', 'must be ISO 8601 with timezone'), + ('Tid', 'tid', 'must be exactly 13 lowercase letters/numbers'), + ('RecordKey', 'record_key', 'must contain only alphanumeric'), + ('Uri', 'uri', 'must be a valid URI with scheme and authority/path'), + ('Language', 'language', 'must be ISO language code'), ], ) def test_string_format_validation(