Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 13, 2024
1 parent b54e109 commit ca75551
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,26 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging

import torch
from typing import (

Check warning on line 3 in deepmd/pt/model/task/dipole.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dipole.py#L3

Added line #L3 was not covered by tests
List,
Optional,
Tuple,
)
from deepmd.pt.utils.env import (
DEFAULT_PRECISION,
PRECISION_DICT,
)
import numpy as np

from deepmd.pt.utils import (
env,
)
import torch

from deepmd.pt.model.network.mlp import (

Check warning on line 10 in deepmd/pt/model/task/dipole.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dipole.py#L10

Added line #L10 was not covered by tests
NetworkCollection,
FittingNet,
NetworkCollection,
)
from deepmd.pt.model.task.fitting import (
Fitting,
)
from deepmd.pt.utils import (

Check warning on line 17 in deepmd/pt/model/task/dipole.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dipole.py#L17

Added line #L17 was not covered by tests
env,
)
from deepmd.pt.utils.env import (

Check warning on line 20 in deepmd/pt/model/task/dipole.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dipole.py#L20

Added line #L20 was not covered by tests
DEFAULT_PRECISION,
PRECISION_DICT,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE

Check warning on line 26 in deepmd/pt/model/task/dipole.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dipole.py#L25-L26

Added lines #L25 - L26 were not covered by tests
Expand Down Expand Up @@ -193,8 +190,8 @@ def forward(

outs = torch.zeros_like(atype).unsqueeze(-1) # jit assertion
if self.use_tebd:
atom_dipole = self.filter_layers.networks[0](xx)
outs = outs + atom_dipole # Shape is [nframes, natoms[0], 3]
atom_dipole = self.filter_layers.networks[0](xx)
outs = outs + atom_dipole # Shape is [nframes, natoms[0], 3]

Check warning on line 194 in deepmd/pt/model/task/dipole.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dipole.py#L191-L194

Added lines #L191 - L194 were not covered by tests
else:
for type_i, ll in enumerate(self.filter_layers.networks):
mask = (atype == type_i).unsqueeze(-1)
Expand Down

0 comments on commit ca75551

Please sign in to comment.