diff --git a/prototypes/testbed.ipynb b/prototypes/testbed.ipynb index 8d3825a..b4db67a 100644 --- a/prototypes/testbed.ipynb +++ b/prototypes/testbed.ipynb @@ -2,15 +2,15 @@ "cells": [ { "cell_type": "code", - "execution_count": 264, + "execution_count": 394, "id": "27ec72cd-9a31-44b1-8573-d381ef192780", "metadata": {}, "outputs": [], "source": [ - "from typing import Any, Sequence, get_args, get_origin, Callable\n", + "from typing import Any, Sequence, get_args, get_origin, Callable, Union\n", "import functools\n", "\n", - "from typing_extensions import Annotated, TypeVar\n", + "from typing_extensions import Annotated, TypeVar, reveal_type, get_original_bases, get_args, get_origin\n", "\n", "import numpy as np\n", "\n", @@ -18,17 +18,171 @@ "\n", "import pydantic\n", "from pydantic_core import core_schema\n", + "import pydantic_core\n", "from pydantic import (\n", " BaseModel,\n", " GetCoreSchemaHandler,\n", " WrapValidator,\n", " ValidationInfo,\n", " ValidatorFunctionWrapHandler,\n", + " Field,\n", ")\n", "\n", "import pint" ] }, + { + "cell_type": "code", + "execution_count": 396, + "id": "299eb9a0-f54e-4a01-9224-cd1b2670e8e9", + "metadata": {}, + "outputs": [], + "source": [ + "T = Annotated[float, Field(gt=0)]" + ] + }, + { + "cell_type": "code", + "execution_count": 397, + "id": "e2b300b3-3c21-40f3-b335-69af458ece03", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "float" + ] + }, + "execution_count": 397, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(T(0.3))" + ] + }, + { + "cell_type": "code", + "execution_count": 411, + "id": "27c8e21d-21a7-454b-8ff3-8e33082aaeda", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3.3" + ] + }, + "execution_count": 411, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "adapter = pydantic.TypeAdapter(T)\n", + "adapter.validate_python(3.3)" + ] + }, + { + "cell_type": "code", + "execution_count": 388, + "id": "046fcd69-74b1-4266-90fe-167e8b2e7b71", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "NDArray[Shape['2 x, 2 y'], , , , , ]" + ] + }, + "execution_count": 388, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "numpydantic.NDArray[numpydantic.Shape['2 x, 2 y'], numpydantic.dtype.Float]" + ] + }, + { + "cell_type": "code", + "execution_count": 412, + "id": "d2606f08-fb69-467f-8ea1-10b33793cf1b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "centimeter" + ], + "text/latex": [ + "$\\mathrm{centimeter}$" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 412, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pint.Quantity(23, 'cm').units" + ] + }, + { + "cell_type": "code", + "execution_count": 387, + "id": "561c5d61-a952-4ec3-92a0-180d6c825dd0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "23 meter" + ], + "text/latex": [ + "$23\\ \\mathrm{meter}$" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 387, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pint.Quantity(value=23, units='m')" + ] + }, + { + "cell_type": "code", + "execution_count": 360, + "id": "79ea0339-3d2d-40da-8a23-8918e7f2a12d", + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "Subscripted generics cannot be used with class and instance checks", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[360], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;43misinstance\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mT\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/ltschema312/lib/python3.12/typing.py:1222\u001b[0m, in \u001b[0;36m_BaseGenericAlias.__instancecheck__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m 1221\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__instancecheck__\u001b[39m(\u001b[38;5;28mself\u001b[39m, obj):\n\u001b[0;32m-> 1222\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__subclasscheck__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mtype\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/ltschema312/lib/python3.12/typing.py:1225\u001b[0m, in \u001b[0;36m_BaseGenericAlias.__subclasscheck__\u001b[0;34m(self, cls)\u001b[0m\n\u001b[1;32m 1224\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__subclasscheck__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28mcls\u001b[39m):\n\u001b[0;32m-> 1225\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSubscripted generics cannot be used with\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1226\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m class and instance checks\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mTypeError\u001b[0m: Subscripted generics cannot be used with class and instance checks" + ] + } + ], + "source": [ + "isinstance(f, T)" + ] + }, { "cell_type": "code", "execution_count": 2, diff --git a/src/libertem_schema/__init__.py b/src/libertem_schema/__init__.py index 4fe976d..acee1e0 100644 --- a/src/libertem_schema/__init__.py +++ b/src/libertem_schema/__init__.py @@ -1,7 +1,7 @@ -from typing import Any, Sequence, Callable +from typing import Any, Sequence, Callable, Union import functools -from typing_extensions import TypeVar, Generic +from typing_extensions import TypeVar, Generic, get_args, get_origin, Annotated import numpydantic import numpy as np @@ -11,7 +11,9 @@ BaseModel, GetCoreSchemaHandler, ValidationInfo, + TypeAdapter ) +from pydantic.errors import PydanticSchemaGenerationError import pint @@ -26,13 +28,13 @@ class DimensionError(ValueError): -def to_tuple(q: pint.Quantity): - base = q.to_base_units() +def to_tuple(q: pint.Quantity, base_units: pint.Unit): + base = q.to(base_units) return (base.magnitude, str(base.units)) -def to_array_tuple(q: pint.Quantity, info: ValidationInfo, array_serializer: Callable): - base = q.to_base_units() +def to_array_tuple(q: pint.Quantity, info: ValidationInfo, base_units: pint.Unit, array_serializer: Callable): + base = q.to(base_units) return (array_serializer(base.magnitude, info=info), str(base.units)) @@ -45,6 +47,12 @@ def get_basic_type(t): # TypeError: Too many parameters for .Single'>; actual 5, expected 1 t = t[0] + origin = get_origin(t) + if origin is not None and issubclass(origin, Annotated): + args = get_args(t) + # First argument is the bas type + t = args[0] + if t in (float, int, complex): return t dtype = np.dtype(t) @@ -63,7 +71,7 @@ def get_schema(t): raise NotImplementedError(t) -def _make_type(reference: pint.Quantity): +def make_type(reference: pint.Quantity): DType = TypeVar('DType') Shape = TypeVar('Shape') @@ -77,8 +85,12 @@ def __get_pydantic_core_schema__( ) -> core_schema.CoreSchema: (dtype, ) = _source_type.__args__ magnitude_schema = get_schema(dtype) - units = str(reference.to_base_units().units) - validate_function = numpydantic.schema.get_validate_interface(Any, dtype) + units = str(reference.units) + try: + adapter = TypeAdapter(dtype) + validate_function = adapter.validate_python + except PydanticSchemaGenerationError as e: + validate_function = numpydantic.schema.get_validate_interface(Any, dtype) target_type = get_basic_type(dtype) def validator(v: Any, info: ValidationInfo) -> pint.Quantity: @@ -86,7 +98,7 @@ def validator(v: Any, info: ValidationInfo) -> pint.Quantity: pass elif isinstance(v, Sequence): magnitude, unit = v - v = pint.Quantity(magnitude, unit) + v = pint.Quantity(value=magnitude, units=unit) else: raise ValueError(f"Don't know how to interpret type {type(v)}.") # Check dimension @@ -113,7 +125,7 @@ def validator(v: Any, info: ValidationInfo) -> pint.Quantity: json_schema=json_schema, python_schema=core_schema.with_info_plain_validator_function(validator), serialization=core_schema.plain_serializer_function_ser_schema( - to_tuple + functools.partial(to_tuple, base_units=reference.units) ), ) @@ -141,7 +153,7 @@ def validator(v: Any, info: core_schema.ValidationInfo) -> pint.Quantity: elif isinstance(v, Sequence): magnitude, unit = v # Turn into Quantity: magnitude * unit - v = pint.Quantity(magnitude=np.asarray(magnitude), unit=unit) + v = pint.Quantity(value=np.asarray(magnitude), units=unit) else: raise ValueError(f"Don't know how to interpret type {type(v)}.") # Check dimension @@ -165,7 +177,11 @@ def validator(v: Any, info: core_schema.ValidationInfo) -> pint.Quantity: core_schema.literal_schema([units]) ]) - serializer = functools.partial(to_array_tuple, array_serializer=magnitude_schema['serialization']['function']) + serializer = functools.partial( + to_array_tuple, + base_units=reference.units, + array_serializer=magnitude_schema['serialization']['function'] + ) return core_schema.json_or_python_schema( json_schema=json_schema, python_schema=core_schema.with_info_plain_validator_function(validator), @@ -177,9 +193,11 @@ def validator(v: Any, info: core_schema.ValidationInfo) -> pint.Quantity: return Single, Array -Length, LengthArray = _make_type(pint.Quantity(1, 'meter')) -Angle, AngleArray = _make_type(pint.Quantity(1, 'radian')) -Pixel, PixelArray = _make_type(pint.Quantity(1, 'pixel')) + +Length, LengthArray = make_type(pint.Quantity(1, 'meter')) +Angle, AngleArray = make_type(pint.Quantity(1, 'radian')) +Pixel, PixelArray = make_type(pint.Quantity(1, 'pixel')) + class Simple4DSTEMParams(BaseModel): ''' diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 22b4139..c957999 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -6,13 +6,15 @@ import jsonschema import numpy as np +from typing_extensions import Annotated, get_origin from pint import UnitRegistry, Quantity from pydantic_core import from_json -from pydantic import ValidationError, BaseModel +from pydantic import ValidationError, BaseModel, PositiveFloat, Field from numpydantic import Shape +import numpydantic.dtype -from libertem_schema import Simple4DSTEMParams, Length, LengthArray +from libertem_schema import Simple4DSTEMParams, Length, LengthArray, make_type ureg = UnitRegistry() @@ -444,3 +446,75 @@ def test_json_schema_dim(): instance=loaded, schema=json_schema ) + + +@pytest.mark.parametrize( + "dtype", ( + float, + numpydantic.dtype.Float, + Annotated[float, Field(strict=False, gt=0)], + np.float64, + PositiveFloat + ) +) +@pytest.mark.parametrize( + "array", (True, False) +) +def test_dtypes(dtype, array): + if dtype is numpydantic.dtype.Float: + pytest.xfail( + "FIXME find out how the numpydantic generic types can be integrated, " + "can't use them as argument as of now" + ) + if array: + origin = get_origin(dtype) + if origin is not None and issubclass(origin, Annotated): + pytest.xfail("FIXME make arrays and pydantic types compatible, somehow.") + + class T(BaseModel): + t: LengthArray[Shape['2 x, 2 y'], dtype] + + t = T(t=Quantity( + np.array([(1., 2.), (3., 4.)]), + 'm' + )) + else: + class T(BaseModel): + t: Length[dtype] + + t = T(t=Quantity(0.3, 'm')) + + json_schema = t.model_json_schema() + pprint.pprint(json_schema) + as_json = t.model_dump_json() + pprint.pprint(as_json) + t.model_validate_json(as_json) + loaded = json.loads(as_json) + t.model_validate(loaded) + jsonschema.validate( + instance=loaded, + schema=json_schema + ) + + +def test_other_unit(): + # we set the base unit to cm + Test, TestArray = make_type(Quantity(1, 'cm')) + + class T(BaseModel): + t: Test[float] + + t = T(t=Quantity(0.3, 'm')) + json_schema = t.model_json_schema() + pprint.pprint(json_schema) + as_json = t.model_dump_json() + pprint.pprint(as_json) + t.model_validate_json(as_json) + loaded = json.loads(as_json) + assert loaded['t'][0] == 30 + assert loaded['t'][1] == 'centimeter' + t.model_validate(loaded) + jsonschema.validate( + instance=loaded, + schema=json_schema + )