diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index db601d7e5a..48630007d0 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -46,6 +46,9 @@ from deepmd.pt.model.model import ( get_model, ) +from deepmd.pt.model.network.network import ( + TypeEmbedNetConsistent, +) from deepmd.pt.train.wrapper import ( ModelWrapper, ) @@ -61,6 +64,7 @@ RESERVED_PRECISON_DICT, ) from deepmd.pt.utils.utils import ( + to_numpy_array, to_torch_tensor, ) @@ -556,6 +560,36 @@ def _get_output_shape(self, odef, nframes, natoms): else: raise RuntimeError("unknown category") + def eval_typeebd(self) -> np.ndarray: + """Evaluate output of type embedding network by using this model. + + Returns + ------- + np.ndarray + The output of type embedding network. The shape is [ntypes, o_size] or [ntypes + 1, o_size], + where ntypes is the number of types, and o_size is the number of nodes + in the output layer. If there are multiple type embedding networks, + these outputs will be concatenated along the second axis. + + Raises + ------ + KeyError + If the model does not enable type embedding. + + See Also + -------- + deepmd.pt.model.network.network.TypeEmbedNetConsistent : + The type embedding network. + """ + out = [] + for mm in self.dp.model["Default"].modules(): + if mm.original_name == TypeEmbedNetConsistent.__name__: + out.append(mm(DEVICE)) + if not out: + raise KeyError("The model has no type embedding networks.") + typeebd = torch.cat(out, dim=1) + return to_numpy_array(typeebd) + # For tests only def eval_model( diff --git a/source/tests/pt/model/test_deeppot.py b/source/tests/pt/model/test_deeppot.py index 7268181c26..8917c62cce 100644 --- a/source/tests/pt/model/test_deeppot.py +++ b/source/tests/pt/model/test_deeppot.py @@ -110,6 +110,14 @@ def test_uni(self): self.assertIsInstance(dp, DeepPot) # its methods has been tested in test_dp_test + def test_eval_typeebd(self): + dp = DeepPot(str(self.model)) + eval_typeebd = dp.eval_typeebd() + self.assertEqual( + eval_typeebd.shape, (len(self.config["model"]["type_map"]) + 1, 8) + ) + np.testing.assert_allclose(eval_typeebd[-1], np.zeros_like(eval_typeebd[-1])) + class TestDeepPotFrozen(TestDeepPot): def setUp(self):