Skip to content

Commit

Permalink
Further improvements
Browse files Browse the repository at this point in the history
* Reference quantity defines base unit for schema, allows for more flexibility
* Interoperability with both pydantic-style and numpydantic-style numeric types for individual field values

Types like numpydantic.dtype.Float still TBD, don't work with `Generic`
  • Loading branch information
uellue committed Sep 18, 2024
1 parent 9cc0555 commit bcf334a
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 21 deletions.
160 changes: 157 additions & 3 deletions prototypes/testbed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,187 @@
"cells": [
{
"cell_type": "code",
"execution_count": 264,
"execution_count": 394,
"id": "27ec72cd-9a31-44b1-8573-d381ef192780",
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, Sequence, get_args, get_origin, Callable\n",
"from typing import Any, Sequence, get_args, get_origin, Callable, Union\n",
"import functools\n",
"\n",
"from typing_extensions import Annotated, TypeVar\n",
"from typing_extensions import Annotated, TypeVar, reveal_type, get_original_bases, get_args, get_origin\n",
"\n",
"import numpy as np\n",
"\n",
"import numpydantic\n",
"\n",
"import pydantic\n",
"from pydantic_core import core_schema\n",
"import pydantic_core\n",
"from pydantic import (\n",
" BaseModel,\n",
" GetCoreSchemaHandler,\n",
" WrapValidator,\n",
" ValidationInfo,\n",
" ValidatorFunctionWrapHandler,\n",
" Field,\n",
")\n",
"\n",
"import pint"
]
},
{
"cell_type": "code",
"execution_count": 396,
"id": "299eb9a0-f54e-4a01-9224-cd1b2670e8e9",
"metadata": {},
"outputs": [],
"source": [
"T = Annotated[float, Field(gt=0)]"
]
},
{
"cell_type": "code",
"execution_count": 397,
"id": "e2b300b3-3c21-40f3-b335-69af458ece03",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"float"
]
},
"execution_count": 397,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(T(0.3))"
]
},
{
"cell_type": "code",
"execution_count": 411,
"id": "27c8e21d-21a7-454b-8ff3-8e33082aaeda",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3.3"
]
},
"execution_count": 411,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"adapter = pydantic.TypeAdapter(T)\n",
"adapter.validate_python(3.3)"
]
},
{
"cell_type": "code",
"execution_count": 388,
"id": "046fcd69-74b1-4266-90fe-167e8b2e7b71",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"NDArray[Shape['2 x, 2 y'], <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.float32'>, <class 'numpy.float64'>]"
]
},
"execution_count": 388,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"numpydantic.NDArray[numpydantic.Shape['2 x, 2 y'], numpydantic.dtype.Float]"
]
},
{
"cell_type": "code",
"execution_count": 412,
"id": "d2606f08-fb69-467f-8ea1-10b33793cf1b",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"centimeter"
],
"text/latex": [
"$\\mathrm{centimeter}$"
],
"text/plain": [
"<Unit('centimeter')>"
]
},
"execution_count": 412,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pint.Quantity(23, 'cm').units"
]
},
{
"cell_type": "code",
"execution_count": 387,
"id": "561c5d61-a952-4ec3-92a0-180d6c825dd0",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"23 meter"
],
"text/latex": [
"$23\\ \\mathrm{meter}$"
],
"text/plain": [
"<Quantity(23, 'meter')>"
]
},
"execution_count": 387,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pint.Quantity(value=23, units='m')"
]
},
{
"cell_type": "code",
"execution_count": 360,
"id": "79ea0339-3d2d-40da-8a23-8918e7f2a12d",
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "Subscripted generics cannot be used with class and instance checks",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[360], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;43misinstance\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mT\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniconda3/envs/ltschema312/lib/python3.12/typing.py:1222\u001b[0m, in \u001b[0;36m_BaseGenericAlias.__instancecheck__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m 1221\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__instancecheck__\u001b[39m(\u001b[38;5;28mself\u001b[39m, obj):\n\u001b[0;32m-> 1222\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__subclasscheck__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mtype\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniconda3/envs/ltschema312/lib/python3.12/typing.py:1225\u001b[0m, in \u001b[0;36m_BaseGenericAlias.__subclasscheck__\u001b[0;34m(self, cls)\u001b[0m\n\u001b[1;32m 1224\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__subclasscheck__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28mcls\u001b[39m):\n\u001b[0;32m-> 1225\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSubscripted generics cannot be used with\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1226\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m class and instance checks\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[0;31mTypeError\u001b[0m: Subscripted generics cannot be used with class and instance checks"
]
}
],
"source": [
"isinstance(f, T)"
]
},
{
"cell_type": "code",
"execution_count": 2,
Expand Down
50 changes: 34 additions & 16 deletions src/libertem_schema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Sequence, Callable
from typing import Any, Sequence, Callable, Union
import functools

from typing_extensions import TypeVar, Generic
from typing_extensions import TypeVar, Generic, get_args, get_origin, Annotated

import numpydantic
import numpy as np
Expand All @@ -11,7 +11,9 @@
BaseModel,
GetCoreSchemaHandler,
ValidationInfo,
TypeAdapter
)
from pydantic.errors import PydanticSchemaGenerationError

import pint

Expand All @@ -26,13 +28,13 @@ class DimensionError(ValueError):



def to_tuple(q: pint.Quantity):
base = q.to_base_units()
def to_tuple(q: pint.Quantity, base_units: pint.Unit):
base = q.to(base_units)
return (base.magnitude, str(base.units))


def to_array_tuple(q: pint.Quantity, info: ValidationInfo, array_serializer: Callable):
base = q.to_base_units()
def to_array_tuple(q: pint.Quantity, info: ValidationInfo, base_units: pint.Unit, array_serializer: Callable):
base = q.to(base_units)
return (array_serializer(base.magnitude, info=info), str(base.units))


Expand All @@ -45,6 +47,12 @@ def get_basic_type(t):
# TypeError: Too many parameters for <class
# 'libertem_schema._make_type.<locals>.Single'>; actual 5, expected 1
t = t[0]
origin = get_origin(t)
if origin is not None and issubclass(origin, Annotated):
args = get_args(t)
# First argument is the bas type
t = args[0]

if t in (float, int, complex):
return t
dtype = np.dtype(t)
Expand All @@ -63,7 +71,7 @@ def get_schema(t):
raise NotImplementedError(t)


def _make_type(reference: pint.Quantity):
def make_type(reference: pint.Quantity):

DType = TypeVar('DType')
Shape = TypeVar('Shape')
Expand All @@ -77,16 +85,20 @@ def __get_pydantic_core_schema__(
) -> core_schema.CoreSchema:
(dtype, ) = _source_type.__args__
magnitude_schema = get_schema(dtype)
units = str(reference.to_base_units().units)
validate_function = numpydantic.schema.get_validate_interface(Any, dtype)
units = str(reference.units)
try:
adapter = TypeAdapter(dtype)
validate_function = adapter.validate_python
except PydanticSchemaGenerationError as e:
validate_function = numpydantic.schema.get_validate_interface(Any, dtype)
target_type = get_basic_type(dtype)

def validator(v: Any, info: ValidationInfo) -> pint.Quantity:
if isinstance(v, pint.Quantity):
pass
elif isinstance(v, Sequence):
magnitude, unit = v
v = pint.Quantity(magnitude, unit)
v = pint.Quantity(value=magnitude, units=unit)
else:
raise ValueError(f"Don't know how to interpret type {type(v)}.")
# Check dimension
Expand All @@ -113,7 +125,7 @@ def validator(v: Any, info: ValidationInfo) -> pint.Quantity:
json_schema=json_schema,
python_schema=core_schema.with_info_plain_validator_function(validator),
serialization=core_schema.plain_serializer_function_ser_schema(
to_tuple
functools.partial(to_tuple, base_units=reference.units)
),
)

Expand Down Expand Up @@ -141,7 +153,7 @@ def validator(v: Any, info: core_schema.ValidationInfo) -> pint.Quantity:
elif isinstance(v, Sequence):
magnitude, unit = v
# Turn into Quantity: magnitude * unit
v = pint.Quantity(magnitude=np.asarray(magnitude), unit=unit)
v = pint.Quantity(value=np.asarray(magnitude), units=unit)
else:
raise ValueError(f"Don't know how to interpret type {type(v)}.")
# Check dimension
Expand All @@ -165,7 +177,11 @@ def validator(v: Any, info: core_schema.ValidationInfo) -> pint.Quantity:
core_schema.literal_schema([units])
])

serializer = functools.partial(to_array_tuple, array_serializer=magnitude_schema['serialization']['function'])
serializer = functools.partial(
to_array_tuple,
base_units=reference.units,
array_serializer=magnitude_schema['serialization']['function']
)
return core_schema.json_or_python_schema(
json_schema=json_schema,
python_schema=core_schema.with_info_plain_validator_function(validator),
Expand All @@ -177,9 +193,11 @@ def validator(v: Any, info: core_schema.ValidationInfo) -> pint.Quantity:

return Single, Array

Length, LengthArray = _make_type(pint.Quantity(1, 'meter'))
Angle, AngleArray = _make_type(pint.Quantity(1, 'radian'))
Pixel, PixelArray = _make_type(pint.Quantity(1, 'pixel'))

Length, LengthArray = make_type(pint.Quantity(1, 'meter'))
Angle, AngleArray = make_type(pint.Quantity(1, 'radian'))
Pixel, PixelArray = make_type(pint.Quantity(1, 'pixel'))


class Simple4DSTEMParams(BaseModel):
'''
Expand Down
Loading

0 comments on commit bcf334a

Please sign in to comment.