Skip to content

Commit

Permalink
fix bug of output def: the reduced virial is not defined. (#3219)
Browse files Browse the repository at this point in the history
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
wanghan-iapcm and Han Wang authored Feb 2, 2024
1 parent 701b913 commit 677d936
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
21 changes: 15 additions & 6 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def __init__(
self.differentiable = differentiable
if not self.reduciable and self.differentiable:
raise ValueError("only reduciable variable are differentiable")
if self.reduciable and not self.atomic:
raise ValueError("only reduciable variable should be atomic")


class FittingOutputDef:
Expand Down Expand Up @@ -201,14 +203,16 @@ def __init__(
fit_defs: FittingOutputDef,
):
self.def_outp = fit_defs
self.def_redu = do_reduce(self.def_outp)
self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp)
self.def_redu = do_reduce(self.def_outp.get_data())
self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp.get_data())
self.def_derv_c_redu = do_reduce(self.def_derv_c)
self.var_defs: Dict[str, OutputVariableDef] = {}
for ii in [
self.def_outp.get_data(),
self.def_redu,
self.def_derv_c,
self.def_derv_r,
self.def_derv_c_redu,
]:
self.var_defs.update(ii)

Expand Down Expand Up @@ -239,6 +243,9 @@ def keys_derv_r(self):
def keys_derv_c(self):
return self.def_derv_c.keys()

def keys_derv_c_redu(self):
return self.def_derv_c_redu.keys()


def get_reduce_name(name: str) -> str:
return name + "_redu"
Expand All @@ -249,10 +256,10 @@ def get_deriv_name(name: str) -> Tuple[str, str]:


def do_reduce(
def_outp: FittingOutputDef,
def_outp_data: Dict[str, OutputVariableDef],
) -> Dict[str, OutputVariableDef]:
def_redu: Dict[str, OutputVariableDef] = {}
for kk, vv in def_outp.get_data().items():
for kk, vv in def_outp_data.items():
if vv.reduciable:
rk = get_reduce_name(kk)
def_redu[rk] = OutputVariableDef(
Expand All @@ -262,23 +269,25 @@ def do_reduce(


def do_derivative(
def_outp: FittingOutputDef,
def_outp_data: Dict[str, OutputVariableDef],
) -> Tuple[Dict[str, OutputVariableDef], Dict[str, OutputVariableDef]]:
def_derv_r: Dict[str, OutputVariableDef] = {}
def_derv_c: Dict[str, OutputVariableDef] = {}
for kk, vv in def_outp.get_data().items():
for kk, vv in def_outp_data.items():
if vv.differentiable:
rkr, rkc = get_deriv_name(kk)
def_derv_r[rkr] = OutputVariableDef(
rkr,
vv.shape + [3], # noqa: RUF005
reduciable=False,
differentiable=False,
atomic=True,
)
def_derv_c[rkc] = OutputVariableDef(
rkc,
vv.shape + [3, 3], # noqa: RUF005
reduciable=True,
differentiable=False,
atomic=True,
)
return def_derv_r, def_derv_c
7 changes: 7 additions & 0 deletions source/tests/common/test_output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_model_output_def(self):
"energy_redu",
"energy_derv_r",
"energy_derv_c",
"energy_derv_c_redu",
"dos_redu",
]
self.assertEqual(
Expand All @@ -93,18 +94,24 @@ def test_model_output_def(self):
self.assertEqual(md["energy_redu"].shape, [1])
self.assertEqual(md["energy_derv_r"].shape, [1, 3])
self.assertEqual(md["energy_derv_c"].shape, [1, 3, 3])
self.assertEqual(md["energy_derv_c_redu"].shape, [1, 3, 3])
# atomic
self.assertEqual(md["energy"].atomic, True)
self.assertEqual(md["dos"].atomic, True)
self.assertEqual(md["foo"].atomic, True)
self.assertEqual(md["energy_redu"].atomic, False)
self.assertEqual(md["energy_derv_r"].atomic, True)
self.assertEqual(md["energy_derv_c"].atomic, True)
self.assertEqual(md["energy_derv_c_redu"].atomic, False)

def test_raise_no_redu_deriv(self):
with self.assertRaises(ValueError) as context:
(OutputVariableDef("energy", [1], False, True),)

def test_raise_redu_not_atomic(self):
with self.assertRaises(ValueError) as context:
(OutputVariableDef("energy", [1], True, False, atomic=False),)

def test_model_decorator(self):
nf = 2
nloc = 3
Expand Down

0 comments on commit 677d936

Please sign in to comment.