Skip to content

Commit

Permalink
more helpful error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Nov 23, 2024
1 parent 4a1afbc commit 47c3aab
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 28 deletions.
30 changes: 30 additions & 0 deletions examples/advanced_usage/validate_string_formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from atproto_client.models import string_formats
from pydantic import TypeAdapter, ValidationError

some_good_handle = 'test.bsky.social'
some_bad_handle = 'invalid@ @handle'

strict_validation_context = {'strict_string_format': True}
HandleTypeAdapter = TypeAdapter(string_formats.Handle)

assert string_formats._OPT_IN_KEY == 'strict_string_format' # noqa: S101

# values will not be validated if not opting in
sneaky_bad_handle = HandleTypeAdapter.validate_python(some_bad_handle)

assert sneaky_bad_handle == some_bad_handle # noqa: S101

print(f'{sneaky_bad_handle=}\n\n')

# values will be validated if opting in
validated_good_handle = HandleTypeAdapter.validate_python(some_good_handle, context=strict_validation_context)

assert validated_good_handle == some_good_handle # noqa: S101

print(f'{validated_good_handle=}\n\n')

try:
print('Trying to validate a bad handle with strict validation...')
HandleTypeAdapter.validate_python(some_bad_handle, context=strict_validation_context)
except ValidationError as e:
print(e)
38 changes: 23 additions & 15 deletions packages/atproto_client/models/string_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,71 +44,79 @@ def wrapper(v: str, info: ValidationInfo) -> str:
@only_validate_if_strict
def validate_handle(v: str, info: ValidationInfo) -> str:
if not DOMAIN_RE.match(v.lower()) or len(v) > MAX_HANDLE_LENGTH:
raise ValueError('Invalid handle')
raise ValueError(
f'Invalid handle: must be a domain name (e.g. user.bsky.social) with max length {MAX_HANDLE_LENGTH}'
)
return v


@only_validate_if_strict
def validate_did(v: str, info: ValidationInfo) -> str:
if not DID_RE.match(v):
raise ValueError('Invalid DID')
raise ValueError('Invalid DID: must be in format did:method:identifier (e.g. did:plc:1234abcd)')
return v


@only_validate_if_strict
def validate_nsid(v: str, info: ValidationInfo) -> str:
if not NSID_RE.match(v) or '.' not in v or len(v) > MAX_NSID_LENGTH:
raise ValueError('Invalid NSID')
raise ValueError(
f'Invalid NSID: must be dot-separated segments (e.g. app.bsky.feed.post) with max length {MAX_NSID_LENGTH}'
)
return v


@only_validate_if_strict
def validate_language(v: str, info: ValidationInfo) -> str:
if not LANG_RE.match(v):
raise ValueError('Invalid language code')
raise ValueError('Invalid language code: must be ISO language code (e.g. en or en-US)')
return v


@only_validate_if_strict
def validate_record_key(v: str, info: ValidationInfo) -> str:
if v in INVALID_RECORD_KEYS or not RKEY_RE.match(v):
raise ValueError('Invalid record key')
raise ValueError(
'Invalid record key: must contain only alphanumeric, dot, underscore, colon, tilde, or hyphen characters'
)
return v


@only_validate_if_strict
def validate_cid(v: str, info: ValidationInfo) -> str:
if not CID_RE.match(v):
raise ValueError('Invalid CID')
raise ValueError('Invalid CID: must be a valid Content Identifier with minimum length 8')
return v


@only_validate_if_strict
def validate_at_uri(v: str, info: ValidationInfo) -> str:
if len(v) >= MAX_URI_LENGTH or '/./' in v or '/../' in v or v.endswith(('/.', '/..')):
raise ValueError('Invalid AT-URI')
raise ValueError(f'Invalid AT-URI: must be under {MAX_URI_LENGTH} chars and not contain /./ or /../ patterns')
if not AT_URI_RE.match(v):
raise ValueError('Invalid AT-URI format')
raise ValueError(
'Invalid AT-URI: must be in format at://authority/collection/record (e.g. at://user.bsky.social/posts/123)'
)
return v


