Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update PyDantic to V2 #79

Merged
merged 14 commits into from
Dec 17, 2024
15 changes: 0 additions & 15 deletions .flake8

This file was deleted.

8 changes: 4 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install coverage flake8
pip install hatch
pip install .
- name: Lint with flake8
- name: Lint with ruff
run: |
flake8 . --count --exit-zero --show-source --statistics
hatch run dev:check
- name: Test with unittest
run: |
coverage run -m unittest
hatch run dev:cov
- name: Upload Coverage to Codecov
if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.9
uses: codecov/codecov-action@v2
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache/

# Translations
*.mo
Expand Down Expand Up @@ -130,6 +131,7 @@ dmypy.json

# IDEs
.vscode
.idea

# MacOS
.DS_Store
File renamed without changes.
28 changes: 18 additions & 10 deletions bpx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
"""BPX schema and parsers"""
# flake8: noqa F401
from .expression_parser import ExpressionParser
from .function import Function
from .interpolated_table import InterpolatedTable
from .parsers import parse_bpx_file, parse_bpx_obj, parse_bpx_str
from .schema import BPX, check_sto_limits
from .utilities import get_electrode_concentrations, get_electrode_stoichiometries

__version__ = "0.4.0"


from .interpolated_table import InterpolatedTable
from .expression_parser import ExpressionParser
from .function import Function
from .validators import check_sto_limits
from .schema import BPX
from .parsers import parse_bpx_str, parse_bpx_obj, parse_bpx_file
from .utilities import get_electrode_stoichiometries, get_electrode_concentrations
__all__ = [
"BPX",
"ExpressionParser",
"Function",
"InterpolatedTable",
"check_sto_limits",
"get_electrode_concentrations",
"get_electrode_stoichiometries",
"parse_bpx_file",
"parse_bpx_obj",
"parse_bpx_str",
]
23 changes: 23 additions & 0 deletions bpx/base_extra_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

from typing import ClassVar

from pydantic import BaseModel, ConfigDict


class ExtraBaseModel(BaseModel):
"""
A base model that forbids extra fields
"""

model_config = ConfigDict(extra="forbid")

class Settings:
"""
Class with BPX-related settings.
It might be worth moving it to a separate file if it grows bigger.
"""

tolerances: ClassVar[dict] = {
"Voltage [V]": 1e-3, # Absolute tolerance in [V] to validate the voltage limits
}
21 changes: 7 additions & 14 deletions bpx/expression_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ class ExpressionParser:

ParseException = pp.ParseException

def __init__(self):
def __init__(self) -> None:
fnumber = ppc.number()
ident = pp.Literal("x")
fn_ident = pp.Literal("x")

fn_ident = pp.Word(pp.alphas, pp.alphanums)
plus, minus, mult, div = map(pp.Literal, "+-*/")
Expand All @@ -31,21 +30,15 @@ def __init__(self):

expr_list = pp.delimitedList(pp.Group(expr))

def insert_fn_argcount_tuple(t):
def insert_fn_argcount_tuple(t: tuple) -> None:
fn = t.pop(0)
num_args = len(t[0])
t.insert(0, (fn, num_args))

fn_call = (fn_ident + lpar - pp.Group(expr_list) + rpar).setParseAction(
insert_fn_argcount_tuple
)
fn_call = (fn_ident + lpar - pp.Group(expr_list) + rpar).setParseAction(insert_fn_argcount_tuple)

atom = (
addop[...]
+ (
(fn_call | fnumber | ident).set_parse_action(self.push_first)
| pp.Group(lpar + expr + rpar)
)
addop[...] + ((fn_call | fnumber | ident).set_parse_action(self.push_first) | pp.Group(lpar + expr + rpar))
).set_parse_action(self.push_unary_minus)

# by defining exponentiation as "atom [ ^ factor ]..." instead of "atom
Expand All @@ -59,16 +52,16 @@ def insert_fn_argcount_tuple(t):
self.expr_stack = []
self.parser = expr

def push_first(self, toks):
def push_first(self, toks: tuple) -> None:
self.expr_stack.append(toks[0])

def push_unary_minus(self, toks):
def push_unary_minus(self, toks: tuple) -> None:
for t in toks:
if t == "-":
self.expr_stack.append("unary -")
else:
break

def parse_string(self, model_str, parse_all=True):
def parse_string(self, model_str: str, *, parse_all: bool = True) -> None:
self.expr_stack = []
self.parser.parseString(model_str, parseAll=parse_all)
66 changes: 49 additions & 17 deletions bpx/function.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from __future__ import annotations

import copy
from importlib import util
import tempfile
from typing import Callable
from importlib import util
from pathlib import Path
from typing import TYPE_CHECKING, Any

from pydantic_core import CoreSchema, core_schema

from bpx import ExpressionParser

if TYPE_CHECKING:
from collections.abc import Callable

from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler


class Function(str):
"""
Expand All @@ -16,31 +25,47 @@ class Function(str):
- single variable 'x'
"""

__slots__ = ()

parser = ExpressionParser()
default_preamble = "from math import exp, tanh, cosh"

@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update(examples=["1 + x", "1.9793 * exp(-39.3631 * x)" "2 * x**2"])
def __get_pydantic_json_schema__(
cls,
core_schema: CoreSchema,
handler: GetJsonSchemaHandler,
) -> dict[str, Any]:
json_schema = handler(core_schema)
json_schema["examples"] = ["1 + x", "1.9793 * exp(-39.3631 * x)" "2 * x**2"]
return handler.resolve_ref_schema(json_schema)

