diff --git a/deepmd_utils/model_format/__init__.py b/deepmd_utils/model_format/__init__.py index d814f56b1a..356eaaf4fa 100644 --- a/deepmd_utils/model_format/__init__.py +++ b/deepmd_utils/model_format/__init__.py @@ -15,6 +15,14 @@ save_dp_model, traverse_model_dict, ) +from .output_def import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableDef, + VariableDef, + fitting_check_output, + model_check_output, +) from .se_e2_a import ( DescrptSeA, ) @@ -31,4 +39,10 @@ "traverse_model_dict", "PRECISION_DICT", "DEFAULT_PRECISION", + "ModelOutputDef", + "FittingOutputDef", + "OutputVariableDef", + "VariableDef", + "model_check_output", + "fitting_check_output", ] diff --git a/deepmd_utils/model_format/output_def.py b/deepmd_utils/model_format/output_def.py new file mode 100644 index 0000000000..f4fcdce3ca --- /dev/null +++ b/deepmd_utils/model_format/output_def.py @@ -0,0 +1,278 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Tuple, + Union, +) + + +def check_var(var, var_def): + if var_def.atomic: + # var.shape == [nf, nloc, *var_def.shape] + if len(var.shape) != len(var_def.shape) + 2: + raise ValueError(f"{var.shape[2:]} length not matching def {var_def.shape}") + if list(var.shape[2:]) != var_def.shape: + raise ValueError(f"{var.shape[2:]} not matching def {var_def.shape}") + else: + # var.shape == [nf, *var_def.shape] + if len(var.shape) != len(var_def.shape) + 1: + raise ValueError(f"{var.shape[1:]} length not matching def {var_def.shape}") + if list(var.shape[1:]) != var_def.shape: + raise ValueError(f"{var.shape[1:]} not matching def {var_def.shape}") + + +def model_check_output(cls): + """Check if the output of the Model is consistent with the definition. + + Two methods are assumed to be provided by the Model: + 1. Model.output_def that gives the output definition. + 2. Model.forward that defines the forward path of the model. + + """ + + class wrapper(cls): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.md = cls.output_def(self) + + def forward( + self, + *args, + **kwargs, + ): + ret = cls.forward(self, *args, **kwargs) + for kk in self.md.keys_outp(): + dd = self.md[kk] + check_var(ret[kk], dd) + if dd.reduciable: + rk = get_reduce_name(kk) + check_var(ret[rk], self.md[rk]) + if dd.differentiable: + dnr, dnc = get_deriv_name(kk) + check_var(ret[dnr], self.md[dnr]) + check_var(ret[dnc], self.md[dnc]) + return ret + + return wrapper + + +def fitting_check_output(cls): + """Check if the output of the Fitting is consistent with the definition. + + Two methods are assumed to be provided by the Fitting: + 1. Fitting.output_def that gives the output definition. + 2. Fitting.forward defines the forward path of the fitting. + + """ + + class wrapper(cls): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.md = cls.output_def(self) + + def forward( + self, + *args, + **kwargs, + ): + ret = cls.forward(self, *args, **kwargs) + for kk in self.md.keys(): + dd = self.md[kk] + check_var(ret[kk], dd) + return ret + + return wrapper + + +class VariableDef: + """Defines the shape and other properties of a variable. + + Parameters + ---------- + name + Name of the output variable. Notice that the xxxx_redu, + xxxx_derv_c, xxxx_derv_r are reserved names that should + not be used to define variables. + shape + The shape of the variable. e.g. energy should be [1], + dipole should be [3], polarizabilty should be [3,3]. + atomic + If the variable is defined for each atom. + + """ + + def __init__( + self, + name: str, + shape: Union[List[int], Tuple[int]], + atomic: bool = True, + ): + self.name = name + self.shape = list(shape) + self.atomic = atomic + + +class OutputVariableDef(VariableDef): + """Defines the shape and other properties of the one output variable. + + It is assume that the fitting network output variables for each + local atom. This class defines one output variable, including its + name, shape, reducibility and differentiability. + + Parameters + ---------- + name + Name of the output variable. Notice that the xxxx_redu, + xxxx_derv_c, xxxx_derv_r are reserved names that should + not be used to define variables. + shape + The shape of the variable. e.g. energy should be [1], + dipole should be [3], polarizabilty should be [3,3]. + reduciable + If the variable is reduced. + differentiable + If the variable is differentiated with respect to coordinates + of atoms and cell tensor (pbc case). Only reduciable variable + are differentiable. + + """ + + def __init__( + self, + name: str, + shape: Union[List[int], Tuple[int]], + reduciable: bool = False, + differentiable: bool = False, + ): + # fitting output must be atomic + super().__init__(name, shape, atomic=True) + self.reduciable = reduciable + self.differentiable = differentiable + if not self.reduciable and self.differentiable: + raise ValueError("only reduciable variable are differentiable") + + +class FittingOutputDef: + """Defines the shapes and other properties of the fitting network outputs. + + It is assume that the fitting network output variables for each + local atom. This class defines all the outputs. + + Parameters + ---------- + var_defs + List of output variable definitions. + + """ + + def __init__( + self, + var_defs: List[OutputVariableDef] = [], + ): + self.var_defs = {vv.name: vv for vv in var_defs} + + def __getitem__( + self, + key, + ) -> OutputVariableDef: + return self.var_defs[key] + + def get_data(self) -> Dict[str, OutputVariableDef]: + return self.var_defs + + def keys(self): + return self.var_defs.keys() + + +class ModelOutputDef: + """Defines the shapes and other properties of the model outputs. + + The model reduce and differentiate fitting outputs if applicable. + If a variable is named by foo, then the reduced variable is called + foo_redu, the derivative w.r.t. coordinates is called foo_derv_r + and the derivative w.r.t. cell is called foo_derv_c. + + Parameters + ---------- + fit_defs + Definition for the fitting net output + + """ + + def __init__( + self, + fit_defs: FittingOutputDef, + ): + self.def_outp = fit_defs + self.def_redu = do_reduce(self.def_outp) + self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp) + self.var_defs = {} + for ii in [ + self.def_outp.get_data(), + self.def_redu, + self.def_derv_c, + self.def_derv_r, + ]: + self.var_defs.update(ii) + + def __getitem__(self, key) -> VariableDef: + return self.var_defs[key] + + def get_data(self, key) -> Dict[str, VariableDef]: + return self.var_defs + + def keys(self): + return self.var_defs.keys() + + def keys_outp(self): + return self.def_outp.keys() + + def keys_redu(self): + return self.def_redu.keys() + + def keys_derv_r(self): + return self.def_derv_r.keys() + + def keys_derv_c(self): + return self.def_derv_c.keys() + + +def get_reduce_name(name): + return name + "_redu" + + +def get_deriv_name(name): + return name + "_derv_r", name + "_derv_c" + + +def do_reduce( + def_outp, +): + def_redu = {} + for kk, vv in def_outp.get_data().items(): + if vv.reduciable: + rk = get_reduce_name(kk) + def_redu[rk] = VariableDef(rk, vv.shape, atomic=False) + return def_redu + + +def do_derivative( + def_outp, +): + def_derv_r = {} + def_derv_c = {} + for kk, vv in def_outp.get_data().items(): + if vv.differentiable: + rkr, rkc = get_deriv_name(kk) + def_derv_r[rkr] = VariableDef(rkr, [*vv.shape, 3], atomic=True) + def_derv_c[rkc] = VariableDef(rkc, [*vv.shape, 3, 3], atomic=False) + return def_derv_r, def_derv_c diff --git a/source/tests/test_output_def.py b/source/tests/test_output_def.py new file mode 100644 index 0000000000..7f5404ee31 --- /dev/null +++ b/source/tests/test_output_def.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd_utils.model_format import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableDef, + fitting_check_output, + model_check_output, +) + + +class TestDef(unittest.TestCase): + def test_model_output_def(self): + defs = [ + OutputVariableDef("energy", [1], True, True), + OutputVariableDef("dos", [10], True, False), + OutputVariableDef("foo", [3], False, False), + ] + # fitting definition + fd = FittingOutputDef(defs) + expected_keys = ["energy", "dos", "foo"] + self.assertEqual( + set(expected_keys), + set(fd.keys()), + ) + # shape + self.assertEqual(fd["energy"].shape, [1]) + self.assertEqual(fd["dos"].shape, [10]) + self.assertEqual(fd["foo"].shape, [3]) + # atomic + self.assertEqual(fd["energy"].atomic, True) + self.assertEqual(fd["dos"].atomic, True) + self.assertEqual(fd["foo"].atomic, True) + # reduce + self.assertEqual(fd["energy"].reduciable, True) + self.assertEqual(fd["dos"].reduciable, True) + self.assertEqual(fd["foo"].reduciable, False) + # derivative + self.assertEqual(fd["energy"].differentiable, True) + self.assertEqual(fd["dos"].differentiable, False) + self.assertEqual(fd["foo"].differentiable, False) + # model definition + md = ModelOutputDef(fd) + expected_keys = [ + "energy", + "dos", + "foo", + "energy_redu", + "energy_derv_r", + "energy_derv_c", + "dos_redu", + ] + self.assertEqual( + set(expected_keys), + set(md.keys()), + ) + for kk in expected_keys: + self.assertEqual(md[kk].name, kk) + # reduce + self.assertEqual(md["energy"].reduciable, True) + self.assertEqual(md["dos"].reduciable, True) + self.assertEqual(md["foo"].reduciable, False) + # derivative + self.assertEqual(md["energy"].differentiable, True) + self.assertEqual(md["dos"].differentiable, False) + self.assertEqual(md["foo"].differentiable, False) + # shape + self.assertEqual(md["energy"].shape, [1]) + self.assertEqual(md["dos"].shape, [10]) + self.assertEqual(md["foo"].shape, [3]) + self.assertEqual(md["energy_redu"].shape, [1]) + self.assertEqual(md["energy_derv_r"].shape, [1, 3]) + self.assertEqual(md["energy_derv_c"].shape, [1, 3, 3]) + # atomic + self.assertEqual(md["energy"].atomic, True) + self.assertEqual(md["dos"].atomic, True) + self.assertEqual(md["foo"].atomic, True) + self.assertEqual(md["energy_redu"].atomic, False) + self.assertEqual(md["energy_derv_r"].atomic, True) + self.assertEqual(md["energy_derv_c"].atomic, False) + + def test_raise_no_redu_deriv(self): + with self.assertRaises(ValueError) as context: + (OutputVariableDef("energy", [1], False, True),) + + def test_model_decorator(self): + nf = 2 + nloc = 3 + + @model_check_output + class Foo: + def output_def(self): + defs = [ + OutputVariableDef("energy", [1], True, True), + ] + return ModelOutputDef(FittingOutputDef(defs)) + + def forward(self): + return { + "energy": np.zeros([nf, nloc, 1]), + "energy_redu": np.zeros([nf, 1]), + "energy_derv_r": np.zeros([nf, nloc, 1, 3]), + "energy_derv_c": np.zeros([nf, 1, 3, 3]), + } + + ff = Foo() + ff.forward() + + def test_model_decorator_keyerror(self): + nf = 2 + nloc = 3 + + @model_check_output + class Foo: + def output_def(self): + defs = [ + OutputVariableDef("energy", [1], True, True), + ] + return ModelOutputDef(FittingOutputDef(defs)) + + def forward(self): + return { + "energy": np.zeros([nf, nloc, 1]), + "energy_redu": np.zeros([nf, 1]), + "energy_derv_c": np.zeros([nf, 1, 3, 3]), + } + + ff = Foo() + with self.assertRaises(KeyError) as context: + ff.forward() + self.assertIn("energy_derv_r", context.exception) + + def test_model_decorator_shapeerror(self): + nf = 2 + nloc = 3 + + @model_check_output + class Foo: + def __init__( + self, + shape_rd=[nf, 1], + shape_dr=[nf, nloc, 1, 3], + ): + self.shape_rd, self.shape_dr = shape_rd, shape_dr + + def output_def(self): + defs = [ + OutputVariableDef("energy", [1], True, True), + ] + return ModelOutputDef(FittingOutputDef(defs)) + + def forward(self): + return { + "energy": np.zeros([nf, nloc, 1]), + "energy_redu": np.zeros(self.shape_rd), + "energy_derv_r": np.zeros(self.shape_dr), + "energy_derv_c": np.zeros([nf, 1, 3, 3]), + } + + ff = Foo() + ff.forward() + # shape of reduced energy + with self.assertRaises(ValueError) as context: + ff = Foo(shape_rd=[nf, nloc, 1]) + ff.forward() + self.assertIn("not matching", context.exception) + with self.assertRaises(ValueError) as context: + ff = Foo(shape_rd=[nf, 2]) + ff.forward() + self.assertIn("not matching", context.exception) + # shape of dr + with self.assertRaises(ValueError) as context: + ff = Foo(shape_dr=[nf, nloc, 1]) + ff.forward() + self.assertIn("not matching", context.exception) + with self.assertRaises(ValueError) as context: + ff = Foo(shape_dr=[nf, nloc, 1, 3, 3]) + ff.forward() + self.assertIn("not matching", context.exception) + with self.assertRaises(ValueError) as context: + ff = Foo(shape_dr=[nf, nloc, 1, 4]) + ff.forward() + self.assertIn("not matching", context.exception) + + def test_fitting_decorator(self): + nf = 2 + nloc = 3 + + @fitting_check_output + class Foo: + def output_def(self): + defs = [ + OutputVariableDef("energy", [1], True, True), + ] + return FittingOutputDef(defs) + + def forward(self): + return { + "energy": np.zeros([nf, nloc, 1]), + } + + ff = Foo() + ff.forward() + + def test_fitting_decorator_shapeerror(self): + nf = 2 + nloc = 3 + + @fitting_check_output + class Foo: + def __init__( + self, + shape=[nf, nloc, 1], + ): + self.shape = shape + + def output_def(self): + defs = [ + OutputVariableDef("energy", [1], True, True), + ] + return FittingOutputDef(defs) + + def forward(self): + return { + "energy": np.zeros(self.shape), + } + + ff = Foo() + ff.forward() + # shape of reduced energy + with self.assertRaises(ValueError) as context: + ff = Foo(shape=[nf, 1]) + ff.forward() + self.assertIn("not matching", context.exception) + with self.assertRaises(ValueError) as context: + ff = Foo(shape=[nf, nloc, 2]) + ff.forward() + self.assertIn("not matching", context.exception)