Skip to content

Commit

Permalink
fix: add UTs
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Mar 5, 2024
1 parent caf5f78 commit 7d4e49c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 28 deletions.
11 changes: 3 additions & 8 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,13 @@ def get_rcut(self) -> float:
pass

@abstractmethod
def get_type_map(self) -> Optional[List[str]]:
def get_type_map(self) -> List[str]:
"""Get the type map."""
pass

def get_ntypes(self) -> int:
"""Get the number of atom types."""
tmap = self.get_type_map()
if tmap is not None:
return len(tmap)
else:
raise ValueError(
"cannot infer the number of types from a None type map"
)
return len(self.get_type_map())

@abstractmethod
def get_sel(self) -> List[int]:
Expand Down
12 changes: 4 additions & 8 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_rcut(self) -> float:

@torch.jit.export
def get_type_map(self) -> List[str]:
self.type_map
return self.type_map

def get_model_rcuts(self) -> List[float]:
"""Get the cut-off radius for each individual models."""
Expand Down Expand Up @@ -240,15 +240,11 @@ def remap_atype(
-------
torch.Tensor
"""
assert max(atype) < len(
new_map
), "The input `atype` cannot be handled by the type_map."
idx_2_type = {k: new_map[k] for k in range(len(new_map))}
assert torch.max(atype) < len(new_map), "The input `atype` cannot be handled by the type_map."
type_2_idx = {atp: idx for idx, atp in enumerate(ori_map)}
# this maps the atype in the new map to the original map
mapping = {idx: type_2_idx[idx_2_type[idx]] for idx in range(len(new_map))}
updated_atype = atype.clone()
updated_atype.apply_(mapping.get)
mapping = torch.tensor([type_2_idx[new_map[idx]] for idx in range(len(new_map))]).to(atype.device)
updated_atype = mapping[atype.long()]
return updated_atype

def fitting_output_def(self) -> FittingOutputDef:
Expand Down
23 changes: 11 additions & 12 deletions source/tests/pt/model/test_linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,26 +187,25 @@ def test_jit(self):

class TestRemmapMethod(unittest.TestCase):
def test_invalid(self):
atype = torch.randint(2, 4, (2, 5))
atype = torch.randint(2,4, (2,5), device=env.DEVICE)
commonl = ["H"]
originl = ["Si", "H", "O", "S"]
with self.assertRaises():
new_atype = remap_atype(atype, originl, commonl)

originl = ["Si","H","O", "S"]
with self.assertRaises(AssertionError):
new_atype = DPZBLLinearAtomicModel.remap_atype(atype, originl, commonl)
def test_valid(self):
atype = torch.randint(0, 3, (4, 20))
nl = ["H", "O", "S"]
ol = ["Si", "H", "O", "S"]
new_atype = remap_atype(atype, originl, commonl)

atype = torch.randint(0,3, (4,20), device=env.DEVICE)
commonl = ["H", "O", "S"]
originl = ["Si","H","O", "S"]
new_atype = DPZBLLinearAtomicModel.remap_atype(atype, originl, commonl)
def trans(atype, map):
idx = atype.flatten().tolist()
res = []
for i in idx:
res.append(map[i])
return res

assert trans(atype, nl) == trans(new_atype, ol)
assert trans(atype,commonl) == trans(new_atype, originl)


if __name__ == "__main__":
Expand Down

0 comments on commit 7d4e49c

Please sign in to comment.