From 549392119cd1704cfdb02c26b8cb8b9afb695935 Mon Sep 17 00:00:00 2001 From: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Date: Thu, 18 Jan 2024 09:26:10 +0800 Subject: [PATCH] fix: some issue of the output def (#3152) - strict type hint - allow the last dim to be variable (by setting the dim to -1) - remove variable def, which is not very useful. - _derv_c should be defined for each atom --------- Co-authored-by: Han Wang Co-authored-by: Jinzhe Zeng Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd_utils/model_format/__init__.py | 6 +- deepmd_utils/model_format/output_def.py | 115 ++++++++++++------------ source/tests/test_output_def.py | 71 +++++++++++++-- 3 files changed, 128 insertions(+), 64 deletions(-) diff --git a/deepmd_utils/model_format/__init__.py b/deepmd_utils/model_format/__init__.py index 72dd7b59ee..253bca3507 100644 --- a/deepmd_utils/model_format/__init__.py +++ b/deepmd_utils/model_format/__init__.py @@ -24,8 +24,9 @@ FittingOutputDef, ModelOutputDef, OutputVariableDef, - VariableDef, fitting_check_output, + get_deriv_name, + get_reduce_name, model_check_output, ) from .se_e2_a import ( @@ -52,7 +53,8 @@ "ModelOutputDef", "FittingOutputDef", "OutputVariableDef", - "VariableDef", "model_check_output", "fitting_check_output", + "get_reduce_name", + "get_deriv_name", ] diff --git a/deepmd_utils/model_format/output_def.py b/deepmd_utils/model_format/output_def.py index 7feb24a145..268dc21ea6 100644 --- a/deepmd_utils/model_format/output_def.py +++ b/deepmd_utils/model_format/output_def.py @@ -3,23 +3,34 @@ Dict, List, Tuple, - Union, ) +def check_shape( + shape: List[int], + def_shape: List[int], +): + """Check if the shape satisfies the defined shape.""" + assert len(shape) == len(def_shape) + if def_shape[-1] == -1: + if list(shape[:-1]) != def_shape[:-1]: + raise ValueError(f"{shape[:-1]} shape not matching def {def_shape[:-1]}") + else: + if list(shape) != def_shape: + raise ValueError(f"{shape} shape not matching def {def_shape}") + + def check_var(var, var_def): if var_def.atomic: # var.shape == [nf, nloc, *var_def.shape] if len(var.shape) != len(var_def.shape) + 2: raise ValueError(f"{var.shape[2:]} length not matching def {var_def.shape}") - if list(var.shape[2:]) != var_def.shape: - raise ValueError(f"{var.shape[2:]} not matching def {var_def.shape}") + check_shape(list(var.shape[2:]), var_def.shape) else: # var.shape == [nf, *var_def.shape] if len(var.shape) != len(var_def.shape) + 1: raise ValueError(f"{var.shape[1:]} length not matching def {var_def.shape}") - if list(var.shape[1:]) != var_def.shape: - raise ValueError(f"{var.shape[1:]} not matching def {var_def.shape}") + check_shape(list(var.shape[1:]), var_def.shape) def model_check_output(cls): @@ -38,7 +49,7 @@ def __init__( **kwargs, ): super().__init__(*args, **kwargs) - self.md = cls.output_def(self) + self.md = self.output_def() def __call__( self, @@ -77,7 +88,7 @@ def __init__( **kwargs, ): super().__init__(*args, **kwargs) - self.md = cls.output_def(self) + self.md = self.output_def() def __call__( self, @@ -93,35 +104,7 @@ def __call__( return wrapper -class VariableDef: - """Defines the shape and other properties of a variable. - - Parameters - ---------- - name - Name of the output variable. Notice that the xxxx_redu, - xxxx_derv_c, xxxx_derv_r are reserved names that should - not be used to define variables. - shape - The shape of the variable. e.g. energy should be [1], - dipole should be [3], polarizabilty should be [3,3]. - atomic - If the variable is defined for each atom. - - """ - - def __init__( - self, - name: str, - shape: Union[List[int], Tuple[int]], - atomic: bool = True, - ): - self.name = name - self.shape = list(shape) - self.atomic = atomic - - -class OutputVariableDef(VariableDef): +class OutputVariableDef: """Defines the shape and other properties of the one output variable. It is assume that the fitting network output variables for each @@ -149,12 +132,14 @@ class OutputVariableDef(VariableDef): def __init__( self, name: str, - shape: Union[List[int], Tuple[int]], + shape: List[int], reduciable: bool = False, differentiable: bool = False, + atomic: bool = True, ): - # fitting output must be atomic - super().__init__(name, shape, atomic=True) + self.name = name + self.shape = list(shape) + self.atomic = atomic self.reduciable = reduciable self.differentiable = differentiable if not self.reduciable and self.differentiable: @@ -176,13 +161,13 @@ class FittingOutputDef: def __init__( self, - var_defs: List[OutputVariableDef] = [], + var_defs: List[OutputVariableDef], ): self.var_defs = {vv.name: vv for vv in var_defs} def __getitem__( self, - key, + key: str, ) -> OutputVariableDef: return self.var_defs[key] @@ -215,7 +200,7 @@ def __init__( self.def_outp = fit_defs self.def_redu = do_reduce(self.def_outp) self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp) - self.var_defs = {} + self.var_defs: Dict[str, OutputVariableDef] = {} for ii in [ self.def_outp.get_data(), self.def_redu, @@ -224,10 +209,16 @@ def __init__( ]: self.var_defs.update(ii) - def __getitem__(self, key) -> VariableDef: + def __getitem__( + self, + key: str, + ) -> OutputVariableDef: return self.var_defs[key] - def get_data(self, key) -> Dict[str, VariableDef]: + def get_data( + self, + key: str, + ) -> Dict[str, OutputVariableDef]: return self.var_defs def keys(self): @@ -246,33 +237,45 @@ def keys_derv_c(self): return self.def_derv_c.keys() -def get_reduce_name(name): +def get_reduce_name(name: str) -> str: return name + "_redu" -def get_deriv_name(name): +def get_deriv_name(name: str) -> Tuple[str, str]: return name + "_derv_r", name + "_derv_c" def do_reduce( - def_outp, -): - def_redu = {} + def_outp: FittingOutputDef, +) -> Dict[str, OutputVariableDef]: + def_redu: Dict[str, OutputVariableDef] = {} for kk, vv in def_outp.get_data().items(): if vv.reduciable: rk = get_reduce_name(kk) - def_redu[rk] = VariableDef(rk, vv.shape, atomic=False) + def_redu[rk] = OutputVariableDef( + rk, vv.shape, reduciable=False, differentiable=False, atomic=False + ) return def_redu def do_derivative( - def_outp, -): - def_derv_r = {} - def_derv_c = {} + def_outp: FittingOutputDef, +) -> Tuple[Dict[str, OutputVariableDef], Dict[str, OutputVariableDef]]: + def_derv_r: Dict[str, OutputVariableDef] = {} + def_derv_c: Dict[str, OutputVariableDef] = {} for kk, vv in def_outp.get_data().items(): if vv.differentiable: rkr, rkc = get_deriv_name(kk) - def_derv_r[rkr] = VariableDef(rkr, [*vv.shape, 3], atomic=True) - def_derv_c[rkc] = VariableDef(rkc, [*vv.shape, 3, 3], atomic=False) + def_derv_r[rkr] = OutputVariableDef( + rkr, + vv.shape + [3], # noqa: RUF005 + reduciable=False, + differentiable=False, + ) + def_derv_c[rkc] = OutputVariableDef( + rkc, + vv.shape + [3, 3], # noqa: RUF005 + reduciable=True, + differentiable=False, + ) return def_derv_r, def_derv_c diff --git a/source/tests/test_output_def.py b/source/tests/test_output_def.py index e0c56784da..82d1b13a80 100644 --- a/source/tests/test_output_def.py +++ b/source/tests/test_output_def.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import unittest +from typing import ( + List, +) import numpy as np @@ -11,6 +14,21 @@ fitting_check_output, model_check_output, ) +from deepmd_utils.model_format.output_def import ( + check_var, +) + + +class VariableDef: + def __init__( + self, + name: str, + shape: List[int], + atomic: bool = True, + ): + self.name = name + self.shape = list(shape) + self.atomic = atomic class TestDef(unittest.TestCase): @@ -81,7 +99,7 @@ def test_model_output_def(self): self.assertEqual(md["foo"].atomic, True) self.assertEqual(md["energy_redu"].atomic, False) self.assertEqual(md["energy_derv_r"].atomic, True) - self.assertEqual(md["energy_derv_c"].atomic, False) + self.assertEqual(md["energy_derv_c"].atomic, True) def test_raise_no_redu_deriv(self): with self.assertRaises(ValueError) as context: @@ -90,6 +108,7 @@ def test_raise_no_redu_deriv(self): def test_model_decorator(self): nf = 2 nloc = 3 + nall = 4 @model_check_output class Foo(NativeOP): @@ -103,8 +122,8 @@ def call(self): return { "energy": np.zeros([nf, nloc, 1]), "energy_redu": np.zeros([nf, 1]), - "energy_derv_r": np.zeros([nf, nloc, 1, 3]), - "energy_derv_c": np.zeros([nf, 1, 3, 3]), + "energy_derv_r": np.zeros([nf, nall, 1, 3]), + "energy_derv_c": np.zeros([nf, nall, 1, 3, 3]), } ff = Foo() @@ -113,6 +132,7 @@ def call(self): def test_model_decorator_keyerror(self): nf = 2 nloc = 3 + nall = 4 @model_check_output class Foo(NativeOP): @@ -129,7 +149,7 @@ def call(self): return { "energy": np.zeros([nf, nloc, 1]), "energy_redu": np.zeros([nf, 1]), - "energy_derv_c": np.zeros([nf, 1, 3, 3]), + "energy_derv_c": np.zeros([nf, nall, 1, 3, 3]), } ff = Foo() @@ -140,13 +160,14 @@ def call(self): def test_model_decorator_shapeerror(self): nf = 2 nloc = 3 + nall = 4 @model_check_output class Foo(NativeOP): def __init__( self, shape_rd=[nf, 1], - shape_dr=[nf, nloc, 1, 3], + shape_dr=[nf, nall, 1, 3], ): self.shape_rd, self.shape_dr = shape_rd, shape_dr @@ -161,7 +182,7 @@ def call(self): "energy": np.zeros([nf, nloc, 1]), "energy_redu": np.zeros(self.shape_rd), "energy_derv_r": np.zeros(self.shape_dr), - "energy_derv_c": np.zeros([nf, 1, 3, 3]), + "energy_derv_c": np.zeros([nf, nall, 1, 3, 3]), } ff = Foo() @@ -192,6 +213,7 @@ def call(self): def test_fitting_decorator(self): nf = 2 nloc = 3 + nall = 4 @fitting_check_output class Foo(NativeOP): @@ -243,3 +265,40 @@ def call(self): ff = Foo(shape=[nf, nloc, 2]) ff() self.assertIn("not matching", context.exception) + + def test_check_var(self): + var_def = VariableDef("foo", [2, 3], atomic=True) + with self.assertRaises(ValueError) as context: + check_var(np.zeros([2, 3, 4, 5, 6]), var_def) + self.assertIn("length not matching", context.exception) + with self.assertRaises(ValueError) as context: + check_var(np.zeros([2, 3, 4, 5]), var_def) + self.assertIn("shape not matching", context.exception) + check_var(np.zeros([2, 3, 2, 3]), var_def) + + var_def = VariableDef("foo", [2, 3], atomic=False) + with self.assertRaises(ValueError) as context: + check_var(np.zeros([2, 3, 4, 5]), var_def) + self.assertIn("length not matching", context.exception) + with self.assertRaises(ValueError) as context: + check_var(np.zeros([2, 3, 4]), var_def) + self.assertIn("shape not matching", context.exception) + check_var(np.zeros([2, 2, 3]), var_def) + + var_def = VariableDef("foo", [2, -1], atomic=True) + with self.assertRaises(ValueError) as context: + check_var(np.zeros([2, 3, 4, 5, 6]), var_def) + self.assertIn("length not matching", context.exception) + with self.assertRaises(ValueError) as context: + check_var(np.zeros([2, 3, 4, 5]), var_def) + self.assertIn("shape not matching", context.exception) + check_var(np.zeros([2, 3, 2, 8]), var_def) + + var_def = VariableDef("foo", [2, -1], atomic=False) + with self.assertRaises(ValueError) as context: + check_var(np.zeros([2, 3, 4, 5]), var_def) + self.assertIn("length not matching", context.exception) + with self.assertRaises(ValueError) as context: + check_var(np.zeros([2, 3, 4]), var_def) + self.assertIn("shape not matching", context.exception) + check_var(np.zeros([2, 2, 8]), var_def)