Skip to content

Commit

Permalink
Add support for modifying template expression via set expression, and…
Browse files Browse the repository at this point in the history
… testing for handling accepted_potentials as scaled versions of template potentials
  • Loading branch information
CalCraven committed Oct 1, 2024
1 parent c020fd1 commit ffb53c8
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 58 deletions.
7 changes: 5 additions & 2 deletions gmso/lib/potential_templates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module supporting template potential objects."""

import json
from copy import deepcopy
from pathlib import Path
from typing import Dict

Expand Down Expand Up @@ -113,9 +114,11 @@ def expected_parameters_dimensions(self):
"""Return the expected dimensions of the parameters for this template."""
return self.__dict__.get("expected_parameters_dimensions_")

def set_expression(self, *args, **kwargs):
def set_expression(self, new_expression):

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

Overriding method 'set_expression' has signature mismatch with
overridden method
.
"""Set the expression of the PotentialTemplate."""
raise NotImplementedError
copied_template = deepcopy(self)
copied_template.expression = new_expression
return copied_template

def assert_can_parameterize_with(
self, parameters: Dict[str, u.unyt_quantity]
Expand Down
18 changes: 15 additions & 3 deletions gmso/tests/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
import sympy
import unyt as u
from sympy import sympify

from gmso.tests.base_test import BaseTest
from gmso.utils.conversions import (
Expand All @@ -17,10 +18,21 @@ def _convert_potential_types(top, connStr, expected_units_dim, base_units):
return potentials


class TestKelvinToEnergy(BaseTest):
def test_convert_potential_styles(self, typed_ethane):
from sympy import sympify
class TestConversions(BaseTest):
def test_rescale_potentials(self, typed_ethane):
from gmso.lib.potential_templates import PotentialTemplateLibrary

library = PotentialTemplateLibrary()
template = library["LennardJonesPotential"]
template = template.set_expression(
template.expression / 4
) # use setter to not set in place
atype = list(typed_ethane.atom_types)[0]
assert atype.expression == sympify("4*epsilon*((sigma/r)**12 - (sigma/r)**6)")
typed_ethane.convert_potential_styles({"sites": template})
assert atype.expression == sympify("epsilon*((sigma/r)**12 - (sigma/r)**6)")

def test_convert_potential_styles(self, typed_ethane):
rb_expr = sympify(
"c0 * cos(phi)**0 + c1 * cos(phi)**1 + c2 * cos(phi)**2 + c3 * cos(phi)**3 + c4 * cos(phi)**4 + c5 * cos(phi)**5"
)
Expand Down
2 changes: 1 addition & 1 deletion gmso/tests/test_hoomd.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def test_zero_charges(self):
@pytest.mark.skipif(not has_hoomd, reason="hoomd is not installed")
@pytest.mark.skipif(not has_mbuild, reason="mbuild not installed")
@pytest.mark.skipif(
int(hoomd_version[0]) <= 3.8, reason="Deprecated features in HOOMD 4"
int(hoomd_version[0]) < 4.5, reason="No periodic impropers in hoomd < 4.5"
)
def test_gaff_sim(self, gaff_forcefield):
base_units = {
Expand Down
6 changes: 4 additions & 2 deletions gmso/tests/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ def test_template_set_expression(self):
independent_variables={"x"},
expected_parameters_dimensions={"a": "length", "b": "length"},
)
with pytest.raises(NotImplementedError):
template.set_expression(expression="a*y+b")
with pytest.raises(ValueError):
template.set_expression(new_expression="a*y+b")
template2 = template.set_expression(new_expression="3*a*x+b")
assert template2.expression != template.expression

def test_parameterization_non_dict_expected_dimensions(self):
template = PotentialTemplate(
Expand Down
159 changes: 109 additions & 50 deletions gmso/utils/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
from unyt.dimensions import length, mass, time

import gmso
from gmso.exceptions import GMSOError
from gmso.lib.potential_templates import PotentialTemplateLibrary
from gmso.exceptions import EngineIncompatibilityError, GMSOError
from gmso.lib.potential_templates import (
PotentialTemplate,
PotentialTemplateLibrary,
)

templates = PotentialTemplateLibrary()

Expand Down Expand Up @@ -52,6 +55,71 @@ def _try_sympy_conversions(pot1, pot2):
return None


def _conversion_from_template_name(
top, connStr: str, conn_typeStr: str, convStr: str
) -> "Topology":
"""Use the name of convStr to identify function to convert sympy expressions."""
conversions_map = { # these are predefined between template types
# More functions, and `(to, from)` key pairs added to this dictionary
(
"OPLSTorsionPotential",
"RyckaertBellemansTorsionPotential",
): convert_opls_to_ryckaert,
(
"RyckaertBellemansTorsionPotential",
"OPLSTorsionPotential",
): convert_ryckaert_to_opls,
(
"RyckaertBellemansTorsionPotential",
"FourierTorsionPotential",
): convert_ryckaert_to_opls,
} # map of all accessible conversions currently supported

# check all connections with these types for compatibility
for conn in getattr(top, connStr):
current_expression = getattr(conn, conn_typeStr[:-1]) # strip off extra s
# convert it using pre-defined names with conversion functions
conversion_from_conversion_toTuple = (current_expression.name, convStr)
if (
conversion_from_conversion_toTuple in conversions_map
): # Try mapped conversions
new_conn_type = conversions_map.get(conversion_from_conversion_toTuple)(
current_expression
)
setattr(conn, conn_typeStr[:-1], new_conn_type)
continue

# convert it using sympy expression conversion (handles constant multipliers)
new_potential = templates[convStr]
modified_connection_parametersDict = _try_sympy_conversions(
current_expression, new_potential
)
if modified_connection_parametersDict: # try sympy conversions
current_expression.name = new_potential.name
current_expression.expression = new_potential.expression
current_expression.parameters.update(modified_connection_parametersDict)


def _conversion_from_template_obj(
top: "Topology",
connStr: str,
conn_typeStr: str,
potential_template: "ParametricPotential",
):
"""Use a passed template object to identify conversion between expressions."""
for conn in getattr(top, connStr):
current_expression = getattr(conn, conn_typeStr[:-1]) # strip off extra s

# convert it using sympy expression conversion (handles constant multipliers)
modified_connection_parametersDict = _try_sympy_conversions(
current_expression, potential_template
)
if modified_connection_parametersDict: # try sympy conversions
current_expression.name = potential_template.name
current_expression.expression = potential_template.expression
current_expression.parameters.update(modified_connection_parametersDict)


def convert_topology_expressions(top, expressionMap={}):
"""Convert from one parameter form to another.
Expand All @@ -68,55 +136,46 @@ def convert_topology_expressions(top, expressionMap={}):
# TODO: Raise errors

