diff --git a/gmso/core/forcefield.py b/gmso/core/forcefield.py index ed6a10941..d4aaf4b23 100644 --- a/gmso/core/forcefield.py +++ b/gmso/core/forcefield.py @@ -1,12 +1,18 @@ -import typing from collections import ChainMap +from typing import Iterable +import warnings from lxml import etree -from gmso.utils.ff_utils import (validate, - parse_ff_metadata, - parse_ff_atomtypes, - parse_ff_connection_types) +from gmso.exceptions import MissingPotentialError +from gmso.utils._constants import FF_TOKENS_SEPARATOR +from gmso.utils.ff_utils import ( + parse_ff_atomtypes, + parse_ff_connection_types, + parse_ff_metadata, + validate, +) +from gmso.utils.misc import mask_with, validate_type def _group_by_expression(potential_types): @@ -183,6 +189,225 @@ def group_improper_types_by_expression(self): """ return _group_by_expression(self.improper_types) + def get_potential(self, group, key, warn=False): + """Returns a specific potential by key in this ForceField + + Parameters + ---------- + group: {'atom_types', 'bond_types', 'angle_types', 'dihedral_types', 'improper_types'} + The potential group to perform this search on + key: str or list of str + The key to lookup for this potential group + warn: bool, default=False + If true, raise a warning instead of Error if no match found + + Returns + ------- + gmso.ParametricPotential + The parametric potential requested + + Raises + ------ + MissingPotentialError + When the potential specified by `key` is not found in the ForceField + potential group `group` + """ + group = group.lower() + + potential_extractors = { + "atom_type": self._get_atom_type, + "bond_type": self._get_bond_type, + "angle_type": self._get_angle_type, + "dihedral_type": self._get_dihedral_type, + "improper_type": self._get_improper_type, + } + + if group not in potential_extractors: + raise ValueError(f"Cannot get potential for {group}") + + validate_type( + [key] if isinstance(key, str) or not isinstance(key, Iterable) else key, str + ) + + return potential_extractors[group](key, warn=warn) + + def get_parameters(self, group, key, warn=False, copy=False): + """Returns parameters for a specific potential by key in this ForceField + + This function uses the `get_potential` function to get Parameters + + See Also + -------- + gmso.ForceField.get_potential + Get specific potential/parameters from a forcefield potential group by key + """ + potential = self.get_potential(group, key, warn=warn) + return potential.get_parameters(copy=copy) + + def _get_atom_type(self, atom_type, warn=False): + """Get a particular atom_type with given `atom_type` from this ForceField""" + if isinstance(atom_type, list): + atom_type = atom_type[0] + + if not self.atom_types.get(atom_type): + msg = f"AtomType {atom_type} is not present in the ForceField" + if warn: + warnings.warn(msg) + else: + raise MissingPotentialError(msg) + + return self.atom_types.get(atom_type) + + def _get_bond_type(self, atom_types, warn=False): + """Get a particular bond_type between `atom_types` from this ForceField""" + if len(atom_types) != 2: + raise ValueError( + f"BondType potential can only " + f"be extracted for two atoms. Provided {len(atom_types)}" + ) + + forward = FF_TOKENS_SEPARATOR.join(atom_types) + reverse = FF_TOKENS_SEPARATOR.join(reversed(atom_types)) + if forward in self.bond_types: + return self.bond_types[forward] + if reverse in self.bond_types: + return self.bond_types[reverse] + + msg = f"BondType between atoms {atom_types[0]} and {atom_types[1]} " \ + f"is missing from the ForceField" + if warn: + warnings.warn(msg) + return None + else: + raise MissingPotentialError(msg) + + def _get_angle_type(self, atom_types, warn=False): + """Get a particular angle_type between `atom_types` from this ForceField""" + if len(atom_types) != 3: + raise ValueError( + f"AngleType potential can only " + f"be extracted for three atoms. Provided {len(atom_types)}" + ) + + forward = FF_TOKENS_SEPARATOR.join(atom_types) + reverse = FF_TOKENS_SEPARATOR.join(reversed(atom_types)) + match = None + if forward in self.angle_types: + match = self.angle_types[forward] + if reverse in self.angle_types: + match = self.angle_types[reverse] + + msg = f"AngleType between atoms {atom_types[0]}, {atom_types[1]} " \ + f"and {atom_types[2]} is missing from the ForceField" + + if match: + return match + elif warn: + warnings.warn(msg) + return None + else: + raise MissingPotentialError(msg) + + def _get_dihedral_type(self, atom_types, warn=False): + """Get a particular dihedral_type between `atom_types` from this ForceField""" + if len(atom_types) != 4: + raise ValueError( + f"DihedralType potential can only " + f"be extracted for four atoms. Provided {len(atom_types)}" + ) + + forward = FF_TOKENS_SEPARATOR.join(atom_types) + reverse = FF_TOKENS_SEPARATOR.join(reversed(atom_types)) + + if forward is self.dihedral_types: + return self.dihedral_types[forward] + if reverse in self.dihedral_types: + return self.dihedral_types[reverse] + + match = None + for i in range(1, 5): + forward_patterns = mask_with(atom_types, i) + reverse_patterns = mask_with(reversed(atom_types), i) + + for forward_pattern, reverse_pattern in zip( + forward_patterns, reverse_patterns + ): + forward_match_key = FF_TOKENS_SEPARATOR.join(forward_pattern) + reverse_match_key = FF_TOKENS_SEPARATOR.join(reverse_pattern) + + if forward_match_key in self.dihedral_types: + match = self.dihedral_types[forward_match_key] + break + + if reverse_match_key in self.dihedral_types: + match = self.dihedral_types[reverse_match_key] + break + + if match: + break + + msg = f"DihedralType between atoms {atom_types[0]}, {atom_types[1]}, "\ + f"{atom_types[2]} and {atom_types[3]} is missing from the ForceField." + if match: + return match + elif warn: + warnings.warn(msg) + return None + else: + raise MissingPotentialError(msg) + + def _get_improper_type(self, atom_types, warn=False): + """Get a particular improper_type between `atom_types` from this ForceField""" + if len(atom_types) != 4: + raise ValueError( + f"ImproperType potential can only " + f"be extracted for four atoms. Provided {len(atom_types)}" + ) + + forward = FF_TOKENS_SEPARATOR.join(atom_types) + reverse = FF_TOKENS_SEPARATOR.join( + [atom_types[0], atom_types[2], atom_types[1], atom_types[3]] + ) + + if forward is self.improper_types: + return self.improper_types[forward] + if reverse in self.improper_types: + return self.improper_types[reverse] + + match = None + for i in range(1, 5): + forward_patterns = mask_with(atom_types, i) + reverse_patterns = mask_with( + [atom_types[0], atom_types[2], atom_types[1], atom_types[3]], i + ) + + for forward_pattern, reverse_pattern in zip( + forward_patterns, reverse_patterns + ): + forward_match_key = FF_TOKENS_SEPARATOR.join(forward_pattern) + reverse_match_key = FF_TOKENS_SEPARATOR.join(reverse_pattern) + + if forward_match_key in self.dihedral_types: + match = self.dihedral_types[forward_match_key] + break + + if reverse_match_key in self.dihedral_types: + match = self.dihedral_types[reverse_match_key] + break + + if match: + break + + msg = f"ImproperType between atoms {atom_types[0]}, {atom_types[1]}, "\ + f"{atom_types[2]} and {atom_types[3]} is missing from the ForceField." + if match: + return match + elif warn: + warnings.warn(msg) + return None + else: + raise MissingPotentialError(msg) + def __repr__(self): return f" + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/gmso/tests/test_forcefield_xml.py b/gmso/tests/test_forcefield.py similarity index 59% rename from gmso/tests/test_forcefield_xml.py rename to gmso/tests/test_forcefield.py index 7c10b29f0..9006b93cc 100644 --- a/gmso/tests/test_forcefield_xml.py +++ b/gmso/tests/test_forcefield.py @@ -1,17 +1,16 @@ import lxml import pytest -from sympy import sympify import unyt as u - from lxml.etree import DocumentInvalid +from sympy import sympify from gmso.core.forcefield import ForceField -from gmso.tests.utils import get_path +from gmso.exceptions import ForceFieldParseError, MissingAtomTypesError, MissingPotentialError from gmso.tests.base_test import BaseTest -from gmso.exceptions import ForceFieldParseError, MissingAtomTypesError +from gmso.tests.utils import allclose_units_mixed, get_path -class TestForceFieldFromXML(BaseTest): +class TestForceField(BaseTest): @pytest.fixture def ff(self): @@ -21,6 +20,10 @@ def ff(self): def named_groups_ff(self): return ForceField(get_path('ff-example1.xml')) + @pytest.fixture + def opls_ethane_foyer(self): + return ForceField(get_path(filename=get_path("oplsaa-ethane_foyer.xml"))) + def test_ff_name_version_from_xml(self, ff): assert ff.name == 'ForceFieldOne' assert ff.version == '0.4.1' @@ -207,3 +210,175 @@ def test_forcefield_missing_atom_types(self): def test_forcefield_missing_atom_types_non_strict(self): ff = ForceField(get_path(filename=get_path('ff_missing_atom_types.xml')), strict=False) + def test_forcefeld_get_potential_atom_type(self, opls_ethane_foyer): + at = opls_ethane_foyer.get_potential("atom_type", key=["opls_135"]) + assert at.expression == sympify( + "ep * ((sigma/r)**12 - (sigma/r)**6) + q / (e0 * r)" + ) + + params = at.parameters + assert "ep" in params + assert "sigma" in params + assert "e0" in params + assert sympify("r") in at.independent_variables + + assert allclose_units_mixed( + params.values(), + [ + 0.276144 * u.kJ / u.mol, + 0.35 * u.nm, + 8.8542e-12 * u.Unit("A**2*s**4/(kg*m**3)"), + -0.18 * u.C, + ], + ) + + def test_forcefield_get_parameters_atom_type(self, opls_ethane_foyer): + params = opls_ethane_foyer.get_parameters("atom_type", key=["opls_140"]) + + assert allclose_units_mixed( + params.values(), + [ + 0.12552 * u.kJ / u.mol, + 0.25 * u.nm, + 8.8542e-12 * u.Unit("A**2*s**4/(kg*m**3)"), + 0.06 * u.C, + ], + ) + + def test_forcefield_get_parameters_atom_type_copy(self, opls_ethane_foyer): + params = opls_ethane_foyer.get_parameters( + "atom_type", key=["opls_140"], copy=False + ) + params_copy = opls_ethane_foyer.get_parameters( + "atom_type", key=["opls_140"], copy=True + ) + assert allclose_units_mixed(params.values(), params_copy.values()) + + def test_forcefield_get_potential_bond_type(self, opls_ethane_foyer): + bt = opls_ethane_foyer.get_potential("bond_type", key=["opls_135", "opls_140"]) + assert bt.name == "BondType-Harmonic-2" + params = bt.parameters + assert "k" in params + assert "r_eq" in params + + assert sympify("r") in bt.independent_variables + + assert allclose_units_mixed( + params.values(), [284512.0 * u.kJ / u.nm ** 2, 0.109 * u.nm] + ) + + def test_forcefield_get_potential_bond_type_reversed(self, opls_ethane_foyer): + assert opls_ethane_foyer.get_potential( + "bond_type", ["opls_135", "opls_140"] + ) == opls_ethane_foyer.get_potential("bond_type", ["opls_140", "opls_135"]) + + def test_forcefield_get_parameters_bond_type(self, opls_ethane_foyer): + params = opls_ethane_foyer.get_parameters( + "bond_type", key=["opls_135", "opls_135"] + ) + + assert allclose_units_mixed( + params.values(), [224262.4 * u.kJ / u.nm ** 2, 0.1529 * u.nm] + ) + + def test_forcefield_get_potential_angle_type(self, opls_ethane_foyer): + at = opls_ethane_foyer.get_potential( + "angle_type", key=["opls_135", "opls_135", "opls_140"] + ) + assert at.name == "AngleType-Harmonic-1" + params = at.parameters + assert "k" in params + assert "theta_eq" in params + + assert sympify("theta") in at.independent_variables + + assert allclose_units_mixed( + params.values(), [313.8 * u.kJ / u.radian ** 2, 1.932079482 * u.radian] + ) + + def test_forcefield_get_potential_angle_type_reversed(self, opls_ethane_foyer): + assert opls_ethane_foyer.get_potential( + "angle_type", ["opls_135", "opls_135", "opls_140"] + ) == opls_ethane_foyer.get_potential( + "angle_type", ["opls_140", "opls_135", "opls_135"] + ) + + def test_forcefield_get_parameters_angle_type(self, opls_ethane_foyer): + params = opls_ethane_foyer.get_parameters( + "angle_type", key=["opls_140", "opls_135", "opls_140"] + ) + + assert allclose_units_mixed( + params.values(), [276.144 * u.kJ / u.radian ** 2, 1.8814649337 * u.radian] + ) + + def test_forcefield_get_potential_dihedral_type(self, opls_ethane_foyer): + dt = opls_ethane_foyer.get_potential( + "dihedral_type", key=["opls_140", "opls_135", "opls_135", "opls_140"] + ) + assert dt.name == "DihedralType-RB-Proper-1" + params = dt.parameters + assert "c0" in params + assert "c1" in params + assert "c2" in params + assert "c3" in params + assert "c4" in params + assert "c5" in params + + assert sympify("phi") in dt.independent_variables + + assert allclose_units_mixed( + params.values(), [0.6276, 1.8828, 0.0, -2.5104, 0.0, 0.0] * u.kJ / u.mol + ) + + def test_forcefield_get_parameters_dihedral_type(self, opls_ethane_foyer): + params = opls_ethane_foyer.get_parameters( + "dihedral_type", key=["opls_140", "opls_135", "opls_135", "opls_140"] + ) + + assert allclose_units_mixed( + params.values(), [0.6276, 1.8828, 0.0, -2.5104, 0.0, 0.0] * u.kJ / u.mol + ) + + def test_forcefield_get_potential_non_exisistent_group(self, opls_ethane_foyer): + with pytest.raises(ValueError): + opls_ethane_foyer.get_potential('non_group', ['a', 'b', 'c']) + + def test_forcefield_get_potential_non_string_key(self, opls_ethane_foyer): + with pytest.raises(TypeError): + opls_ethane_foyer.get_potential('atom_type', key=[111]) + + def test_get_atom_type_missing(self, opls_ethane_foyer): + with pytest.raises(MissingPotentialError): + opls_ethane_foyer._get_atom_type('opls_359', warn=False) + + with pytest.warns(UserWarning): + opls_ethane_foyer._get_atom_type('opls_359', warn=True) + + def test_get_bond_type_missing(self, opls_ethane_foyer): + with pytest.raises(MissingPotentialError): + opls_ethane_foyer._get_bond_type(['opls_359', 'opls_600'], warn=False) + + with pytest.warns(UserWarning): + opls_ethane_foyer._get_bond_type(['opls_359', 'opls_600'], warn=True) + + def test_get_angle_type_missing(self, opls_ethane_foyer): + with pytest.raises(MissingPotentialError): + opls_ethane_foyer._get_angle_type(['opls_359', 'opls_600', 'opls_700'], warn=False) + + with pytest.warns(UserWarning): + opls_ethane_foyer._get_angle_type(['opls_359', 'opls_600', 'opls_700'], warn=True) + + def test_get_dihedral_type_missing(self, opls_ethane_foyer): + with pytest.raises(MissingPotentialError): + opls_ethane_foyer._get_dihedral_type(['opls_359', 'opls_600', 'opls_700', 'opls_800'], warn=False) + + with pytest.warns(UserWarning): + opls_ethane_foyer._get_dihedral_type(['opls_359', 'opls_600', 'opls_700', 'opls_800'], warn=True) + + def test_get_improper_type_missing(self, opls_ethane_foyer): + with pytest.raises(MissingPotentialError): + opls_ethane_foyer._get_improper_type(['opls_359', 'opls_600', 'opls_700', 'opls_800'], warn=False) + + with pytest.warns(UserWarning): + opls_ethane_foyer._get_improper_type(['opls_359', 'opls_600', 'opls_700', 'opls_800'], warn=True) diff --git a/gmso/tests/utils.py b/gmso/tests/utils.py index 788a4184b..5f423b1f6 100644 --- a/gmso/tests/utils.py +++ b/gmso/tests/utils.py @@ -1,8 +1,35 @@ - import os +import unyt as u + def get_path(filename): """Given a test filename return its path""" _path = os.path.join(os.path.split(__file__)[0], 'files', filename) return _path + + +def allclose_units_mixed(u_iter1, u_iter2): + """Check if array of quantities with mixed dimensions are equivalent. + + Notes + ----- + The two iterables provided must contain same number of quantities and + should be able to passed to Python zip function. + + Parameters + ---------- + u_iter1: list or iterable of u.unit_quantity + The first iterable/list of unit quantities + u_iter2: list or iterable of u.unit_quantity + The second iterable/list of unit quantities + + Returns + ------- + bool + True if iter1 is equivalent to iter2 + """ + for q1, q2 in zip(u_iter1, u_iter2): + if not u.allclose_units(q1, q2): + return False + return True diff --git a/gmso/utils/_constants.py b/gmso/utils/_constants.py index 0225fda35..1072207ee 100644 --- a/gmso/utils/_constants.py +++ b/gmso/utils/_constants.py @@ -5,3 +5,4 @@ IMPROPER_TYPE_DICT = 'improper_type_dict' UNIT_WARNING_STRING = '{0} are assumed to be in units of {1}' +FF_TOKENS_SEPARATOR = '~' diff --git a/gmso/utils/ff_utils.py b/gmso/utils/ff_utils.py index c005fa32a..5b1880a51 100644 --- a/gmso/utils/ff_utils.py +++ b/gmso/utils/ff_utils.py @@ -10,15 +10,13 @@ from gmso.core.angle_type import AngleType from gmso.core.dihedral_type import DihedralType from gmso.core.improper_type import ImproperType +from gmso.utils._constants import FF_TOKENS_SEPARATOR from gmso.exceptions import ForceFieldParseError, ForceFieldError, MissingAtomTypesError __all__ = ['validate', 'parse_ff_metadata', 'parse_ff_atomtypes', - 'parse_ff_connection_types', - 'DICT_KEY_SEPARATOR'] - -DICT_KEY_SEPARATOR = '~' + 'parse_ff_connection_types'] # Create a dictionary of units _unyt_dictionary = {} @@ -28,8 +26,8 @@ def _check_valid_string(type_str): - if DICT_KEY_SEPARATOR in type_str: - raise ForceFieldError('Please do not use {} in type string'.format(DICT_KEY_SEPARATOR)) + if FF_TOKENS_SEPARATOR in type_str: + raise ForceFieldError('Please do not use {} in type string'.format(FF_TOKENS_SEPARATOR)) def _parse_param_units(parent_tag): @@ -332,7 +330,7 @@ def parse_ff_connection_types(connectiontypes_el, child_tag='BondType'): valued_param_vars = set(sympify(param) for param in ctor_kwargs['parameters'].keys()) ctor_kwargs['independent_variables'] = sympify(connectiontype_expression).free_symbols - valued_param_vars - this_conn_type_key = DICT_KEY_SEPARATOR.join(ctor_kwargs['member_types']) + this_conn_type_key = FF_TOKENS_SEPARATOR.join(ctor_kwargs['member_types']) this_conn_type = TAG_TO_CLASS_MAP[child_tag](**ctor_kwargs) connectiontypes_dict[this_conn_type_key] = this_conn_type diff --git a/gmso/utils/misc.py b/gmso/utils/misc.py index aa43fb89c..d129c21a4 100644 --- a/gmso/utils/misc.py +++ b/gmso/utils/misc.py @@ -47,3 +47,61 @@ def ensure_valid_dimensions(quantity_1: u.unyt_quantity, quantity_2.units, quantity_2.units.dimensions ) + + +def validate_type(iterator, type_): + """Validate all the elements of the iterable are of a particular type""" + for item in iterator: + if not isinstance(item, type_): + raise TypeError( + f"Expected {item} to be of type {type_.__name__} but got" + f" {type(item).__name__} instead." + ) + + +def mask_with(iterable, window_size=1, mask='*'): + """Mask an iterable with the `mask` in a circular sliding window of size `window_size` + + This method masks an iterable elements with a mask object in a circular sliding window + + Parameters + ---------- + iterable: Iterable + The iterable to mask with + window_size: int, default=1 + The window size for the mask to be applied + mask: Any, default='*' + The mask to apply + Examples + -------- + >>> from gmso.utils.misc import mask_with + >>> list(mask_with(['Ar', 'Ar'], 1)) + [['*', 'Ar'], ['Ar', '*']] + >>> for masked_list in mask_with(['Ar', 'Xe', 'Xm', 'CH'], 2, mask='_'): + ... print('~'.join(masked_list)) + _~_~Xm~CH + Ar~_~_~CH + Ar~Xe~_~_ + _~Xe~Xm~_ + + Yields + ------ + list + The masked iterable + """ + input_list = list(iterable) + idx = 0 + first = None + while idx < len(input_list): + mask_idxes = set((idx + j) % len(input_list) for j in range(window_size)) + to_yield = [ + mask if j in mask_idxes else input_list[j] + for j in range(len(input_list)) + ] + if to_yield == first: + break + if idx == 0: + first = to_yield + + idx += 1 + yield to_yield