Skip to content

Commit

Permalink
(dpa3 alpha) add skip stat (deepmodeling#4501)
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored Dec 24, 2024
2 parents 4e65d8b + 76f28e9 commit 8ac8180
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 1 deletion.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
update_style: str = "res_residual",
update_residual: float = 0.1,
update_residual_init: str = "const",
skip_stat: bool = False,
) -> None:
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(
self.update_style = update_style
self.update_residual = update_residual
self.update_residual_init = update_residual_init
self.skip_stat = skip_stat

def __getitem__(self, key):
if hasattr(self, key):
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def share_params(self, base_class, shared_level, resume=False) -> None:
if shared_level == 0:
# link buffers
if hasattr(self, "mean"):
if not resume:
if not resume and not (getattr(self, "skip_stat", False)):
# in case of change params during resume
base_env = EnvMatStatSe(base_class)
base_env.stats = base_class.stats
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def init_subclass_params(sub_data, sub_class):
update_style=self.repflow_args.update_style,
update_residual=self.repflow_args.update_residual,
update_residual_init=self.repflow_args.update_residual_init,
skip_stat=self.repflow_args.skip_stat,
exclude_types=exclude_types,
env_protection=env_protection,
precision=precision,
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
exclude_types: list[tuple[int, int]] = [],
env_protection: float = 0.0,
precision: str = "float64",
skip_stat: bool = True,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
r"""
Expand Down Expand Up @@ -182,6 +183,7 @@ def __init__(
self.a_compress_rate = a_compress_rate
self.axis_neuron = axis_neuron
self.set_davg_zero = set_davg_zero
self.skip_stat = skip_stat

self.n_dim = n_dim
self.e_dim = e_dim
Expand Down Expand Up @@ -238,6 +240,8 @@ def __init__(
wanted_shape = (self.ntypes, self.nnei, 4)
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
if self.skip_stat:
stddev = stddev * 0.3
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
self.stats = None
Expand Down Expand Up @@ -528,6 +532,8 @@ def compute_input_stats(
The path to the stat file.
"""
if self.skip_stat and self.set_davg_zero:
return
env_mat_stat = EnvMatStatSe(self)
if path is not None:
path = path / env_mat_stat.get_hash()
Expand Down
6 changes: 6 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,12 @@ def dpa3_repflow_args():
default=4,
doc=doc_axis_neuron,
),
Argument(
"skip_stat",
bool,
optional=True,
default=False,
),
Argument(
"update_angle",
bool,
Expand Down

0 comments on commit 8ac8180

Please sign in to comment.