Skip to content

Commit

Permalink
breaking: pt: remove data preprocess from data stat (#3261)
Browse files Browse the repository at this point in the history
This PR:

- Remove data preprocess from data stat.

- Cleanup dependency of data preprocess in dataset and dataloader.


Note that:

- `DeepmdDataSystem` still has dependency for PyTorch, which leaves for
@CaRoLZhangxy to clean up.

- Denoise part in `DeepmdDataSystem` still needs further clean up, which
leaves for @Chengqian-Zhang.

---------

Signed-off-by: Duo <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
iProzd and pre-commit-ci[bot] authored Feb 13, 2024
1 parent 398eb7a commit e41b091
Show file tree
Hide file tree
Showing 20 changed files with 325 additions and 510 deletions.
8 changes: 3 additions & 5 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
List,
Optional,
Expand All @@ -18,8 +17,6 @@
DescrptBlockSeAtten,
)

log = logging.getLogger(__name__)


@Descriptor.register("dpa1")
@Descriptor.register("se_atten")
Expand Down Expand Up @@ -112,7 +109,7 @@ def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return False
return self.se_atten.distinguish_types()

@property
def dim_out(self):
Expand All @@ -128,7 +125,7 @@ def compute_input_stats(self, merged):
def init_desc_stat(
self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs
):
assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]]
assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2])
self.se_atten.init_desc_stat(sumr, suma, sumn, sumr2, suma2)

@classmethod
Expand All @@ -141,6 +138,7 @@ def get_stat_name(
"""
descrpt_type = type_name
assert descrpt_type in ["dpa1", "se_atten"]
assert all(x is not None for x in [rcut, rcut_smth, sel])
return f"stat_file_descrpt_dpa1_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz"

@classmethod
Expand Down
11 changes: 4 additions & 7 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
List,
Optional,
Expand Down Expand Up @@ -27,8 +26,6 @@
DescrptBlockSeAtten,
)

log = logging.getLogger(__name__)


@Descriptor.register("dpa2")
class DescrptDPA2(Descriptor):
Expand Down Expand Up @@ -316,7 +313,7 @@ def compute_input_stats(self, merged):
def init_desc_stat(
self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs
):
assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]]
assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2])
for ii, descrpt in enumerate([self.repinit, self.repformers]):
stat_dict_ii = {
"sumr": sumr[ii],
Expand Down Expand Up @@ -346,8 +343,8 @@ def get_stat_name(
"""
descrpt_type = type_name
assert descrpt_type in ["dpa2"]
assert True not in [
x is None
assert all(
x is not None
for x in [
repinit_rcut,
repinit_rcut_smth,
Expand All @@ -356,7 +353,7 @@ def get_stat_name(
repformer_rcut_smth,
repformer_nsel,
]
]
)
return (
f"stat_file_descrpt_dpa2_repinit_rcut{repinit_rcut:.2f}_smth{repinit_rcut_smth:.2f}_sel{repinit_nsel}"
f"_repformer_rcut{repformer_rcut:.2f}_smth{repformer_rcut_smth:.2f}_sel{repformer_nsel}_ntypes{ntypes}.npz"
Expand Down
10 changes: 9 additions & 1 deletion deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def get_dim_in(self) -> int:
def get_dim_emb(self):
return self.dim_emb

def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return any(
descriptor.distinguish_types() for descriptor in self.descriptor_list
)

@property
def dim_out(self):
"""Returns the output dimension of this descriptor."""
Expand Down Expand Up @@ -170,7 +178,7 @@ def compute_input_stats(self, merged):
def init_desc_stat(
self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs
):
assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]]
assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2])
for ii, descrpt in enumerate(self.descriptor_list):
stat_dict_ii = {
"sumr": sumr[ii],
Expand Down
46 changes: 27 additions & 19 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
env,
)
from deepmd.pt.utils.nlist import (
build_neighbor_list,
extend_input_and_build_neighbor_list,
)
from deepmd.pt.utils.utils import (
get_activation_fn,
Expand Down Expand Up @@ -178,6 +178,12 @@ def get_dim_emb(self) -> int:
"""Returns the embedding dimension g2."""
return self.g2_dim

def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return False

@property
def dim_out(self):
"""Returns the output dimension of this descriptor."""
Expand Down Expand Up @@ -272,44 +278,46 @@ def compute_input_stats(self, merged):
suma2 = []
mixed_type = "real_natoms_vec" in merged[0]
for system in merged:
index = system["mapping"].unsqueeze(-1).expand(-1, -1, 3)
extended_coord = torch.gather(system["coord"], dim=1, index=index)
extended_coord = extended_coord - system["shift"]
index = system["mapping"]
extended_atype = torch.gather(system["atype"], dim=1, index=index)
nloc = system["atype"].shape[-1]
#######################################################
# dirty hack here! the interface of dataload should be
# redesigned to support descriptors like dpa2
#######################################################
nlist = build_neighbor_list(
coord, atype, box, natoms = (
system["coord"],
system["atype"],
system["box"],
system["natoms"],
)
(
extended_coord,
extended_atype,
nloc,
self.rcut,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
distinguish_types=False,
distinguish_types=self.distinguish_types(),
box=box,
)
env_mat, _, _ = prod_env_mat_se_a(
extended_coord,
nlist,
system["atype"],
atype,
self.mean,
self.stddev,
self.rcut,
self.rcut_smth,
)
if not mixed_type:
sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt(
env_mat.detach().cpu().numpy(), ndescrpt, system["natoms"]
env_mat.detach().cpu().numpy(), ndescrpt, natoms
)
else:
real_natoms_vec = system["real_natoms_vec"]
sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt(
env_mat.detach().cpu().numpy(),
ndescrpt,
system["real_natoms_vec"],
real_natoms_vec,
mixed_type=mixed_type,
real_atype=system["atype"].detach().cpu().numpy(),
real_atype=atype.detach().cpu().numpy(),
)
sumr.append(sysr)
suma.append(sysa)
Expand Down
46 changes: 34 additions & 12 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
ClassVar,
List,
Expand Down Expand Up @@ -37,8 +36,9 @@
from deepmd.pt.model.network.network import (
TypeFilter,
)

log = logging.getLogger(__name__)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
)


@Descriptor.register("se_e2_a")
Expand Down Expand Up @@ -100,7 +100,7 @@ def distinguish_types(self):
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return True
return self.sea.distinguish_types()

@property
def dim_out(self):
Expand All @@ -114,7 +114,7 @@ def compute_input_stats(self, merged):
def init_desc_stat(
self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs
):
assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]]
assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2])
self.sea.init_desc_stat(sumr, suma, sumn, sumr2, suma2)

