Skip to content

Commit

Permalink
pt: rename atomic_virial to atom_virial in the model output (#3226)
Browse files Browse the repository at this point in the history
To be consistent with TF, as discussed in
#3213 (comment).
Old PT models are expected to be incompatible.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Feb 5, 2024
1 parent cd77429 commit 17f2c35
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
12 changes: 6 additions & 6 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ def _eval_model(
batch_output["force"].reshape(nframes, natoms, 3).detach().cpu().numpy()
)
virial_out = batch_output["virial"].reshape(nframes, 9).detach().cpu().numpy()
if "atomic_virial" in batch_output:
if "atom_virial" in batch_output:
atomic_virial_out = (
batch_output["atomic_virial"]
batch_output["atom_virial"]
.reshape(nframes, natoms, 9)
.detach()
.cpu()
Expand Down Expand Up @@ -326,9 +326,9 @@ def eval_model(
force_out.append(batch_output["force"].detach().cpu().numpy())
if "virial" in batch_output:
virial_out.append(batch_output["virial"].detach().cpu().numpy())
if "atomic_virial" in batch_output:
if "atom_virial" in batch_output:
atomic_virial_out.append(
batch_output["atomic_virial"].detach().cpu().numpy()
batch_output["atom_virial"].detach().cpu().numpy()
)
if "updated_coord" in batch_output:
updated_coord_out.append(
Expand All @@ -345,8 +345,8 @@ def eval_model(
force_out.append(batch_output["force"])
if "virial" in batch_output:
virial_out.append(batch_output["virial"])
if "atomic_virial" in batch_output:
atomic_virial_out.append(batch_output["atomic_virial"])
if "atom_virial" in batch_output:
atomic_virial_out.append(batch_output["atom_virial"])
if "updated_coord" in batch_output:
updated_coord_out.append(batch_output["updated_coord"])
if "logits" in batch_output:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def forward(
if self.do_grad("energy"):
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
if do_atomic_virial:
model_predict["atomic_virial"] = model_ret["energy_derv_c"].squeeze(
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(
-3
)
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-3)
Expand Down
10 changes: 5 additions & 5 deletions source/tests/pt/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def get_intermediate_state(self, num_steps=1):
"energy": model_pred["energy"],
"force": model_pred["force"],
"virial": model_pred["virial"],
"atomic_virial": model_pred["atom_virial"],
"atom_virial": model_pred["atom_virial"],
}

# Get statistics of each component
Expand Down Expand Up @@ -359,7 +359,7 @@ def test_consistency(self):
model_predict["energy"],
model_predict["force"],
model_predict["virial"],
model_predict["atomic_virial"],
model_predict["atom_virial"],
)
cur_lr = my_lr.value(self.wanted_step)
model_pred = {
Expand Down Expand Up @@ -395,10 +395,10 @@ def test_consistency(self):
.detach()
.numpy(),
)
self.assertIsNone(model_predict_1.get("atomic_virial", None))
self.assertIsNone(model_predict_1.get("atom_virial", None))
np.testing.assert_allclose(
head_dict["atomic_virial"],
p_atomic_virial.view(*head_dict["atomic_virial"].shape)
head_dict["atom_virial"],
p_atomic_virial.view(*head_dict["atom_virial"].shape)
.cpu()
.detach()
.numpy(),
Expand Down

0 comments on commit 17f2c35

Please sign in to comment.