Skip to content

Commit

Permalink
fix FrozenModel
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Feb 1, 2024
1 parent 02b3a6d commit 4fa8754
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions deepmd/tf/model/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,26 @@ def build(
)
if self.model_type == "ener":
return {
"energy": tf.identity(self.model.t_energy, name="o_energy" + suffix),
"force": tf.identity(self.model.t_force, name="o_force" + suffix),
"virial": tf.identity(self.model.t_virial, name="o_virial" + suffix),
# must visit the backend class
"energy": tf.identity(
self.model.deep_eval.output_tensors["energy_redu"],
name="o_energy" + suffix,
),
"force": tf.identity(
self.model.deep_eval.output_tensors["energy_derv_r"],
name="o_force" + suffix,
),
"virial": tf.identity(
self.model.deep_eval.output_tensors["energy_derv_c_redu"],
name="o_virial" + suffix,
),
"atom_ener": tf.identity(
self.model.t_ae, name="o_atom_energy" + suffix
self.model.deep_eval.output_tensors["energy"],
name="o_atom_energy" + suffix,
),
"atom_virial": tf.identity(
self.model.t_av, name="o_atom_virial" + suffix
self.model.deep_eval.output_tensors["energy_derv_c"],
name="o_atom_virial" + suffix,
),
"coord": coord_,
"atype": atype_,
Expand Down

0 comments on commit 4fa8754

Please sign in to comment.