Skip to content

Commit

Permalink
change name for process_input
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 13, 2024
1 parent bf0e58f commit 0b8f0f5
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 21 deletions.
9 changes: 7 additions & 2 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
env,
)
from deepmd.pt.utils.nlist import (
process_input,
extend_input_and_build_neighbor_list,
)
from deepmd.pt.utils.utils import (
get_activation_fn,
Expand Down Expand Up @@ -284,7 +284,12 @@ def compute_input_stats(self, merged):
system["box"],
system["natoms"],
)
extended_coord, extended_atype, mapping, nlist = process_input(
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
Expand Down
9 changes: 7 additions & 2 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
TypeFilter,
)
from deepmd.pt.utils.nlist import (
process_input,
extend_input_and_build_neighbor_list,
)


Expand Down Expand Up @@ -393,7 +393,12 @@ def compute_input_stats(self, merged):
system["box"],
system["natoms"],
)
extended_coord, extended_atype, mapping, nlist = process_input(
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
Expand Down
9 changes: 7 additions & 2 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
env,
)
from deepmd.pt.utils.nlist import (
process_input,
extend_input_and_build_neighbor_list,
)


Expand Down Expand Up @@ -200,7 +200,12 @@ def compute_input_stats(self, merged):
system["box"],
system["natoms"],
)
extended_coord, extended_atype, mapping, nlist = process_input(
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
Expand Down
9 changes: 7 additions & 2 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
fit_output_to_model_output,
)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
nlist_distinguish_types,
process_input,
)


Expand Down Expand Up @@ -89,7 +89,12 @@ def forward_common(
The keys are defined by the `ModelOutputDef`.
"""
extended_coord, extended_atype, mapping, nlist = process_input(
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)


def process_input(
def extend_input_and_build_neighbor_list(
coord,
atype,
rcut: float,
Expand Down
9 changes: 7 additions & 2 deletions source/tests/pt/model/test_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
GLOBAL_PT_FLOAT_PRECISION,
)
from deepmd.pt.utils.nlist import (
process_input,
extend_input_and_build_neighbor_list,
)
from deepmd.tf.common import (
expand_sys_str,
Expand Down Expand Up @@ -145,7 +145,12 @@ def test_consistency(self):
pt_coord = self.pt_batch["coord"].to(env.DEVICE)
atype = self.pt_batch["atype"].to(env.DEVICE)
pt_coord.requires_grad_(True)
extended_coord, extended_atype, mapping, nlist = process_input(
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
pt_coord,
self.pt_batch["atype"].to(env.DEVICE),
self.rcut,
Expand Down
16 changes: 13 additions & 3 deletions source/tests/pt/model/test_descriptor_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
env,
)
from deepmd.pt.utils.nlist import (
process_input,
extend_input_and_build_neighbor_list,
)

dtype = torch.float64
Expand Down Expand Up @@ -251,7 +251,12 @@ def test_descriptor_block(self):
## to save model parameters
# torch.save(des.state_dict(), 'model_weights.pth')
# torch.save(type_embedding.state_dict(), 'model_weights.pth')
extended_coord, extended_atype, mapping, nlist = process_input(
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
des.get_rcut(),
Expand Down Expand Up @@ -300,7 +305,12 @@ def test_descriptor(self):
coord = self.coord
atype = self.atype
box = self.cell
extended_coord, extended_atype, mapping, nlist = process_input(
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
des.get_rcut(),
Expand Down
16 changes: 13 additions & 3 deletions source/tests/pt/model/test_descriptor_dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
)
from deepmd.pt.utils.nlist import (
build_multiple_neighbor_list,
extend_input_and_build_neighbor_list,
get_multiple_nlist_key,
process_input,
)

dtype = torch.float64
Expand Down Expand Up @@ -142,7 +142,12 @@ def test_descriptor_hyb(self):
## to save model parameters
# torch.save(des.state_dict(), 'model_weights.pth')
# torch.save(type_embedding.state_dict(), 'model_weights.pth')
extended_coord, extended_atype, mapping, nlist_max = process_input(
(
extended_coord,
extended_atype,
mapping,
nlist_max,
) = extend_input_and_build_neighbor_list(
coord,
atype,
rcut_max,
Expand Down Expand Up @@ -200,7 +205,12 @@ def test_descriptor(self):
coord = self.coord
atype = self.atype
box = self.cell
extended_coord, extended_atype, mapping, nlist = process_input(
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
des.get_rcut(),
Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/model/test_dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from deepmd.pt.utils.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
process_input,
extend_input_and_build_neighbor_list,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
Expand Down Expand Up @@ -434,7 +434,7 @@ def test_self_consistency(self):
to_numpy_array(ret0["atom_virial"]),
to_numpy_array(ret1["atom_virial"]),
)
coord_ext, atype_ext, mapping, nlist = process_input(
coord_ext, atype_ext, mapping, nlist = extend_input_and_build_neighbor_list(
to_torch_tensor(self.coord),
to_torch_tensor(self.atype),
self.rcut,
Expand Down
9 changes: 7 additions & 2 deletions source/tests/pt/model/test_embedding_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.pt.utils.nlist import (
process_input,
extend_input_and_build_neighbor_list,
)
from deepmd.tf.common import (
expand_sys_str,
Expand Down Expand Up @@ -180,7 +180,12 @@ def test_consistency(self):

pt_coord = self.torch_batch["coord"].to(env.DEVICE)
pt_coord.requires_grad_(True)
extended_coord, extended_atype, mapping, nlist = process_input(
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
pt_coord,
self.torch_batch["atype"].to(env.DEVICE),
self.rcut,
Expand Down

0 comments on commit 0b8f0f5

Please sign in to comment.