diff --git a/deepmd/dpmodel/output_def.py b/deepmd/dpmodel/output_def.py index 6cd83fcf28..9e3570d2ff 100644 --- a/deepmd/dpmodel/output_def.py +++ b/deepmd/dpmodel/output_def.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import functools +from enum import ( + IntEnum, +) from typing import ( Dict, List, @@ -107,6 +110,38 @@ def __call__( return wrapper +class OutputVariableOperation(IntEnum): + """Defines the operation of the output variable.""" + + _NONE = 0 + """No operation.""" + REDU = 1 + """Reduce the output variable.""" + DERV_R = 2 + """Derivative w.r.t. coordinates.""" + DERV_C = 4 + """Derivative w.r.t. cell.""" + _SEC_DERV_R = 8 + """Second derivative w.r.t. coordinates.""" + + +class OutputVariableCategory(IntEnum): + """Defines the category of the output variable.""" + + OUT = OutputVariableOperation._NONE + """Output variable. (e.g. atom energy)""" + REDU = OutputVariableOperation.REDU + """Reduced output variable. (e.g. system energy)""" + DERV_R = OutputVariableOperation.DERV_R + """Negative derivative w.r.t. coordinates. (e.g. force)""" + DERV_C = OutputVariableOperation.DERV_C + """Atomic component of the virial, see PRB 104, 224202 (2021) """ + DERV_C_REDU = OutputVariableOperation.DERV_C | OutputVariableOperation.REDU + """Virial, the transposed negative gradient with cell tensor times cell tensor, see eq 40 JCP 159, 054801 (2023). """ + DERV_R_DERV_R = OutputVariableOperation.DERV_R | OutputVariableOperation._SEC_DERV_R + """Hession matrix, the second derivative w.r.t. coordinates.""" + + class OutputVariableDef: """Defines the shape and other properties of the one output variable. @@ -129,7 +164,8 @@ class OutputVariableDef: If the variable is differentiated with respect to coordinates of atoms and cell tensor (pbc case). Only reduciable variable are differentiable. - + category : int + The category of the output variable. """ def __init__( @@ -139,6 +175,7 @@ def __init__( reduciable: bool = False, differentiable: bool = False, atomic: bool = True, + category: int = OutputVariableCategory.OUT.value, ): self.name = name self.shape = list(shape) @@ -149,6 +186,7 @@ def __init__( raise ValueError("only reduciable variable are differentiable") if self.reduciable and not self.atomic: raise ValueError("only reduciable variable should be atomic") + self.category = category class FittingOutputDef: @@ -255,6 +293,60 @@ def get_deriv_name(name: str) -> Tuple[str, str]: return name + "_derv_r", name + "_derv_c" +def apply_operation(var_def: OutputVariableDef, op: OutputVariableOperation) -> int: + """Apply a operation to the category of a variable definition. + + Parameters + ---------- + var_def : OutputVariableDef + The variable definition. + op : OutputVariableOperation + The operation to be applied. + + Returns + ------- + int + The new category of the variable definition. + + Raises + ------ + ValueError + If the operation has been applied to the variable definition, + and exceed the maximum limitation. + """ + if op == OutputVariableOperation.REDU or op == OutputVariableOperation.DERV_C: + if check_operation_applied(var_def, op): + raise ValueError(f"operation {op} has been applied") + elif op == OutputVariableOperation.DERV_R: + if check_operation_applied(var_def, OutputVariableOperation.DERV_R): + op = OutputVariableOperation._SEC_DERV_R + if check_operation_applied(var_def, OutputVariableOperation._SEC_DERV_R): + raise ValueError(f"operation {op} has been applied twice") + else: + raise ValueError(f"operation {op} not supported") + return var_def.category | op.value + + +def check_operation_applied( + var_def: OutputVariableDef, op: OutputVariableOperation +) -> bool: + """Check if a operation has been applied to a variable definition. + + Parameters + ---------- + var_def : OutputVariableDef + The variable definition. + op : OutputVariableOperation + The operation to be checked. + + Returns + ------- + bool + True if the operation has been applied, False otherwise. + """ + return var_def.category & op.value == op.value + + def do_reduce( def_outp_data: Dict[str, OutputVariableDef], ) -> Dict[str, OutputVariableDef]: @@ -263,7 +355,12 @@ def do_reduce( if vv.reduciable: rk = get_reduce_name(kk) def_redu[rk] = OutputVariableDef( - rk, vv.shape, reduciable=False, differentiable=False, atomic=False + rk, + vv.shape, + reduciable=False, + differentiable=False, + atomic=False, + category=apply_operation(vv, OutputVariableOperation.REDU), ) return def_redu @@ -282,6 +379,7 @@ def do_derivative( reduciable=False, differentiable=False, atomic=True, + category=apply_operation(vv, OutputVariableOperation.DERV_R), ) def_derv_c[rkc] = OutputVariableDef( rkc, @@ -289,5 +387,6 @@ def do_derivative( reduciable=True, differentiable=False, atomic=True, + category=apply_operation(vv, OutputVariableOperation.DERV_C), ) return def_derv_r, def_derv_c diff --git a/source/tests/common/dpmodel/test_output_def.py b/source/tests/common/dpmodel/test_output_def.py index aaabdc0ba6..3f7544f597 100644 --- a/source/tests/common/dpmodel/test_output_def.py +++ b/source/tests/common/dpmodel/test_output_def.py @@ -15,6 +15,9 @@ model_check_output, ) from deepmd.dpmodel.output_def import ( + OutputVariableCategory, + OutputVariableOperation, + apply_operation, check_var, ) @@ -103,6 +106,101 @@ def test_model_output_def(self): self.assertEqual(md["energy_derv_r"].atomic, True) self.assertEqual(md["energy_derv_c"].atomic, True) self.assertEqual(md["energy_derv_c_redu"].atomic, False) + # category + self.assertEqual(md["energy"].category, OutputVariableCategory.OUT) + self.assertEqual(md["dos"].category, OutputVariableCategory.OUT) + self.assertEqual(md["foo"].category, OutputVariableCategory.OUT) + self.assertEqual(md["energy_redu"].category, OutputVariableCategory.REDU) + self.assertEqual(md["energy_derv_r"].category, OutputVariableCategory.DERV_R) + self.assertEqual(md["energy_derv_c"].category, OutputVariableCategory.DERV_C) + self.assertEqual( + md["energy_derv_c_redu"].category, OutputVariableCategory.DERV_C_REDU + ) + # flag + self.assertEqual(md["energy"].category & OutputVariableOperation.REDU, 0) + self.assertEqual(md["energy"].category & OutputVariableOperation.DERV_R, 0) + self.assertEqual(md["energy"].category & OutputVariableOperation.DERV_C, 0) + self.assertEqual(md["dos"].category & OutputVariableOperation.REDU, 0) + self.assertEqual(md["dos"].category & OutputVariableOperation.DERV_R, 0) + self.assertEqual(md["dos"].category & OutputVariableOperation.DERV_C, 0) + self.assertEqual(md["foo"].category & OutputVariableOperation.REDU, 0) + self.assertEqual(md["foo"].category & OutputVariableOperation.DERV_R, 0) + self.assertEqual(md["foo"].category & OutputVariableOperation.DERV_C, 0) + self.assertEqual( + md["energy_redu"].category & OutputVariableOperation.REDU, + OutputVariableOperation.REDU, + ) + self.assertEqual(md["energy_redu"].category & OutputVariableOperation.DERV_R, 0) + self.assertEqual(md["energy_redu"].category & OutputVariableOperation.DERV_C, 0) + self.assertEqual(md["energy_derv_r"].category & OutputVariableOperation.REDU, 0) + self.assertEqual( + md["energy_derv_r"].category & OutputVariableOperation.DERV_R, + OutputVariableOperation.DERV_R, + ) + self.assertEqual( + md["energy_derv_r"].category & OutputVariableOperation.DERV_C, 0 + ) + self.assertEqual(md["energy_derv_c"].category & OutputVariableOperation.REDU, 0) + self.assertEqual( + md["energy_derv_c"].category & OutputVariableOperation.DERV_R, 0 + ) + self.assertEqual( + md["energy_derv_c"].category & OutputVariableOperation.DERV_C, + OutputVariableOperation.DERV_C, + ) + self.assertEqual( + md["energy_derv_c_redu"].category & OutputVariableOperation.REDU, + OutputVariableOperation.REDU, + ) + self.assertEqual( + md["energy_derv_c_redu"].category & OutputVariableOperation.DERV_R, 0 + ) + self.assertEqual( + md["energy_derv_c_redu"].category & OutputVariableOperation.DERV_C, + OutputVariableOperation.DERV_C, + ) + + # apply_operation + self.assertEqual( + apply_operation(md["energy"], OutputVariableOperation.REDU), + md["energy_redu"].category, + ) + self.assertEqual( + apply_operation(md["energy"], OutputVariableOperation.DERV_R), + md["energy_derv_r"].category, + ) + self.assertEqual( + apply_operation(md["energy"], OutputVariableOperation.DERV_C), + md["energy_derv_c"].category, + ) + self.assertEqual( + apply_operation(md["energy_derv_c"], OutputVariableOperation.REDU), + md["energy_derv_c_redu"].category, + ) + # raise ValueError + with self.assertRaises(ValueError): + apply_operation(md["energy_redu"], OutputVariableOperation.REDU) + with self.assertRaises(ValueError): + apply_operation(md["energy_derv_c"], OutputVariableOperation.DERV_C) + with self.assertRaises(ValueError): + apply_operation(md["energy_derv_c_redu"], OutputVariableOperation.REDU) + # hession + hession_cat = apply_operation( + md["energy_derv_r"], OutputVariableOperation.DERV_R + ) + self.assertEqual( + hession_cat & OutputVariableOperation.DERV_R, OutputVariableOperation.DERV_R + ) + self.assertEqual( + hession_cat & OutputVariableOperation._SEC_DERV_R, + OutputVariableOperation._SEC_DERV_R, + ) + self.assertEqual(hession_cat, OutputVariableCategory.DERV_R_DERV_R) + hession_vardef = OutputVariableDef( + "energy_derv_r_derv_r", [1], False, False, category=hession_cat + ) + with self.assertRaises(ValueError): + apply_operation(hession_vardef, OutputVariableOperation.DERV_R) def test_raise_no_redu_deriv(self): with self.assertRaises(ValueError) as context: