Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 5, 2024
1 parent 7d4e49c commit a30bc35
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
8 changes: 6 additions & 2 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,14 @@ def remap_atype(
-------
torch.Tensor
"""
assert torch.max(atype) < len(new_map), "The input `atype` cannot be handled by the type_map."
assert torch.max(atype) < len(
new_map
), "The input `atype` cannot be handled by the type_map."
type_2_idx = {atp: idx for idx, atp in enumerate(ori_map)}
# this maps the atype in the new map to the original map
mapping = torch.tensor([type_2_idx[new_map[idx]] for idx in range(len(new_map))]).to(atype.device)
mapping = torch.tensor(
[type_2_idx[new_map[idx]] for idx in range(len(new_map))]
).to(atype.device)
updated_atype = mapping[atype.long()]
return updated_atype

Expand Down
15 changes: 8 additions & 7 deletions source/tests/pt/model/test_linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,25 +187,26 @@ def test_jit(self):

class TestRemmapMethod(unittest.TestCase):
def test_invalid(self):
atype = torch.randint(2,4, (2,5), device=env.DEVICE)
atype = torch.randint(2, 4, (2, 5), device=env.DEVICE)
commonl = ["H"]
originl = ["Si","H","O", "S"]
originl = ["Si", "H", "O", "S"]
with self.assertRaises(AssertionError):
new_atype = DPZBLLinearAtomicModel.remap_atype(atype, originl, commonl)

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable new_atype is not used.

def test_valid(self):
atype = torch.randint(0,3, (4,20), device=env.DEVICE)
atype = torch.randint(0, 3, (4, 20), device=env.DEVICE)
commonl = ["H", "O", "S"]
originl = ["Si","H","O", "S"]
originl = ["Si", "H", "O", "S"]
new_atype = DPZBLLinearAtomicModel.remap_atype(atype, originl, commonl)

def trans(atype, map):
idx = atype.flatten().tolist()
res = []
for i in idx:
res.append(map[i])
return res
assert trans(atype,commonl) == trans(new_atype, originl)

assert trans(atype, commonl) == trans(new_atype, originl)


if __name__ == "__main__":
Expand Down

0 comments on commit a30bc35

Please sign in to comment.