Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 5, 2024
1 parent 3610b3d commit caf5f78
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
30 changes: 17 additions & 13 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def forward_atomic(

for i, model in enumerate(self.models):
ori_map = model.get_type_map()
updated_atype = self.remap_atype(extended_atype,ori_map,self.type_map)
updated_atype = self.remap_atype(extended_atype, ori_map, self.type_map)

Check warning on line 191 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L190-L191

Added lines #L190 - L191 were not covered by tests
ener_list.append(
model.forward_atomic(
extended_coord,
Expand Down Expand Up @@ -219,30 +219,34 @@ def forward_atomic(
return fit_ret

@staticmethod
def remap_atype(atype:torch.Tensor, ori_map: List[str], new_map: List[str]) -> torch.Tensor:
"""
This method is used to map the atype from the common type_map to the original type_map of
def remap_atype(

Check warning on line 222 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L221-L222

Added lines #L221 - L222 were not covered by tests
atype: torch.Tensor, ori_map: List[str], new_map: List[str]
) -> torch.Tensor:
"""
This method is used to map the atype from the common type_map to the original type_map of
indivial AtomicModels.
Parameters
----------
atype: torch.Tensor
atype : torch.Tensor
The atom type tensor being updated, shape of (nframes, natoms)
ori_map: List[str]
ori_map : List[str]
The original type map of an AtomicModel.
new_map: List[str]
new_map : List[str]
The common type map of the DPZBLLinearAtomicModel, created by the `get_type_map` method,
must be a subset of the ori_map.
Return
Returns
-------
torch.Tensor
"""
assert max(atype) < len(new_map), "The input `atype` cannot be handled by the type_map."
assert max(atype) < len(

Check warning on line 243 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L243

Added line #L243 was not covered by tests
new_map
), "The input `atype` cannot be handled by the type_map."
idx_2_type = {k: new_map[k] for k in range(len(new_map))}
type_2_idx = {atp: idx for idx, atp in enumerate(ori_map)}
type_2_idx = {atp: idx for idx, atp in enumerate(ori_map)}

Check warning on line 247 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L246-L247

Added lines #L246 - L247 were not covered by tests
# this maps the atype in the new map to the original map
mapping = {idx: type_2_idx[idx_2_type[idx]] for idx in range(len(new_map))}
mapping = {idx: type_2_idx[idx_2_type[idx]] for idx in range(len(new_map))}
updated_atype = atype.clone()
updated_atype.apply_(mapping.get)
return updated_atype

Check warning on line 252 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L249-L252

Added lines #L249 - L252 were not covered by tests
Expand Down Expand Up @@ -374,7 +378,7 @@ def __init__(

# this is a placeholder being updated in _compute_weight, to handle Jit attribute init error.
self.zbl_weight = torch.empty(0, dtype=torch.float64, device=env.DEVICE)

def compute_or_load_stat(

Check warning on line 382 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L382

Added line #L382 was not covered by tests
self,
sampled_func,
Expand Down
17 changes: 9 additions & 8 deletions source/tests/pt/model/test_linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,28 +184,29 @@ def test_jit(self):
self.assertEqual(md3.get_rcut(), self.rcut)
self.assertEqual(md3.get_type_map(), ["foo", "bar"])

class TestRemmapMethod(unittest.TestCase):

class TestRemmapMethod(unittest.TestCase):
def test_invalid(self):
atype = torch.randint(2,4, (2,5))
atype = torch.randint(2, 4, (2, 5))
commonl = ["H"]
originl = ["Si","H","O", "S"]
originl = ["Si", "H", "O", "S"]
with self.assertRaises():

Check failure

Code scanning / CodeQL

Wrong number of arguments in a call Error test

Call to
method TestCase.assertRaises
with too few arguments; should be no fewer than 1.
new_atype = remap_atype(atype, originl, commonl)

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable new_atype is not used.

def test_valid(self):
atype = torch.randint(0,3, (4,20))
atype = torch.randint(0, 3, (4, 20))
nl = ["H", "O", "S"]
ol = ["Si","H","O", "S"]
ol = ["Si", "H", "O", "S"]
new_atype = remap_atype(atype, originl, commonl)

def trans(atype, map):
idx = atype.flatten().tolist()
res = []
for i in idx:
res.append(map[i])
return res
assert trans(atype,nl) == trans(new_atype, ol)

assert trans(atype, nl) == trans(new_atype, ol)


if __name__ == "__main__":
Expand Down

0 comments on commit caf5f78

Please sign in to comment.