From 5f52bb704b5cefc9cdb7f615011a7ec21518f8d2 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 12 Nov 2024 06:18:41 -0500 Subject: [PATCH] move bias Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/fitting/general_fitting.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index b4691bf8a3..2c06644afb 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -448,18 +448,15 @@ def _call_common( mask, atom_property, xp.zeros_like(atom_property) ) outs = outs + atom_property # Shape is [nframes, natoms[0], 1] - outs = xp.astype(outs, get_xp_precision(xp, "global")) - for type_i in range(self.ntypes): - outs = outs + self.bias_atom_e[type_i, ...] else: outs = self.nets[()](xx) if xx_zeros is not None: outs -= self.nets[()](xx_zeros) - outs = xp.astype(outs, get_xp_precision(xp, "global")) - outs += xp.reshape( - xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0), - [nf, nloc, net_dim_out], - ) + outs = xp.astype(outs, get_xp_precision(xp, "global")) + outs += xp.reshape( + xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0), + [nf, nloc, net_dim_out], + ) # nf x nloc exclude_mask = self.emask.build_type_exclude_mask(atype) # nf x nloc x nod