From baa281af7d6645b062743decbf9bfe19181be9a5 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 4 Feb 2024 01:29:16 -0500 Subject: [PATCH] pt: rename atomic_virial to atom_virial To be consistent with TF. Signed-off-by: Jinzhe Zeng --- deepmd/pt/infer/deep_eval.py | 12 ++++++------ deepmd/pt/model/model/ener.py | 2 +- source/tests/pt/test_model.py | 10 +++++----- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index b5d596ed0f..4ba9e17b52 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -208,9 +208,9 @@ def _eval_model( batch_output["force"].reshape(nframes, natoms, 3).detach().cpu().numpy() ) virial_out = batch_output["virial"].reshape(nframes, 9).detach().cpu().numpy() - if "atomic_virial" in batch_output: + if "atom_virial" in batch_output: atomic_virial_out = ( - batch_output["atomic_virial"] + batch_output["atom_virial"] .reshape(nframes, natoms, 9) .detach() .cpu() @@ -326,9 +326,9 @@ def eval_model( force_out.append(batch_output["force"].detach().cpu().numpy()) if "virial" in batch_output: virial_out.append(batch_output["virial"].detach().cpu().numpy()) - if "atomic_virial" in batch_output: + if "atom_virial" in batch_output: atomic_virial_out.append( - batch_output["atomic_virial"].detach().cpu().numpy() + batch_output["atom_virial"].detach().cpu().numpy() ) if "updated_coord" in batch_output: updated_coord_out.append( @@ -345,8 +345,8 @@ def eval_model( force_out.append(batch_output["force"]) if "virial" in batch_output: virial_out.append(batch_output["virial"]) - if "atomic_virial" in batch_output: - atomic_virial_out.append(batch_output["atomic_virial"]) + if "atom_virial" in batch_output: + atomic_virial_out.append(batch_output["atom_virial"]) if "updated_coord" in batch_output: updated_coord_out.append(batch_output["updated_coord"]) if "logits" in batch_output: diff --git a/deepmd/pt/model/model/ener.py b/deepmd/pt/model/model/ener.py index a408689d8d..1930936336 100644 --- a/deepmd/pt/model/model/ener.py +++ b/deepmd/pt/model/model/ener.py @@ -45,7 +45,7 @@ def forward( if self.do_grad("energy"): model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) if do_atomic_virial: - model_predict["atomic_virial"] = model_ret["energy_derv_c"].squeeze( + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze( -3 ) model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-3) diff --git a/source/tests/pt/test_model.py b/source/tests/pt/test_model.py index e87a53969c..bb99759d16 100644 --- a/source/tests/pt/test_model.py +++ b/source/tests/pt/test_model.py @@ -146,7 +146,7 @@ def get_intermediate_state(self, num_steps=1): "energy": model_pred["energy"], "force": model_pred["force"], "virial": model_pred["virial"], - "atomic_virial": model_pred["atom_virial"], + "atom_virial": model_pred["atom_virial"], } # Get statistics of each component @@ -359,7 +359,7 @@ def test_consistency(self): model_predict["energy"], model_predict["force"], model_predict["virial"], - model_predict["atomic_virial"], + model_predict["atom_virial"], ) cur_lr = my_lr.value(self.wanted_step) model_pred = { @@ -395,10 +395,10 @@ def test_consistency(self): .detach() .numpy(), ) - self.assertIsNone(model_predict_1.get("atomic_virial", None)) + self.assertIsNone(model_predict_1.get("atom_virial", None)) np.testing.assert_allclose( - head_dict["atomic_virial"], - p_atomic_virial.view(*head_dict["atomic_virial"].shape) + head_dict["atom_virial"], + p_atomic_virial.view(*head_dict["atom_virial"].shape) .cpu() .detach() .numpy(),