Skip to content

Commit

Permalink
add definition for the output of fitting and model
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Jan 10, 2024
1 parent 43f9639 commit 4796a86
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 0 deletions.
10 changes: 10 additions & 0 deletions deepmd_utils/model_format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
save_dp_model,
traverse_model_dict,
)
from .output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
VariableDef,
)
from .se_e2_a import (
DescrptSeA,
)
Expand All @@ -31,4 +37,8 @@
"traverse_model_dict",
"PRECISION_DICT",
"DEFAULT_PRECISION",
"ModelOutputDef",
"FittingOutputDef",
"OutputVariableDef",
"VariableDef",
]
192 changes: 192 additions & 0 deletions deepmd_utils/model_format/output_def.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Tuple,
Union,
)


class VariableDef:
"""Defines the shape and other properties of a variable.
Parameters
----------
name
Name of the output variable. Notice that the xxxx_redu,
xxxx_derv_c, xxxx_derv_r are reserved names that should
not be used to define variables.
shape
The shape of the variable. e.g. energy should be [1],
dipole should be [3], polarizabilty should be [3,3].
atomic
If the variable is defined for each atom.
"""

def __init__(
self,
name: str,
shape: Union[List[int], Tuple[int]],
atomic: bool = True,
):
self.name = name
self.shape = list(shape)
self.atomic = atomic


class OutputVariableDef(VariableDef):
"""Defines the shape and other properties of the one output variable.
It is assume that the fitting network output variables for each
local atom. This class defines one output variable, including its
name, shape, reducibility and differentiability.
Parameters
----------
name
Name of the output variable. Notice that the xxxx_redu,
xxxx_derv_c, xxxx_derv_r are reserved names that should
not be used to define variables.
shape
The shape of the variable. e.g. energy should be [1],
dipole should be [3], polarizabilty should be [3,3].
reduciable
If the variable is reduced.
differentiable
If the variable is differentiated with respect to coordinates
of atoms and cell tensor (pbc case). Only reduciable variable
are differentiable.
"""

def __init__(
self,
name: str,
shape: Union[List[int], Tuple[int]],
reduciable: bool = False,
differentiable: bool = False,
):
# fitting output must be atomic
super().__init__(name, shape, atomic=True)
self.reduciable = reduciable
self.differentiable = differentiable
if not self.reduciable and self.differentiable:
raise ValueError("only reduciable variable are differentiable")


class FittingOutputDef:
"""Defines the shapes and other properties of the fitting network outputs.
It is assume that the fitting network output variables for each
local atom. This class defines all the outputs.
Parameters
----------
var_defs
List of output variable definitions.
"""

def __init__(
self,
var_defs: List[OutputVariableDef] = [],
):
self.var_defs = {vv.name: vv for vv in var_defs}

def __getitem__(
self,
key,
) -> OutputVariableDef:
return self.var_defs[key]

def get_data(self) -> Dict[str, OutputVariableDef]:
return self.var_defs

def keys(self):
return self.var_defs.keys()


class ModelOutputDef:
"""Defines the shapes and other properties of the model outputs.
The model reduce and differentiate fitting outputs if applicable.
If a variable is named by foo, then the reduced variable is called
foo_redu, the derivative w.r.t. coordinates is called foo_derv_r
and the derivative w.r.t. cell is called foo_derv_c.
Parameters
----------
fit_defs
Definition for the fitting net output
"""

def __init__(
self,
fit_defs: FittingOutputDef,
):
self.def_outp = fit_defs
self.def_redu = do_reduce(self.def_outp)
self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp)
self.var_defs = {}
for ii in [
self.def_outp.get_data(),
self.def_redu,
self.def_derv_c,
self.def_derv_r,
]:
self.var_defs.update(ii)

def __getitem__(self, key) -> VariableDef:
return self.var_defs[key]

def get_data(self, key) -> Dict[str, VariableDef]:
return self.var_defs

Check warning on line 145 in deepmd_utils/model_format/output_def.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/output_def.py#L145

Added line #L145 was not covered by tests

def keys(self):
return self.var_defs.keys()

def keys_outp(self):
return self.def_outp.keys()