@classmethod
Expand All @@ -127,7 +127,7 @@ def get_stat_name(
"""
descrpt_type = type_name
assert descrpt_type in ["se_e2_a"]
assert True not in [x is None for x in [rcut, rcut_smth, sel]]
assert all(x is not None for x in [rcut, rcut_smth, sel])
return f"stat_file_descrpt_sea_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz"

@classmethod
Expand Down Expand Up @@ -347,6 +347,12 @@ def get_dim_in(self) -> int:
"""Returns the input dimension."""
return self.dim_in

def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return True

@property
def dim_out(self):
"""Returns the output dimension of this descriptor."""
Expand Down Expand Up @@ -381,20 +387,36 @@ def compute_input_stats(self, merged):
sumr2 = []
suma2 = []
for system in merged:
index = system["mapping"].unsqueeze(-1).expand(-1, -1, 3)
extended_coord = torch.gather(system["coord"], dim=1, index=index)
extended_coord = extended_coord - system["shift"]
coord, atype, box, natoms = (
system["coord"],
system["atype"],
system["box"],
system["natoms"],
)
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
distinguish_types=self.distinguish_types(),
box=box,
)
env_mat, _, _ = prod_env_mat_se_a(
extended_coord,
system["nlist"],
system["atype"],
nlist,
atype,
self.mean,
self.stddev,
self.rcut,
self.rcut_smth,
)
sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt(
env_mat.detach().cpu().numpy(), self.ndescrpt, system["natoms"]
env_mat.detach().cpu().numpy(), self.ndescrpt, natoms
)
sumr.append(sysr)
suma.append(sysa)
Expand Down
42 changes: 34 additions & 8 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
)


@DescriptorBlock.register("se_atten")
Expand Down Expand Up @@ -161,6 +164,12 @@ def get_dim_emb(self) -> int:
"""Returns the output dimension of embedding."""
return self.filter_neuron[-1]

def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return False

@property
def dim_out(self):
"""Returns the output dimension of this descriptor."""
Expand All @@ -185,29 +194,46 @@ def compute_input_stats(self, merged):
suma2 = []
mixed_type = "real_natoms_vec" in merged[0]
for system in merged:
index = system["mapping"].unsqueeze(-1).expand(-1, -1, 3)
extended_coord = torch.gather(system["coord"], dim=1, index=index)
extended_coord = extended_coord - system["shift"]
coord, atype, box, natoms = (
system["coord"],
system["atype"],
system["box"],
system["natoms"],
)
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
distinguish_types=self.distinguish_types(),
box=box,
)
env_mat, _, _ = prod_env_mat_se_a(
extended_coord,
system["nlist"],
system["atype"],
nlist,
atype,
self.mean,
self.stddev,
self.rcut,
self.rcut_smth,
)
if not mixed_type:
sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt(
env_mat.detach().cpu().numpy(), self.ndescrpt, system["natoms"]
env_mat.detach().cpu().numpy(), self.ndescrpt, natoms
)
else:
real_natoms_vec = system["real_natoms_vec"]
sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt(
env_mat.detach().cpu().numpy(),
self.ndescrpt,
system["real_natoms_vec"],
real_natoms_vec,
mixed_type=mixed_type,
real_atype=system["atype"].detach().cpu().numpy(),
real_atype=atype.detach().cpu().numpy(),
)
sumr.append(sysr)
suma.append(sysa)
Expand Down
Loading

0 comments on commit e41b091

Please sign in to comment.