Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update GMSO to work with pydantic 2.0 #745

Merged
merged 30 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
2a7a138
first pass using bump-pydantic per https://docs.pydantic.dev/2.0/migr…
daico007 Jul 10, 2023
3c18f36
Further update syntax in abc and core
daico007 Jul 11, 2023
679f462
checkpoint
daico007 Jul 13, 2023
850ae6a
checkpoint
daico007 Jul 28, 2023
3f3bd3d
Merge branch 'main' of https://github.com/mosdef-hub/gmso into pydant…
daico007 Jul 28, 2023
a1e3795
Merge branch 'main' into pydantic2_compatibility
daico007 Jul 31, 2023
755bbe2
Merge branch 'pydantic2_compatibility' of https://github.com/daico007…
daico007 Nov 17, 2023
8c8ef4b
Merge branch 'main' of https://github.com/mosdef-hub/gmso into pydant…
daico007 Nov 17, 2023
c5e55e6
rip out json_encoder since it iss being deprecated, will need to fix …
daico007 Nov 19, 2023
6d96e89
fix up atom related files and tests
daico007 Nov 19, 2023
587da0d
fix bond.py
daico007 Nov 20, 2023
61c7de3
fix up dihedral and improper
daico007 Nov 20, 2023
cabb47b
fix various import in tests, removed json handler import
daico007 Nov 20, 2023
b3e2c7f
re-add __hash__ method for atom type and parametric potential
daico007 Nov 20, 2023
7b1b53d
replace parse_obj with model_validate
daico007 Nov 20, 2023
ec6c457
add serializer for several fields, fix everything but test_serializat…
daico007 Nov 22, 2023
fa2c64a
fix some of the serialization test
daico007 Nov 22, 2023
44a3951
reimplement json_dict, parse_raw (since model_validate_json raised un…
daico007 Nov 22, 2023
90e2643
found workaround to avoid using parse_raw
daico007 Nov 22, 2023
82ab5e6
fix remainder of xml handling test by change atom_type._etree_attrib …
daico007 Nov 24, 2023
a12b107
Merge branch 'main' of https://github.com/mosdef-hub/gmso into pydant…
daico007 Dec 13, 2023
92d0e50
update potential templates to use expected_parameters_dimensions_ for…
daico007 Dec 13, 2023
2b51973
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2023
9853726
fix typo
daico007 Dec 13, 2023
4e98a72
Merge branch 'pydantic2_compatibility' of https://github.com/daico007…
daico007 Dec 13, 2023
6ee20ca
add the missing alias key
daico007 Dec 13, 2023
6baed27
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2023
3981a49
udpate env yml
daico007 Dec 13, 2023
9f72cb8
remove pydantic from matrx in CI
daico007 Dec 13, 2023
e69f03b
Merge branch 'pydantic2_compatibility' of https://github.com/daico007…
daico007 Dec 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions gmso/abc/abstract_connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, Sequence

from pydantic import Field, root_validator
from pydantic import ConfigDict, Field, model_validator

from gmso.abc.abstract_site import Site
from gmso.abc.gmso_base import GMSOBase
Expand Down Expand Up @@ -65,7 +65,8 @@ def _get_members_types_or_classes(self, to_return):
]
return tc if all(tc) else None

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def validate_fields(cls, values):
connection_members = values.get("connection_members")

Expand Down Expand Up @@ -101,10 +102,12 @@ def __repr__(self):
def __str__(self):
return f"<{self.__class__.__name__} {self.name}, id: {id(self)}> "

class Config:
fields = {"name_": "name", "connection_members_": "connection_members"}

alias_to_fields = {
# TODO[pydantic]: The following keys were removed: `fields`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(
fields={"name_": "name", "connection_members_": "connection_members"},
alias_to_fields={
"name": "name_",
"connection_members": "connection_members_",
}
},
)
21 changes: 11 additions & 10 deletions gmso/abc/abstract_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import abstractmethod
from typing import Any, Dict, Iterator, List

from pydantic import Field, validator
from pydantic import ConfigDict, Field, field_validator
Fixed Show fixed Hide fixed

from gmso.abc.gmso_base import GMSOBase
from gmso.utils.expression import PotentialExpression
Expand Down Expand Up @@ -115,7 +115,8 @@ def delete_tag(self, tag: str) -> None:
def pop_tag(self, tag: str) -> Any:
return self.tags.pop(tag, None)

