Skip to content

Commit

Permalink
add serializer for several fields, fix everything but test_serializat…
Browse files Browse the repository at this point in the history
…ion and test_xml_handling
  • Loading branch information
daico007 committed Nov 22, 2023
1 parent 7b1b53d commit ec6c457
Show file tree
Hide file tree
Showing 17 changed files with 208 additions and 65 deletions.
21 changes: 20 additions & 1 deletion gmso/abc/abstract_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from abc import abstractmethod
from typing import Any, Dict, Iterator, List

from pydantic import ConfigDict, Field, field_validator
from pydantic import ConfigDict, Field, field_serializer, field_validator

from gmso.abc.gmso_base import GMSOBase
from gmso.abc.serialization_utils import unyt_to_dict
from gmso.utils.expression import PotentialExpression


Expand Down Expand Up @@ -110,6 +111,24 @@ def tag_names(self) -> List[str]:
def tag_names_iter(self) -> Iterator[str]:
return iter(self.__dict__.get("tags_"))

@field_serializer("potential_expression_")
def serialize_expression(self, potential_expression_: PotentialExpression):
expr = str(potential_expression_.expression)
ind = sorted(
list(
str(ind) for ind in potential_expression_.independent_variables
)
)
params = {
param: unyt_to_dict(val)
for param, val in potential_expression_.parameters.items()
}
return {
"expression": expr,
"independent_variables": ind,
"parameters": params,
}

def add_tag(self, tag: str, value: Any, overwrite=True) -> None:
"""Add metadata for a particular tag"""
if self.tags.get(tag) and not overwrite:
Expand Down
6 changes: 6 additions & 0 deletions gmso/abc/abstract_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
Field,
StrictInt,
StrictStr,
field_serializer,
field_validator,
validator,
)

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'validator' is not used.
from unyt.exceptions import InvalidUnitOperation

from gmso.abc.gmso_base import GMSOBase
from gmso.abc.serialization_utils import unyt_to_dict
from gmso.exceptions import GMSOError

PositionType = Union[Sequence[float], np.ndarray, u.unyt_array]
Expand Down Expand Up @@ -133,6 +135,10 @@ def residue(self):
"""Return the residue assigned to the site."""
return self.__dict__.get("residue_")

@field_serializer("position_")
def serialize_position(self, position_: PositionType):
return unyt_to_dict(position_)

