From 4fa87547a6635810c7327b93fedd3eda795f07f4 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 1 Feb 2024 16:44:48 -0500 Subject: [PATCH] fix FrozenModel Signed-off-by: Jinzhe Zeng --- deepmd/tf/model/frozen.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/deepmd/tf/model/frozen.py b/deepmd/tf/model/frozen.py index 4151b1b0e4..f06ae954d1 100644 --- a/deepmd/tf/model/frozen.py +++ b/deepmd/tf/model/frozen.py @@ -130,14 +130,26 @@ def build( ) if self.model_type == "ener": return { - "energy": tf.identity(self.model.t_energy, name="o_energy" + suffix), - "force": tf.identity(self.model.t_force, name="o_force" + suffix), - "virial": tf.identity(self.model.t_virial, name="o_virial" + suffix), + # must visit the backend class + "energy": tf.identity( + self.model.deep_eval.output_tensors["energy_redu"], + name="o_energy" + suffix, + ), + "force": tf.identity( + self.model.deep_eval.output_tensors["energy_derv_r"], + name="o_force" + suffix, + ), + "virial": tf.identity( + self.model.deep_eval.output_tensors["energy_derv_c_redu"], + name="o_virial" + suffix, + ), "atom_ener": tf.identity( - self.model.t_ae, name="o_atom_energy" + suffix + self.model.deep_eval.output_tensors["energy"], + name="o_atom_energy" + suffix, ), "atom_virial": tf.identity( - self.model.t_av, name="o_atom_virial" + suffix + self.model.deep_eval.output_tensors["energy_derv_c"], + name="o_atom_virial" + suffix, ), "coord": coord_, "atype": atype_,