@validator("potential_expression_", pre=True)
@field_validator("potential_expression_", mode="before")
@classmethod
def validate_potential_expression(cls, v):
if isinstance(v, dict):
v = PotentialExpression(**v)
Expand Down Expand Up @@ -154,17 +155,17 @@ def __str__(self):
f"id: {id(self)}>"
)

class Config:
"""Pydantic configuration for the potential objects."""

fields = {
# TODO[pydantic]: The following keys were removed: `fields`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(
fields={
"name_": "name",
"potential_expression_": "potential_expression",
"tags_": "tags",
}

alias_to_fields = {
},
alias_to_fields={
"name": "name_",
"potential_expression": "potential_expression_",
"tags": "tags_",
}
},
)
36 changes: 22 additions & 14 deletions gmso/abc/abstract_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@

import numpy as np
import unyt as u
from pydantic import Field, StrictInt, StrictStr, validator
from pydantic import (
ConfigDict,
Field,
StrictInt,
StrictStr,
field_validator,
validator,
)
Fixed Show fixed Hide fixed
from unyt.exceptions import InvalidUnitOperation

from gmso.abc.gmso_base import GMSOBase
Expand Down Expand Up @@ -120,7 +127,8 @@ def __str__(self):
f"label: {self.label if self.label else None} id: {id(self)}>"
)

@validator("position_")
@field_validator("position_")
@classmethod
def is_valid_position(cls, position):
"""Validate attribute position."""
if position is None:
Expand Down Expand Up @@ -148,6 +156,8 @@ def is_valid_position(cls, position):

return position

# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("name_", pre=True, always=True)
def inject_name(cls, value):
if value == "" or value is None:
Expand All @@ -162,27 +172,25 @@ def __new__(cls, *args: Any, **kwargs: Any) -> SiteT:
else:
return object.__new__(cls)

class Config:
"""Pydantic configuration for site objects."""

arbitrary_types_allowed = True

fields = {
# TODO[pydantic]: The following keys were removed: `fields`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(
arbitrary_types_allowed=True,
fields={
"name_": "name",
"position_": "position",
"label_": "label",
"group_": "group",
"molecule_": "molecule",
"residue_": "residue",
}

alias_to_fields = {
},
alias_to_fields={
"name": "name_",
"position": "position_",
"label": "label_",
"group": "group_",
"molecule": "molecule_",
"residue": "residue_",
}

validate_assignment = True
},
validate_assignment=True,
)
22 changes: 12 additions & 10 deletions gmso/abc/gmso_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from abc import ABC
from typing import Any, ClassVar, Type

from pydantic import BaseModel
from pydantic.validators import dict_validator
from pydantic import BaseModel, ConfigDict, validators

from gmso.abc import GMSOJSONHandler
from gmso.abc.auto_doc import apply_docs
from gmso.abc.serialization_utils import dict_to_unyt

dict_validator = validators.getattr_migration("dict_validator")


class GMSOBase(BaseModel, ABC):
"""A BaseClass to all abstract classes in GMSO."""
Expand Down Expand Up @@ -116,11 +117,12 @@ def __get_validators__(cls) -> "CallableGenerator":
"""Get the validators of the object."""
yield cls.validate

class Config:
"""Pydantic configuration for base object."""

