Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added PydanticPintQuantity as an option to enforce unit validation for fields #56

Merged
merged 6 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "infrasys"
version = "0.2.2"
version = "0.2.3"
description = ''
readme = "README.md"
requires-python = ">=3.10, <3.13"
Expand Down
200 changes: 200 additions & 0 deletions src/infrasys/pint_quantities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""Defines the Pydantic `pint.Quantity`."""

from __future__ import annotations

from numbers import Number
from typing import TYPE_CHECKING, Any, Literal

if TYPE_CHECKING:
from pydantic import GetCoreSchemaHandler

import pint
from pint.facets.plain.quantity import PlainQuantity as Quantity
from pydantic_core import core_schema


class PydanticPintQuantity:
"""Pydantic-compatible annotation for validating and serializing `pint.Quantity` fields.

This class allows Pydantic to handle fields that represent quantities with units,
leveraging the `pint` library for unit conversion and validation.

Parameters
----------
units : str
The base units of the Pydantic field. All input units must be convertible
to these base units.
ureg : pint.UnitRegistry, optional
A custom Pint unit registry. If not provided, the default registry is used.
ureg_contexts : str or list of str, optional
A custom Pint context (or a list of contexts) for the default unit registry.
All contexts are applied during validation and conversion.
ser_mode : {"str", "dict"}, optional
The mode for serializing the field. Can be one of:
- `"str"`: Serialize to a string representation of the quantity (default in JSON mode).
- `"dict"`: Serialize to a dictionary representation.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems convenient, but I'm not completely following. It is being set per instance. Would the user set it after construction?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea after construction. We can remove it if we do not need it. It mostly import for serialization. Can we replace some of our logic here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's discuss it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we raise this to an issue?

By default, fields are serialized in Pydantic's `"python"` mode, which preserves
the `pint.Quantity` type. In `"json"` mode, the field is serialized as a string.
strict : bool, optional
If `True` (default), forces users to specify units. If `False`, a value without
units (provided by the user) is treated as having the base units of the field.

Notes
-----
This class integrates with Pydantic's validation and serialization system to ensure
that fields representing physical quantities are handled correctly with respect to units.
"""

def __init__(
self,
units: str,
*,
ureg: pint.UnitRegistry | None = None,
ser_mode: Literal["str", "dict"] | None = None,
strict: bool = True,
):
self.ser_mode = ser_mode.lower() if ser_mode else None
self.strict = strict
self.ureg = ureg if ureg else pint.UnitRegistry()
self.units = self.ureg(units)

def validate(
self,
input_value: Any,
info: core_schema.ValidationInfo | None = None,
) -> Quantity:
"""Validate a `PydanticPintQuantity`.

Parameters
----------
input_value : Any
The quantity to validate. This can be a dictionary containing keys `"magnitude"`
and `"units"`, a string representing the quantity, or a `Number` or `Quantity`
object that can be validated and converted to a `pint.Quantity`.
info : core_schema.ValidationInfo, optional
Additional validation information provided by the Pydantic schema. Default is `None`.

Returns
-------
pint.Quantity
The validated `pint.Quantity` with the correct units.

Raises
------
ValueError
If validation fails due to one of the following reasons:
- The provided `dict` does not contain the required `"magnitude"` and `"units"` keys.
- No units are provided when strict mode is enabled.
- The provided units cannot be converted to the base units.
- An unknown unit is provided.
- An invalid type is provided for the value.
TypeError
If the type is not supported.
"""
# NOTE: `self.ureg` when passed returns the right type
if not isinstance(input_value, Quantity):
input_value = self.ureg(input_value) # This convert string to numbers

if isinstance(input_value, Number | list):
pesap marked this conversation as resolved.
Show resolved Hide resolved
input_value = input_value * self.units

# At this point `input_value` should be a `pint.Quantity`.
if not isinstance(input_value, Quantity):
msg = f"{type(input_value)} not supported"
raise TypeError(msg)
try:
input_value = input_value.to(self.units)
except pint.DimensionalityError:
msg = f"Dimension mismatch from {input_value.units} to {self.units}"
raise ValueError(msg)
return input_value

def serialize(
self,
value: Quantity,
info: core_schema.SerializationInfo | None = None,
) -> dict[str, Any] | str | Quantity:
"""
Serialize a `PydanticPintQuantity`.

Parameters
----------
value : pint.Quantity
The quantity to serialize. This should be a `pint.Quantity` object.
info : core_schema.SerializationInfo, optional
The serialization information provided by the Pydantic schema. Default is `None`.

Returns
-------
dict, str, or pint.Quantity
The serialized representation of the quantity.
- If `ser_mode='dict'` or `info.mode='dict'` a dictionary with magnitude and units.

