diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 8a3a61400d..e2535014d1 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -55,6 +55,7 @@ GLOBAL_PT_FLOAT_PRECISION, ) from deepmd.pt.utils.utils import ( + to_numpy_array, to_torch_tensor, ) @@ -526,6 +527,30 @@ def _get_output_shape(self, odef, nframes, natoms): else: raise RuntimeError("unknown category") + def eval_typeebd(self) -> np.ndarray: + """Evaluate output of type embedding network by using this model. + + Returns + ------- + np.ndarray + The output of type embedding network. The shape is [ntypes + 1, o_size], + where ntypes is the number of types, and o_size is the number of nodes + in the output layer. + + Raises + ------ + KeyError + If the model does not enable type embedding. + """ + model = self.dp.model["Default"] + tebd = None + for item in model.named_parameters(): + if "type_embedding.embedding.weight" in item[0]: + tebd = to_numpy_array(item[1]) + if tebd is None: + raise KeyError("Model has no type embedding!") + return tebd + # For tests only def eval_model( diff --git a/source/tests/pt/model/test_deeppot.py b/source/tests/pt/model/test_deeppot.py index 68b1ff65d5..5bcd98df96 100644 --- a/source/tests/pt/model/test_deeppot.py +++ b/source/tests/pt/model/test_deeppot.py @@ -114,6 +114,12 @@ def test_uni(self): self.assertIsInstance(dp, DeepPot) # its methods has been tested in test_dp_test + def test_eval_typeebd(self): + dp = DeepPot(str(self.model)) + eval_typeebd = dp.eval_typeebd() + self.assertEqual(eval_typeebd.shape, (3, 8)) + np.testing.assert_allclose(eval_typeebd[-1], np.zeros_like(eval_typeebd[-1])) + class TestDeepPotFrozen(TestDeepPot): def setUp(self):