diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index ce40909352..26467124b8 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -21,7 +21,7 @@ env, ) from deepmd.pt.utils.nlist import ( - process_input, + extend_input_and_build_neighbor_list, ) from deepmd.pt.utils.utils import ( get_activation_fn, @@ -284,7 +284,12 @@ def compute_input_stats(self, merged): system["box"], system["natoms"], ) - extended_coord, extended_atype, mapping, nlist = process_input( + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( coord, atype, self.get_rcut(), diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index ddd2b7ca73..700bf6d59b 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -37,7 +37,7 @@ TypeFilter, ) from deepmd.pt.utils.nlist import ( - process_input, + extend_input_and_build_neighbor_list, ) @@ -393,7 +393,12 @@ def compute_input_stats(self, merged): system["box"], system["natoms"], ) - extended_coord, extended_atype, mapping, nlist = process_input( + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( coord, atype, self.get_rcut(), diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 875d235e80..d4dc0cd054 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -22,7 +22,7 @@ env, ) from deepmd.pt.utils.nlist import ( - process_input, + extend_input_and_build_neighbor_list, ) @@ -200,7 +200,12 @@ def compute_input_stats(self, merged): system["box"], system["natoms"], ) - extended_coord, extended_atype, mapping, nlist = process_input( + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( coord, atype, self.get_rcut(), diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index c5e4b3705b..e45c993286 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -14,8 +14,8 @@ fit_output_to_model_output, ) from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, nlist_distinguish_types, - process_input, ) @@ -89,7 +89,12 @@ def forward_common( The keys are defined by the `ModelOutputDef`. """ - extended_coord, extended_atype, mapping, nlist = process_input( + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( coord, atype, self.get_rcut(), diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index 9d43ca8191..963c9bc9b6 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -17,7 +17,7 @@ ) -def process_input( +def extend_input_and_build_neighbor_list( coord, atype, rcut: float, diff --git a/source/tests/pt/model/test_descriptor.py b/source/tests/pt/model/test_descriptor.py index aa5b6017bf..a4493b5b51 100644 --- a/source/tests/pt/model/test_descriptor.py +++ b/source/tests/pt/model/test_descriptor.py @@ -29,7 +29,7 @@ GLOBAL_PT_FLOAT_PRECISION, ) from deepmd.pt.utils.nlist import ( - process_input, + extend_input_and_build_neighbor_list, ) from deepmd.tf.common import ( expand_sys_str, @@ -145,7 +145,12 @@ def test_consistency(self): pt_coord = self.pt_batch["coord"].to(env.DEVICE) atype = self.pt_batch["atype"].to(env.DEVICE) pt_coord.requires_grad_(True) - extended_coord, extended_atype, mapping, nlist = process_input( + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( pt_coord, self.pt_batch["atype"].to(env.DEVICE), self.rcut, diff --git a/source/tests/pt/model/test_descriptor_dpa1.py b/source/tests/pt/model/test_descriptor_dpa1.py index 57adef0784..07d4d34449 100644 --- a/source/tests/pt/model/test_descriptor_dpa1.py +++ b/source/tests/pt/model/test_descriptor_dpa1.py @@ -19,7 +19,7 @@ env, ) from deepmd.pt.utils.nlist import ( - process_input, + extend_input_and_build_neighbor_list, ) dtype = torch.float64 @@ -251,7 +251,12 @@ def test_descriptor_block(self): ## to save model parameters # torch.save(des.state_dict(), 'model_weights.pth') # torch.save(type_embedding.state_dict(), 'model_weights.pth') - extended_coord, extended_atype, mapping, nlist = process_input( + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( coord, atype, des.get_rcut(), @@ -300,7 +305,12 @@ def test_descriptor(self): coord = self.coord atype = self.atype box = self.cell - extended_coord, extended_atype, mapping, nlist = process_input( + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( coord, atype, des.get_rcut(), diff --git a/source/tests/pt/model/test_descriptor_dpa2.py b/source/tests/pt/model/test_descriptor_dpa2.py index c80c92f3d2..6b80eb89a2 100644 --- a/source/tests/pt/model/test_descriptor_dpa2.py +++ b/source/tests/pt/model/test_descriptor_dpa2.py @@ -20,8 +20,8 @@ ) from deepmd.pt.utils.nlist import ( build_multiple_neighbor_list, + extend_input_and_build_neighbor_list, get_multiple_nlist_key, - process_input, ) dtype = torch.float64 @@ -142,7 +142,12 @@ def test_descriptor_hyb(self): ## to save model parameters # torch.save(des.state_dict(), 'model_weights.pth') # torch.save(type_embedding.state_dict(), 'model_weights.pth') - extended_coord, extended_atype, mapping, nlist_max = process_input( + ( + extended_coord, + extended_atype, + mapping, + nlist_max, + ) = extend_input_and_build_neighbor_list( coord, atype, rcut_max, @@ -200,7 +205,12 @@ def test_descriptor(self): coord = self.coord atype = self.atype box = self.cell - extended_coord, extended_atype, mapping, nlist = process_input( + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( coord, atype, des.get_rcut(), diff --git a/source/tests/pt/model/test_dp_model.py b/source/tests/pt/model/test_dp_model.py index 8085898986..d970c8a542 100644 --- a/source/tests/pt/model/test_dp_model.py +++ b/source/tests/pt/model/test_dp_model.py @@ -23,7 +23,7 @@ from deepmd.pt.utils.nlist import ( build_neighbor_list, extend_coord_with_ghosts, - process_input, + extend_input_and_build_neighbor_list, ) from deepmd.pt.utils.utils import ( to_numpy_array, @@ -434,7 +434,7 @@ def test_self_consistency(self): to_numpy_array(ret0["atom_virial"]), to_numpy_array(ret1["atom_virial"]), ) - coord_ext, atype_ext, mapping, nlist = process_input( + coord_ext, atype_ext, mapping, nlist = extend_input_and_build_neighbor_list( to_torch_tensor(self.coord), to_torch_tensor(self.atype), self.rcut, diff --git a/source/tests/pt/model/test_embedding_net.py b/source/tests/pt/model/test_embedding_net.py index 1f457a4271..2621b5d135 100644 --- a/source/tests/pt/model/test_embedding_net.py +++ b/source/tests/pt/model/test_embedding_net.py @@ -32,7 +32,7 @@ GLOBAL_NP_FLOAT_PRECISION, ) from deepmd.pt.utils.nlist import ( - process_input, + extend_input_and_build_neighbor_list, ) from deepmd.tf.common import ( expand_sys_str, @@ -180,7 +180,12 @@ def test_consistency(self): pt_coord = self.torch_batch["coord"].to(env.DEVICE) pt_coord.requires_grad_(True) - extended_coord, extended_atype, mapping, nlist = process_input( + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( pt_coord, self.torch_batch["atype"].to(env.DEVICE), self.rcut,