def __repr__(self):
"""Return the formatted representation of the site."""
return (
Expand Down
21 changes: 16 additions & 5 deletions gmso/abc/gmso_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Base model all classes extend."""
import json
import warnings
from abc import ABC
from typing import Any, ClassVar, Type
Expand All @@ -21,8 +20,6 @@ class GMSOBase(BaseModel, ABC):

__docs_generated__: ClassVar[bool] = False

# TODO[pydantic]: The following keys were removed: `json_encoders`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(
arbitrary_types_allowed=True,
validate_assignment=True,
Expand Down Expand Up @@ -71,9 +68,18 @@ def model_validate(cls: Type["Model"], obj: Any) -> "Model":
dict_to_unyt(obj)
return super(GMSOBase, cls).model_validate(obj)

def dict(self, **kwargs) -> "DictStrAny":
def model_dump(self, **kwargs) -> "DictStrAny":
kwargs["by_alias"] = True
super_dict = super(GMSOBase, self).dict(**kwargs)

additional_excludes = set()
if "exclude" in kwargs:
for term in kwargs["exclude"]:
if term in self.model_config["alias_to_fields"]:
additional_excludes.add(
self.model_config["alias_to_fields"][term]
)
kwargs["exclude"] = kwargs["exclude"].union(additional_excludes)
super_dict = super(GMSOBase, self).model_dump(**kwargs)
return super_dict

def _iter(self, **kwargs) -> "TupleGenerator":
Expand All @@ -100,6 +106,11 @@ def _iter(self, **kwargs) -> "TupleGenerator":

yield from super()._iter(**kwargs)

def model_dump_json(self, **kwargs):
kwargs["by_alias"] = True

return super(GMSOBase, self).model_dump_json(**kwargs)

@classmethod
def validate(cls, value):
"""Ensure that the object is validated before use."""
Expand Down
19 changes: 15 additions & 4 deletions gmso/core/angle_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,27 @@ class AngleType(ParametricPotential):
__eq__, _validate functions
"""

member_types: Optional[Tuple[str, str, str]] = Field(
member_types_: Optional[Tuple[str, str, str]] = Field(
None,
description="List-like of gmso.AtomType.name "
"defining the members of this angle type",
alias="member_types",
)

member_classes: Optional[Tuple[str, str, str]] = Field(
member_classes_: Optional[Tuple[str, str, str]] = Field(
None,
description="List-like of gmso.AtomType.atomclass "
"defining the members of this angle type",
alias="member_classes",
)
model_config = ConfigDict(
alias_to_fields=dict(
**ParametricPotential.model_config["alias_to_fields"],
**{
"member_types": "member_types_",
"member_classes": "member_classes_",
},
),
)

def __init__(
Expand Down Expand Up @@ -71,8 +82,8 @@ def _default_potential_expr():

@property
def member_types(self):
return self.__dict__.get("member_types")
return self.__dict__.get("member_types_")

@property
def member_classes(self):
return self.__dict__.get("member_classes")
return self.__dict__.get("member_classes_")
18 changes: 16 additions & 2 deletions gmso/core/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from typing import Optional, Union

import unyt as u
from pydantic import ConfigDict, Field, field_validator
from pydantic import ConfigDict, Field, field_serializer, field_validator

from gmso.abc.abstract_site import Site
from gmso.abc.serialization_utils import unyt_to_dict
from gmso.core.atom_type import AtomType
from gmso.core.element import Element
from gmso.utils._constants import UNIT_WARNING_STRING
Expand Down Expand Up @@ -68,7 +69,6 @@ class Atom(Site):
"atom_type": "atom_type_",
},
),
# =True,
)

@property
Expand Down Expand Up @@ -105,6 +105,20 @@ def atom_type(self) -> Union[AtomType, property]:
"""Return the atom_type associated with the atom."""
return self.__dict__.get("atom_type_", None)

@field_serializer("charge_")
def serialize_charge(self, charge_: Union[u.unyt_quantity, None]):
if charge_ is None:
return None
else:
return unyt_to_dict(charge_)

@field_serializer("mass_")
def serialize_mass(self, mass_: Union[u.unyt_quantity, None]):
if mass_ is None:
return None
else:
return unyt_to_dict(mass_)

def clone(self):
"""Clone this atom."""
return Atom(
Expand Down
83 changes: 60 additions & 23 deletions gmso/core/atom_type.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Support non-bonded interactions between sites."""
import warnings
from typing import Optional, Set
from typing import Optional, Set, Union

import unyt as u
from pydantic import ConfigDict, Field, field_validator
from pydantic import ConfigDict, Field, field_serializer, field_validator

from gmso.abc.serialization_utils import unyt_to_dict
from gmso.core.parametric_potential import ParametricPotential
from gmso.utils._constants import UNIT_WARNING_STRING
from gmso.utils.expression import PotentialExpression
Expand All @@ -28,34 +29,56 @@ class AtomType(ParametricPotential):
are stored explicitly.
"""

mass: Optional[u.unyt_array] = Field(
0.0 * u.gram / u.mol, description="The mass of the atom type"
mass_: Optional[u.unyt_array] = Field(
0.0 * u.gram / u.mol,
description="The mass of the atom type",
alias="mass",
)

charge: Optional[u.unyt_array] = Field(
0.0 * u.elementary_charge, description="The charge of the atom type"
charge_: Optional[u.unyt_array] = Field(
0.0 * u.elementary_charge,
description="The charge of the atom type",
alias="charge",
)

atomclass: Optional[str] = Field(
"", description="The class of the atomtype"
atomclass_: Optional[str] = Field(
"", description="The class of the atomtype", alias="atomclass"
)

doi: Optional[str] = Field(
doi_: Optional[str] = Field(
"",
description="Digital Object Identifier of publication where this atom type was introduced",
alias="doi",
)

overrides: Optional[Set[str]] = Field(
overrides_: Optional[Set[str]] = Field(
set(),
description="Set of other atom types that this atom type overrides",
alias="overrides",
)

definition: Optional[str] = Field(
"", description="SMARTS string defining this atom type"
definition_: Optional[str] = Field(
"",
description="SMARTS string defining this atom type",
alias="definition",
)

description: Optional[str] = Field(
"", description="Description for the AtomType"
description_: Optional[str] = Field(
"", description="Description for the AtomType", alias="description"
)
model_config = ConfigDict(
alias_to_fields=dict(
**ParametricPotential.model_config["alias_to_fields"],
**{
"mass": "mass_",
"charge": "charge_",
"atomclass": "atomclass_",
"doi": "doi_",
"overrides": "overrides_",
"definition": "definition_",
"description": "description_",
},
),
)

def __init__(
Expand Down Expand Up @@ -96,37 +119,51 @@ def __init__(
@property
def charge(self):
"""Return the charge of the atom_type."""
return self.__dict__.get("charge")
return self.__dict__.get("charge_")

@property
def mass(self):
"""Return the mass of the atom_type."""
return self.__dict__.get("mass")
return self.__dict__.get("mass_")

@property
def atomclass(self):
"""Return the atomclass of the atom_type."""
return self.__dict__.get("atomclass")
return self.__dict__.get("atomclass_")

@property
def doi(self):
"""Return the doi of the atom_type."""
return self.__dict__.get("doi")
return self.__dict__.get("doi_")

@property
def overrides(self):
"""Return the overrides of the atom_type."""
return self.__dict__.get("overrides")
return self.__dict__.get("overrides_")

@property
def description(self):
"""Return the description of the atom_type."""
return self.__dict__.get("description")
return self.__dict__.get("description_")

@property
def definition(self):
"""Return the SMARTS string of the atom_type."""
return self.__dict__.get("definition")
return self.__dict__.get("definition_")

@field_serializer("charge_")
def serialize_charge(self, charge_: Union[u.unyt_quantity, None]):
if charge_ is None:
return None
else:
return unyt_to_dict(charge_)

@field_serializer("mass_")
def serialize_mass(self, mass_: Union[u.unyt_quantity, None]):
if mass_ is None:
return None
else:
return unyt_to_dict(mass_)

def clone(self, fast_copy=False):
"""Clone this AtomType, faster alternative to deepcopying."""
Expand Down Expand Up @@ -190,7 +227,7 @@ def __repr__(self):
)
return desc

@field_validator("mass", mode="before")
@field_validator("mass_", mode="before")
@classmethod
def validate_mass(cls, mass):
"""Check to see that a mass is a unyt array of the right dimension."""
Expand All @@ -203,7 +240,7 @@ def validate_mass(cls, mass):

return mass

@field_validator("charge", mode="before")
@field_validator("charge_", mode="before")
@classmethod
def validate_charge(cls, charge):
"""Check to see that a charge is a unyt array of the right dimension."""
Expand Down
19 changes: 15 additions & 4 deletions gmso/core/bond_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,27 @@ class BondType(ParametricPotential):
__eq__, _validate functions
"""

member_types: Optional[Tuple[str, str]] = Field(
member_types_: Optional[Tuple[str, str]] = Field(
None,
description="List-like of of gmso.AtomType.name "
"defining the members of this bond type",
alias="member_types",
)

member_classes: Optional[Tuple[str, str]] = Field(
member_classes_: Optional[Tuple[str, str]] = Field(
None,
description="List-like of of gmso.AtomType.atomclass "
"defining the members of this bond type",
alias="member_classes",
)
model_config = ConfigDict(
alias_to_fields=dict(
**ParametricPotential.model_config["alias_to_fields"],
**{
"member_types": "member_types_",
"member_classes": "member_classes_",
},
),
)

def __init__(
Expand Down Expand Up @@ -62,11 +73,11 @@ def __init__(
@property
def member_types(self):
"""Return the members involved in this bondtype."""
return self.__dict__.get("member_types")
return self.__dict__.get("member_types_")

@property
def member_classes(self):
return self.__dict__.get("member_classes")
return self.__dict__.get("member_classes_")

@staticmethod
def _default_potential_expr():
Expand Down
Loading

0 comments on commit ec6c457

Please sign in to comment.