Skip to content

Commit

Permalink
WIP It now complains about dimensionality mismatch
Browse files Browse the repository at this point in the history
TODO test cases on invalid JSON
  • Loading branch information
uellue authored and sk1p committed Sep 3, 2024
1 parent bfceeb2 commit 6df1891
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 67 deletions.
124 changes: 76 additions & 48 deletions src/libertem_schema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Tuple
from typing import Any, Tuple, Sequence
from numbers import Number

from typing_extensions import Annotated
Expand All @@ -8,7 +8,13 @@
GetCoreSchemaHandler,
GetJsonSchemaHandler,
ValidationError,
AfterValidator,
BeforeValidator,
WrapValidator,
ValidationInfo,
ValidatorFunctionWrapHandler,
)

from pydantic.json_schema import JsonSchemaValue
import pint

Expand All @@ -22,62 +28,84 @@ class DimensionError(ValueError):
pass


def _get_annotation(reference: pint.Quantity):

dimensionality = reference.dimensionality

class Annotation:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
def validate_from_tuple(value: Tuple[Number, str]) -> pint.Quantity:
m, u = value
quantity = m * ureg(u)
result = pint.Quantity(quantity)
print("debug", quantity, dimensionality)
if not result.check(dimensionality):
raise DimensionError(
f"Dimensionality mismatch: Type {type(result)} expected {dimensionality}."
)
return result

from_tuple_schema = core_schema.chain_schema(
_pint_base_repr = core_schema.tuple_positional_schema(items_schema=[
core_schema.float_schema(),
core_schema.str_schema()
])


class PintAnnotation:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
def validate_from_tuple(value: Tuple[Number, str]) -> pint.Quantity:
m, u = value
return m * ureg(u)

from_tuple_schema = core_schema.chain_schema(
[
_pint_base_repr,
core_schema.no_info_plain_validator_function(validate_from_tuple),
]
)

return core_schema.json_or_python_schema(
json_schema=from_tuple_schema,
python_schema=core_schema.chain_schema(
[
core_schema.tuple_positional_schema(items_schema=[
core_schema.float_schema(),
core_schema.str_schema()
]),
core_schema.no_info_plain_validator_function(validate_from_tuple),
# check if it's an instance first before doing any further work
core_schema.is_instance_schema(pint.Quantity),
]
)

return core_schema.json_or_python_schema(
json_schema=from_tuple_schema,
python_schema=core_schema.union_schema(
[
# check if it's an instance first before doing any further work
core_schema.is_instance_schema(pint.Quantity),
from_tuple_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: (float(instance.m), str(instance.u))
),
)
return Annotation
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: (float(instance.m), str(instance.u))
),
)

@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Use the same schema that would be used for the tuple
return handler(_pint_base_repr)


_length_dim = ureg.meter.dimensionality
_angle_dim = ureg.radian.dimensionality
_pixel_dim = ureg.pixel.dimensionality


def _make_handler(dimensionality: str):
def is_matching(
q: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
) -> pint.Quantity:
if isinstance(q, pint.Quantity):
pass
elif isinstance(q, Sequence):
m, u = q
# Turn into Quantity: measure * unit
q = m * ureg(u)
else:
raise ValueError(f"Don't know how to interpret type {type(q)}.")

if not q.check(dimensionality):
raise DimensionError(f"Expected dimensionality {dimensionality}, got quantity {q}.")
return q

return is_matching


Length = Annotated[
pint.Quantity, _get_annotation(ureg.meter)
pint.Quantity, PintAnnotation, WrapValidator(_make_handler(_length_dim))
]
Angle = Annotated[
pint.Quantity, _get_annotation(ureg.radian)
pint.Quantity, PintAnnotation, WrapValidator(_make_handler(_angle_dim))
]
Pixel = Annotated[
pint.Quantity, _get_annotation(ureg.pixel)
pint.Quantity, PintAnnotation, WrapValidator(_make_handler(_pixel_dim))
]


Expand Down
35 changes: 16 additions & 19 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pint import UnitRegistry
from pydantic_core import from_json
from pydantic import ValidationError

from libertem_schema import Simple4DSTEMParams, DimensionError

Expand All @@ -30,22 +31,18 @@ def test_smoke():


def test_dimensionality():
params = Simple4DSTEMParams(
overfocus=0.0015 * ureg.degree, # mismatch
scan_pixel_pitch=0.000001 * ureg.meter,
camera_length=0.15 * ureg.meter,
detector_pixel_pitch=0.000050 * ureg.meter,
semiconv=0.020 * ureg.radian, # rad
scan_rotation=330. * ureg.degree,
flip_y=False,
# Offset to avoid subchip gap in butted detectors
cy=(32 - 2) * ureg.pixel,
cx=(32 - 2) * ureg.pixel,
)
pprint.pprint(params)
assert Simple4DSTEMParams.model_validate(params)
as_json = params.model_dump_json()
print(as_json)
from_j = from_json(as_json)
pprint.pprint(type(from_j['overfocus']))
assert Simple4DSTEMParams.model_validate(from_j)
with pytest.raises(ValidationError):
Simple4DSTEMParams(
###
overfocus=0.0015 * ureg.degree, # mismatch
###
scan_pixel_pitch=0.000001 * ureg.meter,
camera_length=0.15 * ureg.meter,
detector_pixel_pitch=0.000050 * ureg.meter,
semiconv=0.020 * ureg.radian, # rad
scan_rotation=330. * ureg.degree,
flip_y=False,
# Offset to avoid subchip gap in butted detectors
cy=(32 - 2) * ureg.pixel,
cx=(32 - 2) * ureg.pixel,
)

0 comments on commit 6df1891

Please sign in to comment.