Skip to content

Commit

Permalink
Update outdef for mask_mag
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Mar 5, 2024
1 parent 9772fa4 commit f27ae51
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 11 deletions.
27 changes: 27 additions & 0 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def __init__(
self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp.get_data())
self.def_hess_r, _ = do_derivative(self.def_derv_r)
self.def_derv_c_redu = do_reduce(self.def_derv_c)
self.def_mask = do_mask(self.def_outp.get_data())
self.var_defs: Dict[str, OutputVariableDef] = {}
for ii in [
self.def_outp.get_data(),
Expand All @@ -289,6 +290,7 @@ def __init__(
self.def_derv_r,
self.def_derv_c_redu,
self.def_hess_r,
self.def_mask,
]:
self.var_defs.update(ii)

Expand Down Expand Up @@ -415,6 +417,31 @@ def do_reduce(
return def_redu


def do_mask(
def_outp_data: Dict[str, OutputVariableDef],
) -> Dict[str, OutputVariableDef]:
def_mask: Dict[str, OutputVariableDef] = {}
# for deep eval when has atomic mask
def_mask["mask"] = OutputVariableDef(
name="mask",
shape=[1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
)
for kk, vv in def_outp_data.items():
if vv.magnetic:
# for deep eval when has atomic mask for magnetic atoms
def_mask["mask_mag"] = OutputVariableDef(
name="mask_mag",
shape=[1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
)
return def_mask


def do_derivative(
def_outp_data: Dict[str, OutputVariableDef],
) -> Tuple[Dict[str, OutputVariableDef], Dict[str, OutputVariableDef]]:
Expand Down
16 changes: 11 additions & 5 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,14 @@ def test_ener(
if dp.has_spin:
force_m = ret[5]
force_m = force_m.reshape([numb_test, -1])
avm = ret[6]
avm = avm.reshape([numb_test, -1])
mask_mag = ret[6]
mask_mag = mask_mag.reshape([numb_test, -1])
else:
if dp.has_spin:
force_m = ret[3]
force_m = force_m.reshape([numb_test, -1])
mask_mag = ret[4]
mask_mag = mask_mag.reshape([numb_test, -1])
out_put_spin = dp.get_ntypes_spin() != 0 or dp.has_spin
if out_put_spin:
if dp.get_ntypes_spin() != 0: # old tf support for spin
Expand Down Expand Up @@ -391,8 +393,12 @@ def test_ener(
else: # pt support for spin
force_r = force
test_force_r = test_data["force"][:numb_test]
force_m = force_m
test_force_m = test_data["force_mag"][:numb_test]
force_m = force_m.reshape(-1, 3)[mask_mag.reshape(-1)].reshape(nframes, -1)
test_force_m = (
test_data["force_mag"][:numb_test]
.reshape(-1, 3)[mask_mag.reshape(-1)]
.reshape(nframes, -1)
)

diff_e = energy - test_data["energy"][:numb_test].reshape([-1, 1])
mae_e = mae(diff_e)
Expand Down Expand Up @@ -431,7 +437,7 @@ def test_ener(
log.info(f"Force spin MAE : {mae_fm:e} eV/uB")
log.info(f"Force spin RMSE : {rmse_fm:e} eV/uB")

if data.pbc:
if data.pbc and not out_put_spin:
log.info(f"Virial MAE : {mae_v:e} eV")
log.info(f"Virial RMSE : {rmse_v:e} eV")
log.info(f"Virial MAE/Natoms : {mae_va:e} eV")
Expand Down
2 changes: 2 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class DeepEvalBackend(ABC):
"dipole_derv_c_redu": "virial",
"dos": "atom_dos",
"dos_redu": "dos",
"mask_mag": "mask_mag",
"mask": "mask",
}

@abstractmethod
Expand Down
8 changes: 2 additions & 6 deletions deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,8 @@ def eval(
)
if getattr(self.deep_eval, "has_spin_pt", False):
force_mag = results["energy_derv_r_mag"].reshape(nframes, natoms, 3)
result = result + tuple(force_mag)
if atomic:
atomic_virial_mag = results["energy_derv_c_mag"].reshape(
nframes, natoms, 9
)
result = result + tuple(atomic_virial_mag)
mask_mag = results["mask_mag"].reshape(nframes, natoms, 1)
result = (*list(result), force_mag, mask_mag)
return result


Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def _get_request_defs(self, atomic: bool) -> List[OutputVariableDef]:
for x in self.output_def.var_defs.values()
if x.category
in (
OutputVariableCategory.OUT,
OutputVariableCategory.REDU,
OutputVariableCategory.DERV_R,
OutputVariableCategory.DERV_C_REDU,
Expand Down

0 comments on commit f27ae51

Please sign in to comment.