Skip to content

Commit

Permalink
remove variable def, not help in jit
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Jan 17, 2024
1 parent b8cb289 commit 578e819
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 53 deletions.
2 changes: 0 additions & 2 deletions deepmd_utils/model_format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
VariableDef,
fitting_check_output,
get_deriv_name,
get_reduce_name,
Expand Down Expand Up @@ -54,7 +53,6 @@
"ModelOutputDef",
"FittingOutputDef",
"OutputVariableDef",
"VariableDef",
"model_check_output",
"fitting_check_output",
"get_reduce_name",
Expand Down
72 changes: 26 additions & 46 deletions deepmd_utils/model_format/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,35 +104,7 @@ def __call__(
return wrapper


class VariableDef:
"""Defines the shape and other properties of a variable.
Parameters
----------
name
Name of the output variable. Notice that the xxxx_redu,
xxxx_derv_c, xxxx_derv_r are reserved names that should
not be used to define variables.
shape
The shape of the variable. e.g. energy should be [1],
dipole should be [3], polarizabilty should be [3,3].
atomic
If the variable is defined for each atom.
"""

def __init__(
self,
name: str,
shape: list[int],
atomic: bool = True,
):
self.name = name
self.shape = list(shape)
self.atomic = atomic


class OutputVariableDef(VariableDef):
class OutputVariableDef:
"""Defines the shape and other properties of the one output variable.
It is assume that the fitting network output variables for each
Expand Down Expand Up @@ -163,15 +135,11 @@ def __init__(
shape: List[int],
reduciable: bool = False,
differentiable: bool = False,
atomic: bool = True,
):
## fitting output must be atomic
## Here we cannot use super because it does not pass jit
# super().__init__(name, shape, atomic=True)
## the work around is the following
self.name = name
self.shape = list(shape)
self.atomic = True
#
self.atomic = atomic
self.reduciable = reduciable
self.differentiable = differentiable
if not self.reduciable and self.differentiable:
Expand Down Expand Up @@ -232,7 +200,7 @@ def __init__(
self.def_outp = fit_defs
self.def_redu = do_reduce(self.def_outp)
self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp)
self.var_defs = {}
self.var_defs: Dict[str, OutputVariableDef] = {}
for ii in [
self.def_outp.get_data(),
self.def_redu,
Expand All @@ -244,13 +212,13 @@ def __init__(
def __getitem__(
self,
key: str,
) -> VariableDef:
) -> OutputVariableDef:
return self.var_defs[key]

def get_data(
self,
key: str,
) -> Dict[str, VariableDef]:
) -> Dict[str, OutputVariableDef]:
return self.var_defs

def keys(self):
Expand Down Expand Up @@ -279,23 +247,35 @@ def get_deriv_name(name: str) -> Tuple[str, str]:

def do_reduce(
def_outp: FittingOutputDef,
) -> Dict[str, VariableDef]:
def_redu = {}
) -> Dict[str, OutputVariableDef]:
def_redu: Dict[str, OutputVariableDef] = {}
for kk, vv in def_outp.get_data().items():
if vv.reduciable:
rk = get_reduce_name(kk)
def_redu[rk] = VariableDef(rk, vv.shape, atomic=False)
def_redu[rk] = OutputVariableDef(
rk, vv.shape, reduciable=False, differentiable=False, atomic=False
)
return def_redu


def do_derivative(
def_outp: FittingOutputDef,
) -> Dict[str, VariableDef]:
def_derv_r = {}
def_derv_c = {}
) -> Tuple[Dict[str, OutputVariableDef], Dict[str, OutputVariableDef]]:
def_derv_r: Dict[str, OutputVariableDef] = {}
def_derv_c: Dict[str, OutputVariableDef] = {}
for kk, vv in def_outp.get_data().items():
if vv.differentiable:
rkr, rkc = get_deriv_name(kk)
def_derv_r[rkr] = VariableDef(rkr, [*vv.shape, 3], atomic=True)
def_derv_c[rkc] = VariableDef(rkc, [*vv.shape, 3, 3], atomic=False)
def_derv_r[rkr] = OutputVariableDef(
rkr,
vv.shape + [3], # noqa: RUF005
reduciable=False,
differentiable=False,
)
def_derv_c[rkc] = OutputVariableDef(
rkc,
vv.shape + [3, 3], # noqa: RUF005
reduciable=True,
differentiable=False,
)
return def_derv_r, def_derv_c
21 changes: 16 additions & 5 deletions source/tests/test_output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,22 @@
model_check_output,
)
from deepmd_utils.model_format.output_def import (
VariableDef,
check_var,
)


class VariableDef:
def __init__(
self,
name: str,
shape: list[int],
atomic: bool = True,
):
self.name = name
self.shape = list(shape)
self.atomic = atomic


class TestDef(unittest.TestCase):
def test_model_output_def(self):
defs = [
Expand Down Expand Up @@ -85,7 +96,7 @@ def test_model_output_def(self):
self.assertEqual(md["foo"].atomic, True)
self.assertEqual(md["energy_redu"].atomic, False)
self.assertEqual(md["energy_derv_r"].atomic, True)
self.assertEqual(md["energy_derv_c"].atomic, False)
self.assertEqual(md["energy_derv_c"].atomic, True)

def test_raise_no_redu_deriv(self):
with self.assertRaises(ValueError) as context:
Expand All @@ -108,7 +119,7 @@ def call(self):
"energy": np.zeros([nf, nloc, 1]),
"energy_redu": np.zeros([nf, 1]),
"energy_derv_r": np.zeros([nf, nloc, 1, 3]),
"energy_derv_c": np.zeros([nf, 1, 3, 3]),
"energy_derv_c": np.zeros([nf, nloc, 1, 3, 3]),
}

ff = Foo()
Expand All @@ -133,7 +144,7 @@ def call(self):
return {
"energy": np.zeros([nf, nloc, 1]),
"energy_redu": np.zeros([nf, 1]),
"energy_derv_c": np.zeros([nf, 1, 3, 3]),
"energy_derv_c": np.zeros([nf, nloc, 1, 3, 3]),
}

ff = Foo()
Expand Down Expand Up @@ -165,7 +176,7 @@ def call(self):
"energy": np.zeros([nf, nloc, 1]),
"energy_redu": np.zeros(self.shape_rd),
"energy_derv_r": np.zeros(self.shape_dr),
"energy_derv_c": np.zeros([nf, 1, 3, 3]),
"energy_derv_c": np.zeros([nf, nloc, 1, 3, 3]),
}

ff = Foo()
Expand Down

0 comments on commit 578e819

Please sign in to comment.