# Apply from predefined conversions or easy sympy conversions
conversions_map = {
(
"OPLSTorsionPotential",
"RyckaertBellemansTorsionPotential",
): convert_opls_to_ryckaert,
(
"RyckaertBellemansTorsionPotential",
"OPLSTorsionPotential",
): convert_ryckaert_to_opls,
(
"RyckaertBellemansTorsionPotential",
"FourierTorsionPotential",
): convert_ryckaert_to_opls,
} # map of all accessible conversions currently supported

for conv in expressionMap:
# check all connections with these types for compatibility
for conn in getattr(top, conv):
current_expression = getattr(conn, conv[:-1] + "_type")
if (
current_expression.name == expressionMap[conv]
): # check to see if we can skip this one
# TODO: Do something instead of just comparing the names
continue

# convert it using pre-defined conversion functions
conversion_from_conversion_toTuple = (
current_expression.name,
expressionMap[conv],
)
if (
conversion_from_conversion_toTuple in conversions_map
): # Try mapped conversions
new_conn_type = conversions_map.get(conversion_from_conversion_toTuple)(
current_expression
)
setattr(conn, conv[:-1] + "_type", new_conn_type)
continue

# convert it using sympy expression conversion
new_potential = templates[expressionMap[conv]]
modified_connection_parametersDict = _try_sympy_conversions(
current_expression, new_potential
)
if modified_connection_parametersDict: # try sympy conversions
current_expression.name = new_potential.name
current_expression.expression = new_potential.expression
current_expression.parameters.update(modified_connection_parametersDict)

# handler for various keys passed to expressionMap for conversion
for connStr, conv in expressionMap.items():
possible_connections = ["bond", "angle", "dihedral", "improper"]
possible_endings = ["", "s", "_atypes"]
possible_connection_labels = [

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable possible_connection_labels is not used.
connection + ending
for connection in possible_connections
for ending in possible_endings
]
if connStr.lower() in [
"sites",
"site",
"atom",
"atoms",
"atom_types",
"atom_type",
"atomtype",
"atomtypes",
]:
# handle renaming type names in relationship to the site or connection
conn_typeStr = "atom_types"
connStr = "sites"
for possible_connection in possible_connections:
if possible_connection in connStr.lower():
connStr = possible_connection + "s"
conn_typeStr = possible_connection + "_types"
break

if isinstance(conv, str):
_conversion_from_template_name(top, connStr, conn_typeStr, conv)
elif isinstance(conv, PotentialTemplate):
_conversion_from_template_obj(top, connStr, conn_typeStr, conv)
else:
connType = list(getattr(top, conn_typeStr))[0]
errormsg = f"""
Failed to convert {top} for {connStr} components, with conversion
of {connType.name}: {connType} to {conv} as {type(conv)}.
"""
raise EngineIncompatibilityError(errormsg)
return top


Expand Down

0 comments on commit ffb53c8

Please sign in to comment.