Notes
-----
This method is useful when working with `PydanticPintQuantity` fields outside
of Pydantic models, as it allows control over the serialization format
(e.g., JSON-compatible representation).
"""
if info is not None:
mode = info.mode
else:
mode = self.ser_mode

if mode == "dict":
return {
"magnitude": value.magnitude,
"units": f"{value.units}",
}
elif mode == "str" or mode == "json":
return str(value)
else:
return value

def __get_pydantic_core_schema__(
self,
source_type: Any,
handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
_from_typedict_schema = {
"magnitude": core_schema.typed_dict_field(
core_schema.str_schema(coerce_numbers_to_str=True)
),
"units": core_schema.typed_dict_field(core_schema.str_schema()),
}

validate_schema = core_schema.chain_schema(
[
core_schema.union_schema(
[
core_schema.is_instance_schema(Quantity),
core_schema.str_schema(coerce_numbers_to_str=True),
core_schema.typed_dict_schema(_from_typedict_schema),
]
),
core_schema.with_info_plain_validator_function(self.validate),
]
)

validate_json_schema = core_schema.chain_schema(
[
core_schema.union_schema(
[
core_schema.str_schema(coerce_numbers_to_str=True),
core_schema.typed_dict_schema(_from_typedict_schema),
]
),
core_schema.no_info_plain_validator_function(self.validate),
]
)

serialize_schema = core_schema.plain_serializer_function_ser_schema(
self.serialize,
info_arg=True,
)

return core_schema.json_or_python_schema(
json_schema=validate_json_schema,
python_schema=validate_schema,
serialization=serialize_schema,
)
92 changes: 92 additions & 0 deletions tests/test_pint_quantities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest
from typing import Annotated

from pydantic import ValidationError, Field
from infrasys.base_quantity import ureg
from infrasys.component import Component
from infrasys.pint_quantities import PydanticPintQuantity
from infrasys.quantities import Voltage
from pint import Quantity


class PintQuantityStrict(Component):
voltage: Annotated[Quantity, PydanticPintQuantity("volts")]


class PintQuantityNoStrict(Component):
voltage: Annotated[Quantity, PydanticPintQuantity("volts", strict=False)]


class PintQuantityStrictDict(Component):
daniel-thom marked this conversation as resolved.
Show resolved Hide resolved
voltage: Annotated[Quantity, PydanticPintQuantity("volts", ser_mode="dict")]


class PintQuantityStrictDictPositive(Component):
voltage: Annotated[Quantity, PydanticPintQuantity("volts", ser_mode="dict"), Field(gt=0)]


@pytest.mark.parametrize(
"input_value",
[10.0 * ureg.volts, Quantity(10.0, "volt"), Voltage(10.0, "volts")],
ids=["float", "Quantity", "BaseQuantity"],
)
def test_pydantic_pint_multiple_input(input_value):
component = PintQuantityStrict(name="TestComponent", voltage=input_value)
assert isinstance(component.voltage, Quantity)
assert component.voltage.magnitude == 10.0
assert component.voltage.units == "volt"


def test_pydantic_pint_validation():
with pytest.raises(ValidationError):
_ = PintQuantityStrict(name="test", voltage=10.0 * ureg.meter)

# Pass wrong type
with pytest.raises(ValidationError):
_ = PintQuantityStrict(name="test", voltage={10: 2})


def test_compatibility_with_base_quantity():
voltage = Voltage(10.0, "volts")
component = PintQuantityStrict(name="TestComponent", voltage=voltage)
assert isinstance(component.voltage, Quantity)
assert isinstance(component.voltage, Voltage)
assert component.voltage.magnitude == 10.0
assert component.voltage.units == "volt"


def test_pydantic_pint_arguments():
# Single float should work
component = PintQuantityNoStrict(name="TestComponent", voltage=10.0)
assert isinstance(component.voltage, Quantity)
assert component.voltage.magnitude == 10.0
assert component.voltage.units == "volt"

with pytest.raises(ValidationError):
_ = PintQuantityStrictDictPositive(name="TestComponent", voltage=-10)


def test_serialization():
component = PintQuantityStrict(name="TestComponent", voltage=10.0 * ureg.volts)
component_serialized = component.model_dump()
assert isinstance(component_serialized["voltage"], Quantity)
assert component_serialized["voltage"].magnitude == 10.0
assert component_serialized["voltage"].units == "volt"

component_json = component.model_dump(mode="json")
assert component_json["voltage"] == "10.0 volt"

component_dict = component.model_dump(mode="dict")
assert isinstance(component_dict["voltage"], dict)
assert component_dict["voltage"].get("magnitude", False)
assert component_dict["voltage"].get("units", False)
assert component_dict["voltage"]["magnitude"] == 10.0
assert str(component_dict["voltage"]["units"]) == "volt"

component = PintQuantityStrict(name="TestComponent", voltage=10.0 * ureg.volts)
component_json = component.model_dump(mode="json")
assert isinstance(component_dict["voltage"], dict)
assert component_dict["voltage"].get("magnitude", False)
assert component_dict["voltage"].get("units", False)
assert component_dict["voltage"]["magnitude"] == 10.0
assert component_dict["voltage"]["units"] == "volt"
Loading