Check warning on line 151 in deepmd_utils/model_format/output_def.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/output_def.py#L151

Added line #L151 was not covered by tests

def keys_redu(self):
return self.def_redu.keys()

Check warning on line 154 in deepmd_utils/model_format/output_def.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/output_def.py#L154

Added line #L154 was not covered by tests

def keys_derv_r(self):
return self.def_derv_r.keys()

Check warning on line 157 in deepmd_utils/model_format/output_def.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/output_def.py#L157

Added line #L157 was not covered by tests

def keys_derv_c(self):
return self.def_derv_c.keys()

Check warning on line 160 in deepmd_utils/model_format/output_def.py

View check run for this annotation

Codecov / codecov/patch

deepmd_utils/model_format/output_def.py#L160

Added line #L160 was not covered by tests


def get_reduce_name(name):
return name + "_redu"


def get_deriv_name(name):
return name + "_derv_r", name + "_derv_c"


def do_reduce(
def_outp,
):
def_redu = {}
for kk, vv in def_outp.get_data().items():
if vv.reduciable:
rk = get_reduce_name(kk)
def_redu[rk] = VariableDef(rk, vv.shape, atomic=False)
return def_redu


def do_derivative(
def_outp,
):
def_derv_r = {}
def_derv_c = {}
for kk, vv in def_outp.get_data().items():
if vv.differentiable:
rkr, rkc = get_deriv_name(kk)
def_derv_r[rkr] = VariableDef(rkr, [*vv.shape, 3], atomic=True)
def_derv_c[rkc] = VariableDef(rkc, [*vv.shape, 3, 3], atomic=False)
return def_derv_r, def_derv_c
75 changes: 75 additions & 0 deletions source/tests/test_output_def.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

from deepmd_utils.model_format import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
)


class TestDef(unittest.TestCase):
def test_model_output_def(self):
defs = [
OutputVariableDef("energy", [1], True, True),
OutputVariableDef("dos", [10], True, False),
OutputVariableDef("foo", [3], False, False),
]
# fitting definition
fd = FittingOutputDef(defs)
expected_keys = ["energy", "dos", "foo"]
self.assertEqual(
set(expected_keys),
set(fd.keys()),
)
# shape
self.assertEqual(fd["energy"].shape, [1])
self.assertEqual(fd["dos"].shape, [10])
self.assertEqual(fd["foo"].shape, [3])
# atomic
self.assertEqual(fd["energy"].atomic, True)
self.assertEqual(fd["dos"].atomic, True)
self.assertEqual(fd["foo"].atomic, True)
# reduce
self.assertEqual(fd["energy"].reduciable, True)
self.assertEqual(fd["dos"].reduciable, True)
self.assertEqual(fd["foo"].reduciable, False)
# derivative
self.assertEqual(fd["energy"].differentiable, True)
self.assertEqual(fd["dos"].differentiable, False)
self.assertEqual(fd["foo"].differentiable, False)
# model definition
md = ModelOutputDef(fd)
expected_keys = [
"energy",
"dos",
"foo",
"energy_redu",
"energy_derv_r",
"energy_derv_c",
"dos_redu",
]
self.assertEqual(
set(expected_keys),
set(md.keys()),
)
for kk in expected_keys:
self.assertEqual(md[kk].name, kk)
# shape
self.assertEqual(md["energy"].shape, [1])
self.assertEqual(md["dos"].shape, [10])
self.assertEqual(md["foo"].shape, [3])
self.assertEqual(md["energy_redu"].shape, [1])
self.assertEqual(md["energy_derv_r"].shape, [1, 3])
self.assertEqual(md["energy_derv_c"].shape, [1, 3, 3])
# atomic
self.assertEqual(md["energy"].atomic, True)
self.assertEqual(md["dos"].atomic, True)
self.assertEqual(md["foo"].atomic, True)
self.assertEqual(md["energy_redu"].atomic, False)
self.assertEqual(md["energy_derv_r"].atomic, True)
self.assertEqual(md["energy_derv_c"].atomic, False)

def test_raise_no_redu_deriv(self):
with self.assertRaises(ValueError) as context:
(OutputVariableDef("energy", [1], False, True),)

0 comments on commit 4796a86

Please sign in to comment.