From 95852035615d214d0f1683cab1228cfe73719a1e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Feb 2024 12:58:01 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/model/task/ener.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index 00271e1493..4d7f25ff7d 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -86,7 +86,7 @@ def __init__( self.activation_function = activation_function self.precision = precision self.prec = PRECISION_DICT[self.precision] - + # init constants if self.numb_fparam > 0: self.register_buffer( @@ -218,7 +218,6 @@ def _foward_common( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, ): - xx = descriptor nf, nloc, nd = xx.shape if hasattr(self, "bias_atom_e"): @@ -282,7 +281,11 @@ def _foward_common( for type_i, filter_layer in enumerate(self.filter_layers_old): mask = atype == type_i atom_property = filter_layer(xx) - atom_property = atom_property + self.bias_atom_e[type_i] if hasattr(self, "bias_atom_e") else atom_property + atom_property = ( + atom_property + self.bias_atom_e[type_i] + if hasattr(self, "bias_atom_e") + else atom_property + ) atom_property = atom_property * mask.unsqueeze(-1) outs = outs + atom_property # Shape is [nframes, natoms[0], 1] return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} @@ -297,11 +300,15 @@ def _foward_common( mask = (atype == type_i).unsqueeze(-1) mask = torch.tile(mask, (1, 1, self.dim_out)) atom_property = ll(xx) - atom_property = atom_property + self.bias_atom_e[type_i] if hasattr(self, "bias_atom_e") else atom_property + atom_property = ( + atom_property + self.bias_atom_e[type_i] + if hasattr(self, "bias_atom_e") + else atom_property + ) atom_property = atom_property * mask outs = outs + atom_property # Shape is [nframes, natoms[0], 1] return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} - + @fitting_check_output class InvarFitting(GeneralFitting): @@ -330,14 +337,14 @@ def __init__( - bias_atom_e: Average enery per atom for each element. - resnet_dt: Using time-step in the ResNet construction. """ - super().__init__() + super().__init__() if bias_atom_e is None: bias_atom_e = np.zeros([self.ntypes, self.dim_out]) bias_atom_e = torch.tensor(bias_atom_e, dtype=self.prec, device=device) bias_atom_e = bias_atom_e.view([self.ntypes, self.dim_out]) if not self.use_tebd: assert self.ntypes == bias_atom_e.shape[0], "Element count mismatches!" - self.register_buffer("bias_atom_e", bias_atom_e) + self.register_buffer("bias_atom_e", bias_atom_e) def output_def(self) -> FittingOutputDef: return FittingOutputDef( @@ -434,7 +441,6 @@ def forward( - `torch.Tensor`: Total energy with shape [nframes, natoms[0]]. """ return super().forward(descriptor, atype, gr, g2, h2, fparam, aparam) - @Fitting.register("ener")