Skip to content

Commit

Permalink
pt: support eval_typeebd for deep_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Mar 26, 2024
1 parent 964f02d commit e454b8c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
25 changes: 25 additions & 0 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
GLOBAL_PT_FLOAT_PRECISION,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
to_torch_tensor,
)

Expand Down Expand Up @@ -526,6 +527,30 @@ def _get_output_shape(self, odef, nframes, natoms):
else:
raise RuntimeError("unknown category")

def eval_typeebd(self) -> np.ndarray:
"""Evaluate output of type embedding network by using this model.
Returns
-------
np.ndarray
The output of type embedding network. The shape is [ntypes + 1, o_size],
where ntypes is the number of types, and o_size is the number of nodes
in the output layer.
Raises
------
KeyError
If the model does not enable type embedding.
"""
model = self.dp.model["Default"]
tebd = None
for item in model.named_parameters():
if "type_embedding.embedding.weight" in item[0]:
tebd = to_numpy_array(item[1])
if tebd is None:
raise KeyError("Model has no type embedding!")
return tebd


# For tests only
def eval_model(
Expand Down
6 changes: 6 additions & 0 deletions source/tests/pt/model/test_deeppot.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ def test_uni(self):
self.assertIsInstance(dp, DeepPot)
# its methods has been tested in test_dp_test

def test_eval_typeebd(self):
dp = DeepPot(str(self.model))
eval_typeebd = dp.eval_typeebd()
self.assertEqual(eval_typeebd.shape, (3, 8))
np.testing.assert_allclose(eval_typeebd[-1], np.zeros_like(eval_typeebd[-1]))


class TestDeepPotFrozen(TestDeepPot):
def setUp(self):
Expand Down

0 comments on commit e454b8c

Please sign in to comment.