From 6df18916bdb0fb7cd0779f75ab810479d4b6dcc6 Mon Sep 17 00:00:00 2001 From: Dieter Weber Date: Tue, 3 Sep 2024 11:56:32 +0200 Subject: [PATCH] WIP It now complains about dimensionality mismatch TODO test cases on invalid JSON --- src/libertem_schema/__init__.py | 124 +++++++++++++++++++------------- tests/test_schemas.py | 35 +++++---- 2 files changed, 92 insertions(+), 67 deletions(-) diff --git a/src/libertem_schema/__init__.py b/src/libertem_schema/__init__.py index 14bb7dc..5378407 100644 --- a/src/libertem_schema/__init__.py +++ b/src/libertem_schema/__init__.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple +from typing import Any, Tuple, Sequence from numbers import Number from typing_extensions import Annotated @@ -8,7 +8,13 @@ GetCoreSchemaHandler, GetJsonSchemaHandler, ValidationError, + AfterValidator, + BeforeValidator, + WrapValidator, + ValidationInfo, + ValidatorFunctionWrapHandler, ) + from pydantic.json_schema import JsonSchemaValue import pint @@ -22,62 +28,84 @@ class DimensionError(ValueError): pass -def _get_annotation(reference: pint.Quantity): - - dimensionality = reference.dimensionality - - class Annotation: - @classmethod - def __get_pydantic_core_schema__( - cls, - _source_type: Any, - _handler: GetCoreSchemaHandler, - ) -> core_schema.CoreSchema: - def validate_from_tuple(value: Tuple[Number, str]) -> pint.Quantity: - m, u = value - quantity = m * ureg(u) - result = pint.Quantity(quantity) - print("debug", quantity, dimensionality) - if not result.check(dimensionality): - raise DimensionError( - f"Dimensionality mismatch: Type {type(result)} expected {dimensionality}." - ) - return result - - from_tuple_schema = core_schema.chain_schema( +_pint_base_repr = core_schema.tuple_positional_schema(items_schema=[ + core_schema.float_schema(), + core_schema.str_schema() +]) + + +class PintAnnotation: + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + def validate_from_tuple(value: Tuple[Number, str]) -> pint.Quantity: + m, u = value + return m * ureg(u) + + from_tuple_schema = core_schema.chain_schema( + [ + _pint_base_repr, + core_schema.no_info_plain_validator_function(validate_from_tuple), + ] + ) + + return core_schema.json_or_python_schema( + json_schema=from_tuple_schema, + python_schema=core_schema.chain_schema( [ - core_schema.tuple_positional_schema(items_schema=[ - core_schema.float_schema(), - core_schema.str_schema() - ]), - core_schema.no_info_plain_validator_function(validate_from_tuple), + # check if it's an instance first before doing any further work + core_schema.is_instance_schema(pint.Quantity), ] - ) - - return core_schema.json_or_python_schema( - json_schema=from_tuple_schema, - python_schema=core_schema.union_schema( - [ - # check if it's an instance first before doing any further work - core_schema.is_instance_schema(pint.Quantity), - from_tuple_schema, - ] - ), - serialization=core_schema.plain_serializer_function_ser_schema( - lambda instance: (float(instance.m), str(instance.u)) - ), - ) - return Annotation + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: (float(instance.m), str(instance.u)) + ), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + # Use the same schema that would be used for the tuple + return handler(_pint_base_repr) + + +_length_dim = ureg.meter.dimensionality +_angle_dim = ureg.radian.dimensionality +_pixel_dim = ureg.pixel.dimensionality + + +def _make_handler(dimensionality: str): + def is_matching( + q: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo + ) -> pint.Quantity: + if isinstance(q, pint.Quantity): + pass + elif isinstance(q, Sequence): + m, u = q + # Turn into Quantity: measure * unit + q = m * ureg(u) + else: + raise ValueError(f"Don't know how to interpret type {type(q)}.") + + if not q.check(dimensionality): + raise DimensionError(f"Expected dimensionality {dimensionality}, got quantity {q}.") + return q + + return is_matching Length = Annotated[ - pint.Quantity, _get_annotation(ureg.meter) + pint.Quantity, PintAnnotation, WrapValidator(_make_handler(_length_dim)) ] Angle = Annotated[ - pint.Quantity, _get_annotation(ureg.radian) + pint.Quantity, PintAnnotation, WrapValidator(_make_handler(_angle_dim)) ] Pixel = Annotated[ - pint.Quantity, _get_annotation(ureg.pixel) + pint.Quantity, PintAnnotation, WrapValidator(_make_handler(_pixel_dim)) ] diff --git a/tests/test_schemas.py b/tests/test_schemas.py index e8b40f2..892e06b 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -4,6 +4,7 @@ from pint import UnitRegistry from pydantic_core import from_json +from pydantic import ValidationError from libertem_schema import Simple4DSTEMParams, DimensionError @@ -30,22 +31,18 @@ def test_smoke(): def test_dimensionality(): - params = Simple4DSTEMParams( - overfocus=0.0015 * ureg.degree, # mismatch - scan_pixel_pitch=0.000001 * ureg.meter, - camera_length=0.15 * ureg.meter, - detector_pixel_pitch=0.000050 * ureg.meter, - semiconv=0.020 * ureg.radian, # rad - scan_rotation=330. * ureg.degree, - flip_y=False, - # Offset to avoid subchip gap in butted detectors - cy=(32 - 2) * ureg.pixel, - cx=(32 - 2) * ureg.pixel, - ) - pprint.pprint(params) - assert Simple4DSTEMParams.model_validate(params) - as_json = params.model_dump_json() - print(as_json) - from_j = from_json(as_json) - pprint.pprint(type(from_j['overfocus'])) - assert Simple4DSTEMParams.model_validate(from_j) + with pytest.raises(ValidationError): + Simple4DSTEMParams( + ### + overfocus=0.0015 * ureg.degree, # mismatch + ### + scan_pixel_pitch=0.000001 * ureg.meter, + camera_length=0.15 * ureg.meter, + detector_pixel_pitch=0.000050 * ureg.meter, + semiconv=0.020 * ureg.radian, # rad + scan_rotation=330. * ureg.degree, + flip_y=False, + # Offset to avoid subchip gap in butted detectors + cy=(32 - 2) * ureg.pixel, + cx=(32 - 2) * ureg.pixel, + )