From bc124a718c779ec95023459d5121cf71d8c6382a Mon Sep 17 00:00:00 2001 From: Dieter Weber Date: Fri, 13 Sep 2024 19:03:29 +0200 Subject: [PATCH] Integrate with numpydantic Allow specifying dtype and array shapes The dtype is broken down to Python basic types. TODO test and probably implement support for all the nifty tools of pydantic et al, specifying value ranges etc --- prototypes/testbed.ipynb | 668 ++++++++++++++++++++++++++++++++ pyproject.toml | 4 +- src/libertem_schema/__init__.py | 232 +++++++---- test_requirements.txt | 3 + tests/test_schemas.py | 102 ++++- 5 files changed, 934 insertions(+), 75 deletions(-) create mode 100644 prototypes/testbed.ipynb create mode 100644 test_requirements.txt diff --git a/prototypes/testbed.ipynb b/prototypes/testbed.ipynb new file mode 100644 index 0000000..1481651 --- /dev/null +++ b/prototypes/testbed.ipynb @@ -0,0 +1,668 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 264, + "id": "27ec72cd-9a31-44b1-8573-d381ef192780", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any, Sequence, get_args, get_origin, Callable\n", + "import functools\n", + "\n", + "from typing_extensions import Annotated, TypeVar\n", + "\n", + "import numpy as np\n", + "\n", + "import numpydantic\n", + "\n", + "import pydantic\n", + "from pydantic_core import core_schema\n", + "from pydantic import (\n", + " BaseModel,\n", + " GetCoreSchemaHandler,\n", + " WrapValidator,\n", + " ValidationInfo,\n", + " ValidatorFunctionWrapHandler,\n", + ")\n", + "\n", + "import pint" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2b14c7cd-c1ff-4274-bab3-3cb8a9dd31d3", + "metadata": {}, + "outputs": [], + "source": [ + "ureg = pint.UnitRegistry()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0bd4e497-4f3b-4fa2-aae7-cda17c08ccbe", + "metadata": {}, + "outputs": [], + "source": [ + "class DimensionError(ValueError):\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "2cd7e50f-6f2c-4eb5-8ce1-462402539d56", + "metadata": {}, + "outputs": [], + "source": [ + "_pint_base_repr = core_schema.tuple_positional_schema(items_schema=[\n", + " core_schema.float_schema(),\n", + " core_schema.str_schema()\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "id": "35801fc1-94b1-48b1-aa45-58353d59862c", + "metadata": {}, + "outputs": [ + { + "ename": "KeyError", + "evalue": "(, , , , )", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[94], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mnumpydantic\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmaps\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnp_to_python\u001b[49m\u001b[43m[\u001b[49m\u001b[43mnumpydantic\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mFloat\u001b[49m\u001b[43m]\u001b[49m\n", + "\u001b[0;31mKeyError\u001b[0m: (, , , , )" + ] + } + ], + "source": [ + "numpydantic.maps.np_to_python[numpydantic.dtype.Float]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "4fdb7df5-c483-495a-ace8-61da166aae16", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'type': 'json-or-python',\n", + " 'json_schema': {'type': 'list',\n", + " 'items_schema': {'type': 'list',\n", + " 'items_schema': {'type': 'float'},\n", + " 'min_length': 2,\n", + " 'max_length': 2,\n", + " 'metadata': {'name': 'y'}},\n", + " 'min_length': 2,\n", + " 'max_length': 2,\n", + " 'metadata': {'name': 'x'}},\n", + " 'python_schema': {'type': 'function-plain',\n", + " 'function': {'type': 'with-info',\n", + " 'function': .validate_interface(value: Any, info: Optional[ForwardRef('ValidationInfo')] = None) -> numpydantic.types.NDArrayType>}},\n", + " 'serialization': {'type': 'function-plain',\n", + " 'function': Union[list, dict]>,\n", + " 'info_arg': True,\n", + " 'when_used': 'json'}}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "numpydantic.NDArray.__get_pydantic_core_schema__(_source_type=numpydantic.NDArray[numpydantic.Shape['2 x, 2 y'], numpydantic.dtype.Float], _handler=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 109, + "id": "81a19c2f-89c6-43fe-8827-a4b5ee95c64a", + "metadata": {}, + "outputs": [ + { + "ename": "DtypeError", + "evalue": "Invalid dtype! expected (, , , , , , , , , ), got float64", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mDtypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[109], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mnumpydantic\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mschema\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_validate_interface\u001b[49m\u001b[43m(\u001b[49m\u001b[43mAny\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnumpydantic\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mInt\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m2.3\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/ltschema312/lib/python3.12/site-packages/numpydantic/schema.py:251\u001b[0m, in \u001b[0;36mget_validate_interface..validate_interface\u001b[0;34m(value, info)\u001b[0m\n\u001b[1;32m 249\u001b[0m interface_cls \u001b[38;5;241m=\u001b[39m Interface\u001b[38;5;241m.\u001b[39mmatch(value)\n\u001b[1;32m 250\u001b[0m interface \u001b[38;5;241m=\u001b[39m interface_cls(shape, dtype)\n\u001b[0;32m--> 251\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[43minterface\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 252\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value\n", + "File \u001b[0;32m~/miniconda3/envs/ltschema312/lib/python3.12/site-packages/numpydantic/interface/interface.py:81\u001b[0m, in \u001b[0;36mInterface.validate\u001b[0;34m(self, array)\u001b[0m\n\u001b[1;32m 79\u001b[0m dtype \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_dtype(array)\n\u001b[1;32m 80\u001b[0m dtype_valid \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvalidate_dtype(dtype)\n\u001b[0;32m---> 81\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraise_for_dtype\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdtype_valid\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 82\u001b[0m array \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mafter_validate_dtype(array)\n\u001b[1;32m 84\u001b[0m shape \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_shape(array)\n", + "File \u001b[0;32m~/miniconda3/envs/ltschema312/lib/python3.12/site-packages/numpydantic/interface/interface.py:150\u001b[0m, in \u001b[0;36mInterface.raise_for_dtype\u001b[0;34m(self, valid, dtype)\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;124;03mAfter validating, raise an exception if invalid\u001b[39;00m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;124;03mRaises:\u001b[39;00m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;124;03m :class:`~numpydantic.exceptions.DtypeError`\u001b[39;00m\n\u001b[1;32m 148\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m valid:\n\u001b[0;32m--> 150\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m DtypeError(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInvalid dtype! expected \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdtype\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdtype\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mDtypeError\u001b[0m: Invalid dtype! expected (, , , , , , , , , ), got float64" + ] + } + ], + "source": [ + "numpydantic.schema.get_validate_interface(Any, numpydantic.dtype.Int)(2.3)" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "a94ab906-419a-4860-bbf9-86deed62876c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "typing.Union[str, type, typing.Any, numpy.generic]" + ] + }, + "execution_count": 82, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 322, + "id": "9ffb235e-bf5e-43cd-8da6-5f9a8d802d01", + "metadata": {}, + "outputs": [], + "source": [ + "def to_tuple(q: pint.Quantity):\n", + " base = q.to_base_units()\n", + " return (base.magnitude, str(base.units))\n", + "\n", + "def to_array_tuple(q: pint.Quantity, info, array_serializer: Callable):\n", + " base = q.to_base_units()\n", + " return (array_serializer(base.magnitude, info=info), str(base.units))" + ] + }, + { + "cell_type": "code", + "execution_count": 323, + "id": "9e81f832-ab66-4d41-8020-25e40deb0b58", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dummy 23 None\n" + ] + }, + { + "data": { + "text/plain": [ + "(None, 'meter')" + ] + }, + "execution_count": 323, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def dummy_serializer(value, info):\n", + " print(\"dummy\", value, info)\n", + "\n", + "functools.partial(to_array_tuple, array_serializer=dummy_serializer)(pint.Quantity(23, 'm'), info=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 277, + "id": "7c724ba0-b273-4a99-bdd5-09eb5237ca87", + "metadata": {}, + "outputs": [], + "source": [ + "def get_basic_type(t):\n", + " if isinstance(t, str):\n", + " t = np.dtype(t)\n", + " if isinstance(t, Sequence):\n", + " # numpydantic.dtype.Float is a sequence, for example\n", + " # They all map to the same basic Python type\n", + " t = t[0]\n", + " if t in (float, int, complex):\n", + " return t\n", + " dtype = np.dtype(t)\n", + " return numpydantic.maps.np_to_python[dtype.type]" + ] + }, + { + "cell_type": "code", + "execution_count": 278, + "id": "8680a8bb-da5f-452b-91cf-afdb7b991971", + "metadata": {}, + "outputs": [], + "source": [ + "def get_schema(t):\n", + " basic_type = get_basic_type(t)\n", + " if basic_type is float:\n", + " return core_schema.float_schema()\n", + " elif basic_type is int:\n", + " return core_schema.int_schema()\n", + " elif basic_type is complex:\n", + " return core_schema.complex_schema()\n", + " else:\n", + " raise NotImplementedError(t)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 279, + "id": "3a55b6f8-49f3-4865-80ce-16eedcee9d3c", + "metadata": {}, + "outputs": [], + "source": [ + "class PintAnnotation:\n", + " @classmethod\n", + " def __get_pydantic_core_schema__(\n", + " cls,\n", + " _source_type: Any,\n", + " _handler: GetCoreSchemaHandler,\n", + " ) -> core_schema.CoreSchema:\n", + " \n", + " return core_schema.json_or_python_schema(\n", + " json_schema=_pint_base_repr,\n", + " python_schema=core_schema.is_instance_schema(pint.Quantity),\n", + " serialization=core_schema.plain_serializer_function_ser_schema(\n", + " to_tuple\n", + " ),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 280, + "id": "b6c8496e-7c2c-4526-a0aa-8ebf562f367c", + "metadata": {}, + "outputs": [], + "source": [ + "class PintArrayAnnotation:\n", + " @classmethod\n", + " def __get_pydantic_core_schema__(\n", + " cls,\n", + " _source_type: Any,\n", + " _handler: GetCoreSchemaHandler,\n", + " ) -> core_schema.CoreSchema:\n", + " \n", + " return core_schema.json_or_python_schema(\n", + " json_schema=_pint_base_repr,\n", + " python_schema=core_schema.is_instance_schema(pint.Quantity),\n", + " serialization=core_schema.plain_serializer_function_ser_schema(\n", + " to_tuple\n", + " ),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 281, + "id": "f8ddbb28-bfe4-485b-846a-5f776dba9a4a", + "metadata": {}, + "outputs": [], + "source": [ + "_length_dim = ureg.meter.dimensionality" + ] + }, + { + "cell_type": "code", + "execution_count": 282, + "id": "2e06057a-62dc-4653-92df-85b025b1296e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'meter'" + ] + }, + "execution_count": 282, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "str((1 * ureg.meter).to_base_units().units)" + ] + }, + { + "cell_type": "code", + "execution_count": 283, + "id": "37dbf128-bb00-4e22-9ae3-95513fd517c5", + "metadata": {}, + "outputs": [], + "source": [ + "t = numpydantic.NDArray[numpydantic.Shape['2 x, 2 y'], float]" + ] + }, + { + "cell_type": "code", + "execution_count": 284, + "id": "5b12d306-152e-48ba-a9a9-14c79acba4f0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'type': 'json-or-python',\n", + " 'json_schema': {'type': 'list',\n", + " 'items_schema': {'type': 'list',\n", + " 'items_schema': {'type': 'float'},\n", + " 'min_length': 2,\n", + " 'max_length': 2,\n", + " 'metadata': {'name': 'y'}},\n", + " 'min_length': 2,\n", + " 'max_length': 2,\n", + " 'metadata': {'name': 'x'}},\n", + " 'python_schema': {'type': 'function-plain',\n", + " 'function': {'type': 'with-info',\n", + " 'function': .validate_interface(value: Any, info: Optional[ForwardRef('ValidationInfo')] = None) -> numpydantic.types.NDArrayType>}},\n", + " 'serialization': {'type': 'function-plain',\n", + " 'function': Union[list, dict]>,\n", + " 'info_arg': True,\n", + " 'when_used': 'json'}}" + ] + }, + "execution_count": 284, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t.__get_pydantic_core_schema__(t, None)" + ] + }, + { + "cell_type": "code", + "execution_count": 326, + "id": "d131d835-8001-40fa-a3d8-efe5a9c2baad", + "metadata": {}, + "outputs": [], + "source": [ + "class Length[DType](pint.Quantity):\n", + " @classmethod\n", + " def __get_pydantic_core_schema__(\n", + " cls,\n", + " _source_type: Any,\n", + " _handler: GetCoreSchemaHandler,\n", + " ) -> core_schema.CoreSchema:\n", + " (dtype, ) = get_args(_source_type)\n", + " magnitude_schema = get_schema(dtype)\n", + " reference = 1 * ureg.meter\n", + " units = str(reference.to_base_units().units)\n", + " validate_function = numpydantic.schema.get_validate_interface(Any, dtype)\n", + " target_type = get_basic_type(dtype)\n", + " \n", + " def validator(v: Any, info: core_schema.ValidationInfo) -> pint.Quantity:\n", + " if isinstance(v, pint.Quantity):\n", + " pass\n", + " elif isinstance(v, Sequence):\n", + " magnitude, unit = v\n", + " # Turn into Quantity: magnitude * unit\n", + " v = magnitude * ureg(unit)\n", + " else:\n", + " raise ValueError(f\"Don't know how to interpret type {type(v)}.\")\n", + " # Check dimension\n", + " if not v.check(reference.dimensionality):\n", + " raise DimensionError(f\"Expected dimensionality {reference.dimensionality}, got quantity {v}.\")\n", + " try:\n", + " # First, try as-is\n", + " validate_function(v.magnitude)\n", + " except Exception:\n", + " # See if we can go from int to float, for example\n", + " if np.can_cast(type(v.magnitude), target_type):\n", + " v = target_type(v.magnitude) * v.units\n", + " validate_function(v.magnitude)\n", + " else:\n", + " raise\n", + " # Return target type\n", + " return v\n", + " \n", + " json_schema = core_schema.tuple_positional_schema(items_schema=[\n", + " magnitude_schema,\n", + " core_schema.literal_schema([units])\n", + " ])\n", + " return core_schema.json_or_python_schema(\n", + " json_schema=json_schema,\n", + " python_schema=core_schema.with_info_plain_validator_function(validator),\n", + " serialization=core_schema.plain_serializer_function_ser_schema(\n", + " to_tuple\n", + " ),\n", + " )\n", + "\n", + "class LengthArray[Shape, DType](pint.Quantity):\n", + " @classmethod\n", + " def __get_pydantic_core_schema__(\n", + " cls,\n", + " _source_type: Any,\n", + " _handler: GetCoreSchemaHandler,\n", + " ) -> core_schema.CoreSchema:\n", + " shape, dtype = get_args(_source_type)\n", + " numpydantic_type = numpydantic.NDArray[shape, dtype]\n", + " reference = 1 * ureg.meter\n", + " units = str(reference.to_base_units().units)\n", + " validate_function = numpydantic.schema.get_validate_interface(shape, dtype)\n", + " target_type = get_basic_type(dtype)\n", + "\n", + " magnitude_schema = numpydantic_type.__get_pydantic_core_schema__(\n", + " _source_type=numpydantic_type,\n", + " _handler=_handler\n", + " )\n", + " \n", + " def validator(v: Any, info: core_schema.ValidationInfo) -> pint.Quantity:\n", + " if isinstance(v, pint.Quantity):\n", + " pass\n", + " elif isinstance(v, Sequence):\n", + " magnitude, unit = v\n", + " # Turn into Quantity: magnitude * unit\n", + " v = np.asarray(magnitude) * ureg(unit)\n", + " else:\n", + " raise ValueError(f\"Don't know how to interpret type {type(v)}.\")\n", + " # Check dimension\n", + " if not v.check(reference.dimensionality):\n", + " raise DimensionError(f\"Expected dimensionality {reference.dimensionality}, got quantity {v}.\")\n", + " try:\n", + " # First, try as-is\n", + " validate_function(v.magnitude)\n", + " except Exception:\n", + " # See if we can go from int to float, for example\n", + " if np.can_cast(v.magnitude, target_type):\n", + " v = v.magnitude.astype(target_type) * v.units\n", + " validate_function(v.magnitude)\n", + " else:\n", + " raise\n", + " # Return target type\n", + " return v\n", + " \n", + " json_schema = core_schema.tuple_positional_schema(items_schema=[\n", + " magnitude_schema['json_schema'],\n", + " core_schema.literal_schema([units])\n", + " ])\n", + "\n", + " serializer = functools.partial(to_array_tuple, array_serializer=magnitude_schema['serialization']['function'])\n", + " return core_schema.json_or_python_schema(\n", + " json_schema=json_schema,\n", + " python_schema=core_schema.with_info_plain_validator_function(validator),\n", + " serialization=core_schema.plain_serializer_function_ser_schema(\n", + " function=serializer,\n", + " info_arg=True,\n", + " ),\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 327, + "id": "316c607e-cc06-437f-8bc3-4c8b326516c0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dummy 23 2\n" + ] + }, + { + "data": { + "text/plain": [ + "(None, 'meter')" + ] + }, + "execution_count": 327, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "functools.partial(to_array_tuple, array_serializer=dummy_serializer)(pint.Quantity(23, 'm'), 2)" + ] + }, + { + "cell_type": "code", + "execution_count": 335, + "id": "10a4bdd1-1d2a-46be-a6d9-15de74735cd6", + "metadata": {}, + "outputs": [], + "source": [ + "class TestSchema(BaseModel):\n", + " l: LengthArray[numpydantic.Shape['* y, 2 x'], numpydantic.dtype.Float]\n", + " m: Length[complex]" + ] + }, + { + "cell_type": "code", + "execution_count": 338, + "id": "4803b3c8-5340-46d7-8cfd-45290780651b", + "metadata": {}, + "outputs": [], + "source": [ + "t = TestSchema(m=pint.Quantity(23.3, 'cm'), l=pint.Quantity(np.array([(1, 2), (3, 4), (5, 6)]), 'km'))" + ] + }, + { + "cell_type": "code", + "execution_count": 339, + "id": "40ce4c1f-4ce6-4db4-984f-889850b404ed", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'l': ([[1000.0, 2000.0], [3000.0, 4000.0], [5000.0, 6000.0]], 'meter'),\n", + " 'm': ((0.233+0j), 'meter')}" + ] + }, + "execution_count": 339, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t.model_dump()" + ] + }, + { + "cell_type": "code", + "execution_count": 340, + "id": "2d1455af-17b6-4392-b411-8fc5264a1ec8", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'NDArray' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[340], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;21;43;01mNumpySchema\u001b[39;49;00m\u001b[43m(\u001b[49m\u001b[43mBaseModel\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43ma\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mNDArray\u001b[49m\u001b[43m[\u001b[49m\u001b[43mShape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m2 x, 2 y\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mFloat\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mfloat\u001b[39;49m\n", + "Cell \u001b[0;32mIn[340], line 2\u001b[0m, in \u001b[0;36mNumpySchema\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mNumpySchema\u001b[39;00m(BaseModel):\n\u001b[0;32m----> 2\u001b[0m a: \u001b[43mNDArray\u001b[49m[Shape[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m2 x, 2 y\u001b[39m\u001b[38;5;124m'\u001b[39m], Float]\n\u001b[1;32m 3\u001b[0m b: \u001b[38;5;28mfloat\u001b[39m\n", + "\u001b[0;31mNameError\u001b[0m: name 'NDArray' is not defined" + ] + } + ], + "source": [ + "class NumpySchema(BaseModel):\n", + " a: NDArray[Shape['2 x, 2 y'], Float]\n", + " b: float" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "c0450cc6-dd02-4ca9-bedf-5190c3b0d5c0", + "metadata": {}, + "outputs": [], + "source": [ + "n = NumpySchema(a=np.zeros((2, 2)), b=3)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "9c9f20b8-43fd-41d5-8a9a-1d8e4c907c26", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'properties': {'a': {'items': {'items': {'type': 'number'},\n", + " 'maxItems': 2,\n", + " 'minItems': 2,\n", + " 'type': 'array'},\n", + " 'maxItems': 2,\n", + " 'minItems': 2,\n", + " 'title': 'A',\n", + " 'type': 'array'},\n", + " 'b': {'title': 'B', 'type': 'number'}},\n", + " 'required': ['a', 'b'],\n", + " 'title': 'NumpySchema',\n", + " 'type': 'object'}" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "n.model_json_schema()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03ff7081-83d5-4cc8-b3c6-90a1d5faf4fd", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 3663c0d..88a6ab6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,9 @@ dynamic = ["version", "readme"] dependencies = [ "typing-extensions", "pint", - "pydantic>1,<3" + "pydantic>1,<3", + "numpy", + "numpydantic", ] classifiers = [ "Programming Language :: Python :: 3", diff --git a/src/libertem_schema/__init__.py b/src/libertem_schema/__init__.py index 51eef27..f7adeb8 100644 --- a/src/libertem_schema/__init__.py +++ b/src/libertem_schema/__init__.py @@ -1,13 +1,16 @@ -from typing import Any, Sequence +from typing import Any, Sequence, Callable +import functools + +from typing_extensions import TypeVar + +import numpydantic +import numpy as np -from typing_extensions import Annotated from pydantic_core import core_schema from pydantic import ( BaseModel, GetCoreSchemaHandler, - WrapValidator, ValidationInfo, - ValidatorFunctionWrapHandler, ) import pint @@ -22,70 +25,159 @@ class DimensionError(ValueError): pass -_pint_base_repr = core_schema.tuple_positional_schema(items_schema=[ - core_schema.float_schema(), - core_schema.str_schema() -]) - def to_tuple(q: pint.Quantity): base = q.to_base_units() return (float(base.magnitude), str(base.units)) -class PintAnnotation: - @classmethod - def __get_pydantic_core_schema__( - cls, - _source_type: Any, - _handler: GetCoreSchemaHandler, - ) -> core_schema.CoreSchema: - return core_schema.json_or_python_schema( - json_schema=_pint_base_repr, - python_schema=core_schema.is_instance_schema(pint.Quantity), - serialization=core_schema.plain_serializer_function_ser_schema( - to_tuple - ), - ) - - -_length_dim = ureg.meter.dimensionality -_angle_dim = ureg.radian.dimensionality -_pixel_dim = ureg.pixel.dimensionality - - -def _make_handler(dimensionality: str): - def is_matching( - q: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo - ) -> pint.Quantity: - # Ensure target type - if isinstance(q, pint.Quantity): - pass - elif isinstance(q, Sequence): - magnitude, unit = q - # Turn into Quantity: measure * unit - q = magnitude * ureg(unit) - else: - raise ValueError(f"Don't know how to interpret type {type(q)}.") - # Check dimension - if not q.check(dimensionality): - raise DimensionError(f"Expected dimensionality {dimensionality}, got quantity {q}.") - # Return target type - return q - - return is_matching - - -Length = Annotated[ - pint.Quantity, PintAnnotation, WrapValidator(_make_handler(_length_dim)) -] -Angle = Annotated[ - pint.Quantity, PintAnnotation, WrapValidator(_make_handler(_angle_dim)) -] -Pixel = Annotated[ - pint.Quantity, PintAnnotation, WrapValidator(_make_handler(_pixel_dim)) -] - +def to_array_tuple(q: pint.Quantity, info: ValidationInfo, array_serializer: Callable): + base = q.to_base_units() + return (array_serializer(base.magnitude, info=info), str(base.units)) + + +def get_basic_type(t): + if isinstance(t, str): + t = np.dtype(t) + if isinstance(t, Sequence): + # numpydantic.dtype.Float is a sequence, for example + # They all map to the same basic Python type + t = t[0] + if t in (float, int, complex): + return t + dtype = np.dtype(t) + return numpydantic.maps.np_to_python[dtype.type] + + +def get_schema(t): + basic_type = get_basic_type(t) + if basic_type is float: + return core_schema.float_schema() + elif basic_type is int: + return core_schema.int_schema() + elif basic_type is complex: + return core_schema.complex_schema() + else: + raise NotImplementedError(t) + + +def _make_type(reference: pint.Quantity): + + DType = TypeVar('DType') + Shape = TypeVar('Shape') + + class Single[DType](pint.Quantity): + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + (dtype, ) = _source_type.__args__ + magnitude_schema = get_schema(dtype) + units = str(reference.to_base_units().units) + validate_function = numpydantic.schema.get_validate_interface(Any, dtype) + target_type = get_basic_type(dtype) + + def validator(v: Any, info: ValidationInfo) -> pint.Quantity: + if isinstance(v, pint.Quantity): + pass + elif isinstance(v, Sequence): + magnitude, unit = v + v = pint.Quantity(magnitude, unit) + else: + raise ValueError(f"Don't know how to interpret type {type(v)}.") + # Check dimension + if not v.check(reference.dimensionality): + raise DimensionError(f"Expected dimensionality {reference.dimensionality}, got quantity {v}.") + try: + # First, try as-is + validate_function(v.magnitude) + except Exception: + # See if we can go from int to float, for example + if np.can_cast(type(v.magnitude), target_type): + v = target_type(v.magnitude) * v.units + validate_function(v.magnitude) + else: + raise + # Return target type + return v + + json_schema = core_schema.tuple_positional_schema(items_schema=[ + magnitude_schema, + core_schema.literal_schema([units]) + ]) + return core_schema.json_or_python_schema( + json_schema=json_schema, + python_schema=core_schema.with_info_plain_validator_function(validator), + serialization=core_schema.plain_serializer_function_ser_schema( + to_tuple + ), + ) + + class Array[Shape, DType](pint.Quantity): + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + shape, dtype = _source_type.__args__ + numpydantic_type = numpydantic.NDArray[shape, dtype] + units = str(reference.to_base_units().units) + validate_function = numpydantic.schema.get_validate_interface(shape, dtype) + target_type = get_basic_type(dtype) + + magnitude_schema = numpydantic_type.__get_pydantic_core_schema__( + _source_type=numpydantic_type, + _handler=_handler + ) + + def validator(v: Any, info: core_schema.ValidationInfo) -> pint.Quantity: + if isinstance(v, pint.Quantity): + pass + elif isinstance(v, Sequence): + magnitude, unit = v + # Turn into Quantity: magnitude * unit + v = pint.Quantity(magnitude=np.asarray(magnitude), unit=unit) + else: + raise ValueError(f"Don't know how to interpret type {type(v)}.") + # Check dimension + if not v.check(reference.dimensionality): + raise DimensionError(f"Expected dimensionality {reference.dimensionality}, got quantity {v}.") + try: + # First, try as-is + validate_function(v.magnitude) + except Exception: + # See if we can go from int to float, for example + if np.can_cast(v.magnitude, target_type): + v = v.magnitude.astype(target_type) * v.units + validate_function(v.magnitude) + else: + raise + # Return target type + return v + + json_schema = core_schema.tuple_positional_schema(items_schema=[ + magnitude_schema['json_schema'], + core_schema.literal_schema([units]) + ]) + + serializer = functools.partial(to_array_tuple, array_serializer=magnitude_schema['serialization']['function']) + return core_schema.json_or_python_schema( + json_schema=json_schema, + python_schema=core_schema.with_info_plain_validator_function(validator), + serialization=core_schema.plain_serializer_function_ser_schema( + function=serializer, + info_arg=True, + ), + ) + + return Single, Array + +Length, LengthArray = _make_type(pint.Quantity(1, 'meter')) +Angle, AngleArray = _make_type(pint.Quantity(1, 'radian')) +Pixel, PixelArray = _make_type(pint.Quantity(1, 'pixel')) class Simple4DSTEMParams(BaseModel): ''' @@ -96,12 +188,12 @@ class Simple4DSTEMParams(BaseModel): and https://arxiv.org/abs/2403.08538 for the technical details. ''' - overfocus: Length - scan_pixel_pitch: Length - camera_length: Length - detector_pixel_pitch: Length - semiconv: Angle - cy: Pixel - cx: Pixel - scan_rotation: Angle + overfocus: Length[float] + scan_pixel_pitch: Length[float] + camera_length: Length[float] + detector_pixel_pitch: Length[float] + semiconv: Angle[float] + cy: Pixel[float] + cx: Pixel[float] + scan_rotation: Angle[float] flip_y: bool diff --git a/test_requirements.txt b/test_requirements.txt new file mode 100644 index 0000000..9520650 --- /dev/null +++ b/test_requirements.txt @@ -0,0 +1,3 @@ +pytest +pytest-cov +jsonschema \ No newline at end of file diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 8bc56da..bfa59fd 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -4,12 +4,15 @@ import jsonschema.exceptions import pytest import jsonschema +import numpy as np from pint import UnitRegistry, Quantity from pydantic_core import from_json -from pydantic import ValidationError +from pydantic import ValidationError, BaseModel -from libertem_schema import Simple4DSTEMParams +from numpydantic import Shape + +from libertem_schema import Simple4DSTEMParams, Length, LengthArray ureg = UnitRegistry() @@ -93,6 +96,99 @@ def test_carrots(): ) +def test_nocast(): + class T1(BaseModel): + t: Length[int] + + with pytest.raises(ValidationError): + t1 = T1(t=Quantity(0.3, 'm')) + + class T2(BaseModel): + t: LengthArray[Shape['2 x, 2 y'], int] + + with pytest.raises(ValidationError): + t2 = T2(t=Quantity( + np.array([(1, 2), (3, 4)]).astype(float), + 'm' + )) + + +def test_json_nocast(): + class T1(BaseModel): + t: Length[int] + + params = T1(t=Quantity(1, 'm')) + + json_schema = params.model_json_schema() + pprint.pprint(json_schema) + as_json = params.model_dump_json() + pprint.pprint(as_json) + loaded = json.loads(as_json) + loaded['t'][0] = 0.3 + + with pytest.raises(jsonschema.exceptions.ValidationError): + jsonschema.validate( + instance=loaded, + schema=json_schema + ) + + class T2(BaseModel): + t: LengthArray[Shape['2 x, 2 y'], int] + + params = T2(t=Quantity( + np.array([(1, 2), (3, 4)]), + 'm' + )) + + json_schema = params.model_json_schema() + pprint.pprint(json_schema) + as_json = params.model_dump_json() + pprint.pprint(as_json) + loaded = json.loads(as_json) + loaded['t'][0] = [[0.3, 0.4], [0.5, 0.6]] + + with pytest.raises(jsonschema.exceptions.ValidationError): + jsonschema.validate( + instance=loaded, + schema=json_schema + ) + + +def test_shape(): + class T(BaseModel): + t: LengthArray[Shape['2 x, 2 y'], complex] + + with pytest.raises(ValidationError): + t = T(t=Quantity( + # Shape mismatch + np.array([(1, 2), (3, 4), (5, 6)]).astype(float), + 'm' + )) + + +def test_json_shape(): + class T2(BaseModel): + t: LengthArray[Shape['2 x, 2 y'], int] + + params = T2(t=Quantity( + np.array([(1, 2), (3, 4)]), + 'm' + )) + + json_schema = params.model_json_schema() + pprint.pprint(json_schema) + as_json = params.model_dump_json() + pprint.pprint(as_json) + loaded = json.loads(as_json) + loaded['t'][0] = [[1, 2], [3, 4], [5, 6]] + + with pytest.raises(jsonschema.exceptions.ValidationError): + jsonschema.validate( + instance=loaded, + schema=json_schema + ) + + def test_dimensionality(): with pytest.raises(ValidationError): Simple4DSTEMParams( @@ -249,8 +345,6 @@ def test_json_schema_missing(): ) -# No dimensionality check in JSON Schema yet -@pytest.mark.xfail def test_json_schema_dim(): params = Simple4DSTEMParams( overfocus=0.0015 * ureg.meter,