Skip to content

Commit

Permalink
feat(pt): support eval_typeebd for DeepEval
Browse files Browse the repository at this point in the history
Closes deepmodeling#3608.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Sep 7, 2024
1 parent c3ba728 commit c2eac9d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
34 changes: 34 additions & 0 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
from deepmd.pt.model.model import (
get_model,
)
from deepmd.pt.model.network.network import (
TypeEmbedNetConsistent,
)
from deepmd.pt.train.wrapper import (
ModelWrapper,
)
Expand All @@ -61,6 +64,7 @@
RESERVED_PRECISON_DICT,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
to_torch_tensor,
)

Expand Down Expand Up @@ -556,6 +560,36 @@ def _get_output_shape(self, odef, nframes, natoms):
else:
raise RuntimeError("unknown category")

def eval_typeebd(self) -> np.ndarray:
"""Evaluate output of type embedding network by using this model.
Returns
-------
np.ndarray
The output of type embedding network. The shape is [ntypes, o_size] or [ntypes + 1, o_size],
where ntypes is the number of types, and o_size is the number of nodes
in the output layer. If there are multiple type embedding networks,
these outputs will be concatenated along the second axis.
Raises
------
KeyError
If the model does not enable type embedding.
See Also
--------
deepmd.pt.model.network.network.TypeEmbedNetConsistent :
The type embedding network.
"""
out = []
for mm in self.dp.model["Default"].modules():
if mm.original_name == TypeEmbedNetConsistent.__name__:
out.append(mm(DEVICE))
if not out:
raise KeyError("The model has no type embedding networks.")
typeebd = torch.cat(out, dim=1)
return to_numpy_array(typeebd)


# For tests only
def eval_model(
Expand Down
8 changes: 8 additions & 0 deletions source/tests/pt/model/test_deeppot.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ def test_uni(self):
self.assertIsInstance(dp, DeepPot)
# its methods has been tested in test_dp_test

def test_eval_typeebd(self):
dp = DeepPot(str(self.model))
eval_typeebd = dp.eval_typeebd()
self.assertEqual(
eval_typeebd.shape, (len(self.config["model"]["type_map"]) + 1, 8)
)
np.testing.assert_allclose(eval_typeebd[-1], np.zeros_like(eval_typeebd[-1]))


class TestDeepPotFrozen(TestDeepPot):
def setUp(self):
Expand Down

0 comments on commit c2eac9d

Please sign in to comment.