arbitrary_types_allowed = True
alias_to_fields = dict()
extra = "forbid"
json_encoders = GMSOJSONHandler.json_encoders
allow_population_by_field_name = True
# TODO[pydantic]: The following keys were removed: `json_encoders`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(
arbitrary_types_allowed=True,
alias_to_fields=dict(),
Fixed Show fixed Hide fixed
extra="forbid",
json_encoders=GMSOJSONHandler.json_encoders,
populate_by_name=True,
)
17 changes: 9 additions & 8 deletions gmso/core/angle.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Support for 3-partner connections between gmso.core.Atoms."""
from typing import Callable, ClassVar, Optional, Tuple

from pydantic import Field
from pydantic import ConfigDict, Field

from gmso.abc.abstract_connection import Connection
from gmso.core.angle_type import AngleType
Expand Down Expand Up @@ -83,16 +83,17 @@ def __setattr__(self, key, value):
else:
super(Angle, self).__setattr__(key, value)

class Config:
"""Support pydantic configuration for attributes and behavior."""

fields = {
# TODO[pydantic]: The following keys were removed: `fields`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(
fields={
"connection_members_": "connection_members",
"angle_type_": "angle_type",
"restraint_": "restraint",
}
alias_to_fields = {
},
alias_to_fields={
"connection_members": "connection_members_",
"angle_type": "angle_type_",
"restraint": "restraint_",
}
},
)
16 changes: 9 additions & 7 deletions gmso/core/angle_type.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Tuple

import unyt as u
from pydantic import Field
from pydantic import ConfigDict, Field
Fixed Show fixed Hide fixed

from gmso.core.parametric_potential import ParametricPotential
from gmso.utils.expression import PotentialExpression
Expand Down Expand Up @@ -77,13 +77,15 @@ def member_types(self):
def member_classes(self):
return self.__dict__.get("member_classes_")

class Config:
fields = {
# TODO[pydantic]: The following keys were removed: `fields`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(
fields={
"member_types_": "member_types",
"member_classes_": "member_classes",
}

alias_to_fields = {
},
alias_to_fields={
"member_types": "member_types_",
"member_classes": "member_classes_",
}
},
)
30 changes: 15 additions & 15 deletions gmso/core/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, Union

import unyt as u
from pydantic import Field, validator
from pydantic import ConfigDict, Field, field_validator

from gmso.abc.abstract_site import Site
from gmso.core.atom_type import AtomType
Expand Down Expand Up @@ -122,7 +122,8 @@ def __lt__(self, other):
f"Cannot compare equality between {type(self)} and {type(other)}"
)

@validator("charge_")
@field_validator("charge_")
@classmethod
def is_valid_charge(cls, charge):
"""Ensure that the charge is physically meaningful."""
if charge is None:
Expand All @@ -137,7 +138,8 @@ def is_valid_charge(cls, charge):

return charge

@validator("mass_")
@field_validator("mass_")
@classmethod
def is_valid_mass(cls, mass):
"""Ensure that the mass is physically meaningful."""
if mass is None:
Expand All @@ -150,23 +152,21 @@ def is_valid_mass(cls, mass):
ensure_valid_dimensions(mass, default_mass_units)
return mass

class Config:
"""Pydantic configuration for the atom class."""

extra = "forbid"

fields = {
# TODO[pydantic]: The following keys were removed: `fields`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(
extra="forbid",
fields={
"charge_": "charge",
"mass_": "mass",
"element_": "element",
"atom_type_": "atom_type",
}

alias_to_fields = {
},
alias_to_fields={
"charge": "charge_",
"mass": "mass_",
"element": "element_",
"atom_type": "atom_type_",
}

validate_assignment = True
},
validate_assignment=True,
)
24 changes: 13 additions & 11 deletions gmso/core/atom_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, Set

import unyt as u
from pydantic import Field, validator
from pydantic import ConfigDict, Field, field_validator
Fixed Show fixed Hide fixed

from gmso.core.parametric_potential import ParametricPotential
from gmso.utils._constants import UNIT_WARNING_STRING
Expand Down Expand Up @@ -186,7 +186,8 @@ def __repr__(self):
)
return desc

@validator("mass_", pre=True)
@field_validator("mass_", mode="before")
@classmethod
def validate_mass(cls, mass):
"""Check to see that a mass is a unyt array of the right dimension."""
default_mass_units = u.gram / u.mol
Expand All @@ -198,7 +199,8 @@ def validate_mass(cls, mass):

return mass

@validator("charge_", pre=True)
@field_validator("charge_", mode="before")
@classmethod
def validate_charge(cls, charge):
"""Check to see that a charge is a unyt array of the right dimension."""
if not isinstance(charge, u.unyt_array):
Expand All @@ -222,25 +224,25 @@ def _default_potential_expr():
},
)

class Config:
"""Pydantic configuration of the attributes for an atom_type."""

fields = {
# TODO[pydantic]: The following keys were removed: `fields`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(
fields={
"mass_": "mass",
"charge_": "charge",
"atomclass_": "atomclass",
"overrides_": "overrides",
"doi_": "doi",
"description_": "description",
"definition_": "definition",
}

alias_to_fields = {
},
alias_to_fields={
"mass": "mass_",
"charge": "charge_",
"atomclass": "atomclass_",
"overrides": "overrides_",
"doi": "doi_",
"description": "description_",
"definition": "definition_",
}
},
)
Loading