Skip to content

Commit

Permalink
fix up atom related files and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
daico007 committed Nov 19, 2023
1 parent c5e55e6 commit 6d96e89
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 35 deletions.
67 changes: 57 additions & 10 deletions gmso/abc/abstract_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,37 +57,84 @@ class Site(GMSOBase):
and their meaning the responsibility of the container where the sites will reside.
"""

name: str = Field(
name_: str = Field(
"",
validate_default=True,
description="Name of the site, defaults to class name",
alias="name",
)
label_: str = Field(
"", description="Label to be assigned to the site", alias="label"
)
label: str = Field("", description="Label to be assigned to the site")

group: Optional[StrictStr] = Field(
None, description="Flexible alternative label relative to site"
group_: Optional[StrictStr] = Field(
None,
description="Flexible alternative label relative to site",
alias="group",
)

molecule: Optional[MoleculeType] = Field(
molecule_: Optional[MoleculeType] = Field(
None,
description="Molecule label for the site, format of (molecule_name, molecule_number)",
alias="molecule",
)

residue: Optional[ResidueType] = Field(
residue_: Optional[ResidueType] = Field(
None,
description="Residue label for the site, format of (residue_name, residue_number)",
alias="residue",
)

position: PositionType = Field(
position_: PositionType = Field(
default_factory=default_position,
description="The 3D Cartesian coordinates of the position of the site",
alias="position",
)

model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
validate_assignment=True,
alias_to_fields={
"name": "name_",
"label": "label_",
"group": "group_",
"molecule": "molecule_",
"residue": "residue_",
"position": "position_",
},
populate_by_name=True,
)

@property
def name(self) -> str:
"""Return the name of the site."""
return self.__dict__.get("name_")

@property
def position(self) -> u.unyt_array:
"""Return the 3D Cartesian coordinates of the site."""
return self.__dict__.get("position_")

@property
def label(self) -> str:
"""Return the label assigned to the site."""
return self.__dict__.get("label_")

@property
def group(self) -> str:
"""Return the group of the site."""
return self.__dict__.get("group_")

@property
def molecule(self) -> tuple:
"""Return the molecule of the site."""
return self.__dict__.get("molecule_")

@property
def residue(self):
"""Return the residue assigned to the site."""
return self.__dict__.get("residue_")

def __repr__(self):
"""Return the formatted representation of the site."""
return (
Expand All @@ -104,7 +151,7 @@ def __str__(self):
f"label: {self.label if self.label else None} id: {id(self)}>"
)

@field_validator("position")
@field_validator("position_")
@classmethod
def is_valid_position(cls, position):
"""Validate attribute position."""
Expand Down Expand Up @@ -133,7 +180,7 @@ def is_valid_position(cls, position):

return position

@field_validator("name")
@field_validator("name_")
def inject_name(cls, value):
if value == "" or value is None:
return cls.__name__
Expand Down
12 changes: 12 additions & 0 deletions gmso/abc/gmso_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ def __eq__(self, other):
"""Test if two objects are equivalent."""
return self is other

def __setattr__(self, name: Any, value: Any) -> None:
"""Set the attributes of the object."""
if name in self.model_config.get("alias_to_fields"):
name = self.model_config.get("alias_to_fields")[name]
elif name in self.model_config.get("alias_to_fields").values():
warnings.warn(
"Use of internal fields is discouraged. "
"Please use external fields to set attributes."
)

super().__setattr__(name, value)

@classmethod
def __init_subclass__(cls, **kwargs):
"""Initialize the subclass of the object."""
Expand Down
53 changes: 34 additions & 19 deletions gmso/core/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,33 +38,46 @@ class Atom(Site):
An Abstract Base class for implementing site objects in GMSO. The class Atom bases from
the gmso.abc.abstract site class
"""
charge: Optional[Union[u.unyt_quantity, float]] = Field(
None,
description="Charge of the atom",
charge_: Optional[Union[u.unyt_quantity, float]] = Field(
None, description="Charge of the atom", alias="charge"
)

mass: Optional[Union[u.unyt_quantity, float]] = Field(
None, description="Mass of the atom"
mass_: Optional[Union[u.unyt_quantity, float]] = Field(
None,
description="Mass of the atom",
alias="mass",
)

element: Optional[Element] = Field(
None, description="Element associated with the atom"
element_: Optional[Element] = Field(
None,
description="Element associated with the atom",
alias="element",
)

atom_type: Optional[AtomType] = Field(
None, description="AtomType associated with the atom"
atom_type_: Optional[AtomType] = Field(
None, description="AtomType associated with the atom", alias="atom_type"
)

model_config = ConfigDict(
extra="forbid",
validate_assignment=True,
alias_to_fields=dict(
**Site.model_config["alias_to_fields"],
**{
"charge": "charge_",
"mass": "mass_",
"element": "element_",
"atom_type": "atom_type_",
},
),
populate_by_name=True,
)

@property
def charge(self) -> Union[u.unyt_quantity, None]:
"""Return the charge of the atom."""
charge = self.__dict__.get("charge", None)
atom_type = self.__dict__.get("atom_type", None)
charge = self.__dict__.get("charge_", None)
atom_type = self.__dict__.get("atom_type_", None)
if charge is not None:
return charge
elif atom_type is not None:
Expand All @@ -75,8 +88,8 @@ def charge(self) -> Union[u.unyt_quantity, None]:
@property
def mass(self) -> Union[u.unyt_quantity, None]:
"""Return the mass of the atom."""
mass = self.__dict__.get("mass", None)
atom_type = self.__dict__.get("atom_type", None)
mass = self.__dict__.get("mass_", property)
atom_type = self.__dict__.get("atom_type_", None)
if mass is not None:
return mass
elif atom_type is not None:
Expand All @@ -87,12 +100,12 @@ def mass(self) -> Union[u.unyt_quantity, None]:
@property
def element(self) -> Union[Element, None]:
"""Return the element associated with the atom."""
return self.__dict__.get("element", None)
return self.__dict__.get("element_", None)

@property
def atom_type(self) -> Union[AtomType, None]:
def atom_type(self) -> Union[AtomType, property]:
"""Return the atom_type associated with the atom."""
return self.__dict__.get("atom_type", None)
return self.__dict__.get("atom_type_", None)

def clone(self):
"""Clone this atom."""
Expand All @@ -106,7 +119,9 @@ def clone(self):
charge=self.charge,
mass=self.mass,
element=self.element,
atom_type=None if not self.atom_type else self.atom_type.clone(),
atom_type=property
if not self.atom_type
else self.atom_type.clone(),
)

def __le__(self, other):
Expand All @@ -127,7 +142,7 @@ def __lt__(self, other):
f"Cannot compare equality between {type(self)} and {type(other)}"
)

@field_validator("charge")
@field_validator("charge_")
@classmethod
def is_valid_charge(cls, charge):
"""Ensure that the charge is physically meaningful."""
Expand All @@ -143,7 +158,7 @@ def is_valid_charge(cls, charge):

return charge

@field_validator("mass")
@field_validator("mass_")
@classmethod
def is_valid_mass(cls, mass):
"""Ensure that the mass is physically meaningful."""
Expand Down
8 changes: 2 additions & 6 deletions gmso/tests/test_atom.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import numpy as np
import pytest
import unyt as u
from pydantic import ValidationError

from gmso.core.atom import Atom
from gmso.core.atom_type import AtomType
from gmso.core.element import Lithium, Sulfur
from gmso.exceptions import GMSOError
from gmso.tests.base_test import BaseTest

try:
from pydantic.v1 import ValidationError
except ImportError:
from pydantic import ValidationError


class TestSite(BaseTest):
def test_new_site(self):
Expand All @@ -28,7 +24,7 @@ def test_dtype(self):
assert isinstance(atom.position, np.ndarray)

def test_name_none(self):
atom = Atom(name=None)
atom = Atom()
assert atom.name == "Atom"

def test_setters_and_getters(self):
Expand Down

0 comments on commit 6d96e89

Please sign in to comment.