From 7d4e49c22e2a6dcb60de17c1fce88593bc89c203 Mon Sep 17 00:00:00 2001 From: anyangml Date: Tue, 5 Mar 2024 08:11:37 +0000 Subject: [PATCH] fix: add UTs --- .../atomic_model/make_base_atomic_model.py | 11 +++------ .../model/atomic_model/linear_atomic_model.py | 12 ++++------ .../pt/model/test_linear_atomic_model.py | 23 +++++++++---------- 3 files changed, 18 insertions(+), 28 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py index 5548147d54..ce1a6708e6 100644 --- a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py @@ -54,18 +54,13 @@ def get_rcut(self) -> float: pass @abstractmethod - def get_type_map(self) -> Optional[List[str]]: + def get_type_map(self) -> List[str]: """Get the type map.""" + pass def get_ntypes(self) -> int: """Get the number of atom types.""" - tmap = self.get_type_map() - if tmap is not None: - return len(tmap) - else: - raise ValueError( - "cannot infer the number of types from a None type map" - ) + return len(self.get_type_map()) @abstractmethod def get_sel(self) -> List[int]: diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 9440b8276c..446b88da46 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -96,7 +96,7 @@ def get_rcut(self) -> float: @torch.jit.export def get_type_map(self) -> List[str]: - self.type_map + return self.type_map def get_model_rcuts(self) -> List[float]: """Get the cut-off radius for each individual models.""" @@ -240,15 +240,11 @@ def remap_atype( ------- torch.Tensor """ - assert max(atype) < len( - new_map - ), "The input `atype` cannot be handled by the type_map." - idx_2_type = {k: new_map[k] for k in range(len(new_map))} + assert torch.max(atype) < len(new_map), "The input `atype` cannot be handled by the type_map." type_2_idx = {atp: idx for idx, atp in enumerate(ori_map)} # this maps the atype in the new map to the original map - mapping = {idx: type_2_idx[idx_2_type[idx]] for idx in range(len(new_map))} - updated_atype = atype.clone() - updated_atype.apply_(mapping.get) + mapping = torch.tensor([type_2_idx[new_map[idx]] for idx in range(len(new_map))]).to(atype.device) + updated_atype = mapping[atype.long()] return updated_atype def fitting_output_def(self) -> FittingOutputDef: diff --git a/source/tests/pt/model/test_linear_atomic_model.py b/source/tests/pt/model/test_linear_atomic_model.py index f168e1613c..e6fbbb0304 100644 --- a/source/tests/pt/model/test_linear_atomic_model.py +++ b/source/tests/pt/model/test_linear_atomic_model.py @@ -187,26 +187,25 @@ def test_jit(self): class TestRemmapMethod(unittest.TestCase): def test_invalid(self): - atype = torch.randint(2, 4, (2, 5)) + atype = torch.randint(2,4, (2,5), device=env.DEVICE) commonl = ["H"] - originl = ["Si", "H", "O", "S"] - with self.assertRaises(): - new_atype = remap_atype(atype, originl, commonl) - + originl = ["Si","H","O", "S"] + with self.assertRaises(AssertionError): + new_atype = DPZBLLinearAtomicModel.remap_atype(atype, originl, commonl) + def test_valid(self): - atype = torch.randint(0, 3, (4, 20)) - nl = ["H", "O", "S"] - ol = ["Si", "H", "O", "S"] - new_atype = remap_atype(atype, originl, commonl) - + atype = torch.randint(0,3, (4,20), device=env.DEVICE) + commonl = ["H", "O", "S"] + originl = ["Si","H","O", "S"] + new_atype = DPZBLLinearAtomicModel.remap_atype(atype, originl, commonl) def trans(atype, map): idx = atype.flatten().tolist() res = [] for i in idx: res.append(map[i]) return res - - assert trans(atype, nl) == trans(new_atype, ol) + + assert trans(atype,commonl) == trans(new_atype, originl) if __name__ == "__main__":