@classmethod
def validate(cls, v: str) -> Function:
if not isinstance(v, str):
raise TypeError("string required")
error_msg = "string required"
raise TypeError(error_msg)
try:
cls.parser.parse_string(v)
except ExpressionParser.ParseException as e:
raise ValueError(str(e))
raise ValueError(str(e)) from e
return cls(v)

def __repr__(self):
@classmethod
def __get_pydantic_core_schema__(
cls,
source_type: str,
handler: GetCoreSchemaHandler,
) -> CoreSchema:
return core_schema.no_info_after_validator_function(
cls.validate,
handler(str),
)

def __repr__(self) -> str:
return f"Function({super().__repr__()})"

def to_python_function(self, preamble: str = None) -> Callable:
def to_python_function(self, preamble: str | None = None) -> Callable:
"""
Return a python function that can be called with a single argument 'x'

Expand All @@ -61,9 +86,7 @@ def to_python_function(self, preamble: str = None) -> Callable:
function_body = f" return {self}"
source_code = preamble + function_def + function_body

with tempfile.NamedTemporaryFile(
suffix="{}.py".format(function_name), delete=False
) as tmp:
with tempfile.NamedTemporaryFile(suffix=f"{function_name}.py", delete=False) as tmp:
# write to a tempory file so we can
# get the source later on using inspect.getsource
# (as long as the file still exists)
Expand All @@ -75,6 +98,15 @@ def to_python_function(self, preamble: str = None) -> Callable:
module = util.module_from_spec(spec)
spec.loader.exec_module(module)

# Delete
tmp.close()
Path(tmp.name).unlink(missing_ok=True)
if module.__cached__:
cached_file = Path(module.__cached__)
cached_path = cached_file.parent
cached_file.unlink(missing_ok=True)
if not any(cached_path.iterdir()):
cached_path.rmdir()

# return the new function object
value = getattr(module, function_name)
return value
return getattr(module, function_name)
18 changes: 10 additions & 8 deletions bpx/interpolated_table.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List
from __future__ import annotations

from pydantic import BaseModel, validator
from pydantic import BaseModel, ValidationInfo, field_validator


class InterpolatedTable(BaseModel):
Expand All @@ -9,11 +9,13 @@ class InterpolatedTable(BaseModel):
by two lists of floats, x and y. The function is defined by interpolation.
"""

x: List[float]
y: List[float]
x: list[float]
y: list[float]

@validator("y")
def same_length(cls, v: list, values: dict) -> list:
if "x" in values and len(v) != len(values["x"]):
raise ValueError("x & y should be same length")
@field_validator("y")
@classmethod
def same_length(cls, v: list, info: ValidationInfo) -> list:
if "x" in info.data and len(v) != len(info.data["x"]):
error_msg = "x & y should be same length"
raise ValueError(error_msg)
return v
51 changes: 30 additions & 21 deletions bpx/parsers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from bpx import BPX
from .schema import BPX


def parse_bpx_file(filename: str, v_tol: float = 0.001) -> BPX:
def parse_bpx_obj(bpx: dict, v_tol: float = 0.001) -> BPX:
"""
A convenience function to parse a bpx file into a BPX model.
A convenience function to parse a bpx dict into a BPX model.

Parameters
----------
filename: str
a filepath to a bpx file
bpx: dict
a dict object in bpx format
v_tol: float
absolute tolerance in [V] to validate the voltage limits, 1 mV by default

Expand All @@ -18,21 +18,22 @@ def parse_bpx_file(filename: str, v_tol: float = 0.001) -> BPX:
a parsed BPX model
"""
if v_tol < 0:
raise ValueError("v_tol should not be negative")
error_msg = "v_tol should not be negative"
raise ValueError(error_msg)

BPX.settings.tolerances["Voltage [V]"] = v_tol
BPX.Settings.tolerances["Voltage [V]"] = v_tol

return BPX.parse_file(filename)
return BPX.model_validate(bpx)


def parse_bpx_obj(bpx: dict, v_tol: float = 0.001) -> BPX:
def parse_bpx_file(filename: str, v_tol: float = 0.001) -> BPX:
"""
A convenience function to parse a bpx dict into a BPX model.
A convenience function to parse a bpx file into a BPX model.

Parameters
----------
bpx: dict
a dict object in bpx format
filename: str
a filepath to a bpx file
v_tol: float
absolute tolerance in [V] to validate the voltage limits, 1 mV by default

Expand All @@ -41,12 +42,22 @@ def parse_bpx_obj(bpx: dict, v_tol: float = 0.001) -> BPX:
BPX: :class:`bpx.BPX`
a parsed BPX model
"""
if v_tol < 0:
raise ValueError("v_tol should not be negative")

BPX.settings.tolerances["Voltage [V]"] = v_tol
from pathlib import Path

return BPX.parse_obj(bpx)
bpx = ""
if filename.endswith((".yml", ".yaml")):
import yaml

with Path(filename).open(encoding="utf-8") as f:
bpx = yaml.safe_load(f)
else:
import orjson as json

with Path(filename).open(encoding="utf-8") as f:
bpx = json.loads(f.read())

return parse_bpx_obj(bpx, v_tol)


def parse_bpx_str(bpx: str, v_tol: float = 0.001) -> BPX:
Expand All @@ -66,9 +77,7 @@ def parse_bpx_str(bpx: str, v_tol: float = 0.001) -> BPX:
BPX:
a parsed BPX model
"""
if v_tol < 0:
raise ValueError("v_tol should not be negative")

BPX.settings.tolerances["Voltage [V]"] = v_tol
import orjson as json

return BPX.parse_raw(bpx)
bpx = json.loads(bpx)
return parse_bpx_obj(bpx, v_tol)
Loading
Loading