Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 14, 2024
1 parent 86a6713 commit 9585203
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
self.activation_function = activation_function
self.precision = precision
self.prec = PRECISION_DICT[self.precision]

# init constants
if self.numb_fparam > 0:
self.register_buffer(
Expand Down Expand Up @@ -218,7 +218,6 @@ def _foward_common(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
):

xx = descriptor
nf, nloc, nd = xx.shape
if hasattr(self, "bias_atom_e"):
Expand Down Expand Up @@ -282,7 +281,11 @@ def _foward_common(
for type_i, filter_layer in enumerate(self.filter_layers_old):
mask = atype == type_i
atom_property = filter_layer(xx)
atom_property = atom_property + self.bias_atom_e[type_i] if hasattr(self, "bias_atom_e") else atom_property
atom_property = (

Check warning on line 284 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L283-L284

Added lines #L283 - L284 were not covered by tests
atom_property + self.bias_atom_e[type_i]
if hasattr(self, "bias_atom_e")
else atom_property
)
atom_property = atom_property * mask.unsqueeze(-1)
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)}

Check warning on line 291 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L289-L291

Added lines #L289 - L291 were not covered by tests
Expand All @@ -297,11 +300,15 @@ def _foward_common(
mask = (atype == type_i).unsqueeze(-1)
mask = torch.tile(mask, (1, 1, self.dim_out))
atom_property = ll(xx)
atom_property = atom_property + self.bias_atom_e[type_i] if hasattr(self, "bias_atom_e") else atom_property
atom_property = (

Check warning on line 303 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L302-L303

Added lines #L302 - L303 were not covered by tests
atom_property + self.bias_atom_e[type_i]
if hasattr(self, "bias_atom_e")
else atom_property
)
atom_property = atom_property * mask
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]

Check warning on line 309 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L308-L309

Added lines #L308 - L309 were not covered by tests
return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)}


@fitting_check_output
class InvarFitting(GeneralFitting):
Expand Down Expand Up @@ -330,14 +337,14 @@ def __init__(
- bias_atom_e: Average enery per atom for each element.
- resnet_dt: Using time-step in the ResNet construction.
"""
super().__init__()
super().__init__()

Check failure

Code scanning / CodeQL

Wrong number of arguments in a call Error

Call to
method GeneralFitting.__init__
with too few arguments; should be no fewer than 4.

Check failure

Code scanning / CodeQL

Wrong number of arguments in a class instantiation Error

Call to
GeneralFitting.__init__
with too few arguments; should be no fewer than 4.
if bias_atom_e is None:
bias_atom_e = np.zeros([self.ntypes, self.dim_out])
bias_atom_e = torch.tensor(bias_atom_e, dtype=self.prec, device=device)
bias_atom_e = bias_atom_e.view([self.ntypes, self.dim_out])
if not self.use_tebd:
assert self.ntypes == bias_atom_e.shape[0], "Element count mismatches!"
self.register_buffer("bias_atom_e", bias_atom_e)
self.register_buffer("bias_atom_e", bias_atom_e)

Check warning on line 347 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L340-L347

Added lines #L340 - L347 were not covered by tests

def output_def(self) -> FittingOutputDef:
return FittingOutputDef(

Check warning on line 350 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L349-L350

Added lines #L349 - L350 were not covered by tests
Expand Down Expand Up @@ -434,7 +441,6 @@ def forward(
- `torch.Tensor`: Total energy with shape [nframes, natoms[0]].
"""
return super().forward(descriptor, atype, gr, g2, h2, fparam, aparam)

Check warning on line 443 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L443

Added line #L443 was not covered by tests



@Fitting.register("ener")
Expand Down

0 comments on commit 9585203

Please sign in to comment.