@only_validate_if_strict
def validate_datetime(v: str, info: ValidationInfo) -> str:
# could just use pydantic_extra_types.pendulum_dt.DateTime but
# see https://github.com/pydantic/pydantic-extra-types/issues/239
if 'T' not in v or not any(v.endswith(x) for x in ('Z', '+00:00')):
raise ValueError('Invalid datetime format')
raise ValueError('Invalid datetime format: must be ISO 8601 with timezone (e.g. 2023-01-01T12:00:00Z)')
try:
datetime.fromisoformat(v.replace('Z', '+00:00'))
return v
except ValueError:
raise ValueError('Invalid datetime format') from None
raise ValueError(
'Invalid datetime format: must be ISO 8601 with timezone (e.g. 2023-01-01T12:00:00Z)'
) from None


@only_validate_if_strict
def validate_tid(v: str, info: ValidationInfo) -> str:
if not TID_RE.match(v):
raise ValueError('Invalid TID format')
raise ValueError(f'Invalid TID format: must be exactly {TID_LENGTH} lowercase letters/numbers')
if ord(v[0]) & 0x40:
raise ValueError('Invalid TID: high bit cannot be 1')
return v
Expand All @@ -117,14 +125,14 @@ def validate_tid(v: str, info: ValidationInfo) -> str:
@only_validate_if_strict
def validate_uri(v: str, info: ValidationInfo) -> str:
if len(v) >= MAX_URI_LENGTH or ' ' in v:
raise ValueError('Invalid URI')
raise ValueError(f'Invalid URI: must be under {MAX_URI_LENGTH} chars and not contain spaces')
parsed = urlparse(v)
if not (
parsed.scheme
and parsed.scheme[0].isalpha()
and (parsed.netloc or parsed.path or parsed.query or parsed.fragment)
):
raise ValueError('Invalid URI')
raise ValueError('Invalid URI: must be a valid URI with scheme and authority/path (e.g. https://example.com)')
return v


Expand Down
24 changes: 11 additions & 13 deletions tests/test_atproto_client/models/tests/test_string_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

@pytest.fixture
def invalid_data():
# TODO: retrieve believable examples of invalid data
return {
'handle': 'invalid@ @handle',
'did': 'not-a-did',
Expand All @@ -16,14 +15,13 @@ def invalid_data():
'datetime': '2023-01-01',
'tid': 'invalid-tid',
'record_key': '..',
'uri': ' invalid uri ',
'uri': 'invalid-uri-no-scheme',
'language': 'invalid!',
}


@pytest.fixture
def valid_data():
# TODO: retrieve real examples of test data
return {
'handle': 'test.bsky.social',
'did': 'did:plc:1234abcd',
Expand All @@ -41,16 +39,16 @@ def valid_data():
@pytest.mark.parametrize(
'type_name,field_name,expected_error',
[
('Handle', 'handle', 'Invalid handle'),
('Did', 'did', 'Invalid DID'),
('Nsid', 'nsid', 'Invalid NSID'),
('AtUri', 'at_uri', 'Invalid AT-URI'),
('Cid', 'cid', 'Invalid CID'),
('DateTime', 'datetime', 'Invalid datetime format'),
('Tid', 'tid', 'Invalid TID format'),
('RecordKey', 'record_key', 'Invalid record key'),
('Uri', 'uri', 'Invalid URI'),
('Language', 'language', 'Invalid language code'),
('Handle', 'handle', 'must be a domain name'),
('Did', 'did', 'must be in format did:method:identifier'),
('Nsid', 'nsid', 'must be dot-separated segments'),
('AtUri', 'at_uri', 'must be in format at://authority/collection/record'),
('Cid', 'cid', 'must be a valid Content Identifier'),
('DateTime', 'datetime', 'must be ISO 8601 with timezone'),
('Tid', 'tid', 'must be exactly 13 lowercase letters/numbers'),
('RecordKey', 'record_key', 'must contain only alphanumeric'),
('Uri', 'uri', 'must be a valid URI with scheme and authority/path'),
('Language', 'language', 'must be ISO language code'),
],
)
def test_string_format_validation(
Expand Down

0 comments on commit 47c3aab

Please sign in to comment.