Skip to content

Commit

Permalink
Add data_requirement for dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 28, 2024
1 parent 7f573ab commit f9265d5
Show file tree
Hide file tree
Showing 25 changed files with 327 additions and 158 deletions.
4 changes: 1 addition & 3 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def mixed_types(self) -> bool:
"""
pass

def compute_input_stats(
self, merged: List[dict], path: Optional[DPPath] = None
):
def compute_input_stats(self, merged: callable, path: Optional[DPPath] = None):
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def is_aparam_nall(self) -> bool:
def model_output_type(self) -> str:
"""Get the output type for the model."""

@abstractmethod
def data_requirement(self) -> dict:
"""Get the data requirement for the model."""

@abstractmethod
def serialize(self) -> dict:
"""Serialize the model.
Expand Down
41 changes: 1 addition & 40 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.pt.utils.stat import (
make_stat_input,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -83,7 +80,6 @@ def get_trainer(
multi_task=multi_task,
model_branch=model_branch,
)
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])
Expand All @@ -98,24 +94,6 @@ def prepare_trainer_input_single(
validation_dataset_params = data_dict_single["validation_data"]
training_systems = training_dataset_params["systems"]
validation_systems = validation_dataset_params["systems"]

# noise params
noise_settings = None
if loss_dict_single.get("type", "ener") == "denoise":
noise_settings = {
"noise_type": loss_dict_single.pop("noise_type", "uniform"),
"noise": loss_dict_single.pop("noise", 1.0),
"noise_mode": loss_dict_single.pop("noise_mode", "fix_num"),
"mask_num": loss_dict_single.pop("mask_num", 8),
"mask_prob": loss_dict_single.pop("mask_prob", 0.15),
"same_mask": loss_dict_single.pop("same_mask", False),
"mask_coord": loss_dict_single.pop("mask_coord", False),
"mask_type": loss_dict_single.pop("mask_type", False),
"max_fail_num": loss_dict_single.pop("max_fail_num", 10),
"mask_type_idx": len(model_params_single["type_map"]) - 1,
}
# noise_settings = None

# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
if stat_file_path_single is not None:
Expand All @@ -140,48 +118,32 @@ def prepare_trainer_input_single(
training_dataset_params["batch_size"],
model_params_single,
)
sampled_single = None
else:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
data_stat_nbatch = model_params_single.get("data_stat_nbatch", 10)
sampled_single = make_stat_input(
train_data_single.systems,
train_data_single.dataloaders,
data_stat_nbatch,
)
if noise_settings is not None:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
return (
train_data_single,
validation_data_single,
sampled_single,
stat_file_path_single,
)

if not multi_task:
(
train_data,
validation_data,
sampled,
stat_file_path,
) = prepare_trainer_input_single(
config["model"], config["training"], config["loss"]
)
else:
train_data, validation_data, sampled, stat_file_path = {}, {}, {}, {}
train_data, validation_data, stat_file_path = {}, {}, {}
for model_key in config["model"]["model_dict"]:
(
train_data[model_key],
validation_data[model_key],
sampled[model_key],
stat_file_path[model_key],
) = prepare_trainer_input_single(
config["model"]["model_dict"][model_key],
Expand All @@ -193,7 +155,6 @@ def prepare_trainer_input_single(
trainer = training.Trainer(
config,
train_data,
sampled=sampled,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
Expand Down
17 changes: 5 additions & 12 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
from deepmd.pt.model.task.base_fitting import (
BaseFitting,
)
from deepmd.pt.utils.utils import (
dict_to_device,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -170,7 +167,7 @@ def forward_atomic(

def compute_or_load_stat(
self,
sampled,
sampled_func,
stat_file_path: Optional[DPPath] = None,
):
"""
Expand All @@ -183,22 +180,18 @@ def compute_or_load_stat(
Parameters
----------
sampled
The sampled data frames from different data systems.
sampled_func
The lazy sampled function to get data frames from different data systems.
stat_file_path
The dictionary of paths to the statistics files.
"""
if stat_file_path is not None and self.type_map is not None:
# descriptors and fitting net with different type_map
# should not share the same parameters
stat_file_path /= " ".join(self.type_map)
for data_sys in sampled:
dict_to_device(data_sys)
if sampled is None:
sampled = []
self.descriptor.compute_input_stats(sampled, stat_file_path)
self.descriptor.compute_input_stats(sampled_func, stat_file_path)
if self.fitting_net is not None:
self.fitting_net.compute_output_stats(sampled, stat_file_path)
self.fitting_net.compute_output_stats(sampled_func, stat_file_path)

@torch.jit.export
def get_dim_fparam(self) -> int:
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
abstractmethod,
)
from typing import (
Callable,
Dict,
List,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -86,7 +88,9 @@ def get_dim_emb(self) -> int:
"""Returns the embedding dimension."""
pass

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(
self, merged: Union[Callable, List], path: Optional[DPPath] = None
):
"""Update mean and stddev for DescriptorBlock elements."""
raise NotImplementedError

Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
List,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -128,7 +130,9 @@ def dim_out(self):
def dim_emb(self):
return self.get_dim_emb()

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(
self, merged: Union[Callable, List], path: Optional[DPPath] = None
):
return self.se_atten.compute_input_stats(merged, path)

def serialize(self) -> dict:
Expand Down
15 changes: 6 additions & 9 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
List,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -295,16 +297,11 @@ def dim_emb(self):
"""Returns the embedding dimension g2."""
return self.get_dim_emb()

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(
self, merged: Union[Callable, List], path: Optional[DPPath] = None
):
for ii, descrpt in enumerate([self.repinit, self.repformers]):
merged_tmp = [
{
key: item[key] if not isinstance(item[key], list) else item[key][ii]
for key in item
}
for item in merged
]
descrpt.compute_input_stats(merged_tmp, path)
descrpt.compute_input_stats(merged, path)

def serialize(self) -> dict:
"""Serialize the obj to dict."""
Expand Down
16 changes: 7 additions & 9 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
List,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -157,17 +159,13 @@ def share_params(self, base_class, shared_level, resume=False):
else:
raise NotImplementedError

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(
self, merged: Union[Callable, List], path: Optional[DPPath] = None
):
"""Update mean and stddev for descriptor elements."""
for ii, descrpt in enumerate(self.descriptor_list):
merged_tmp = [
{
key: item[key] if not isinstance(item[key], list) else item[key][ii]
for key in item
}
for item in merged
]
descrpt.compute_input_stats(merged_tmp, path)
# need support for hybrid descriptors
descrpt.compute_input_stats(merged, path)

def forward(
self,
Expand Down
16 changes: 14 additions & 2 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Dict,
List,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -278,12 +280,22 @@ def forward(

return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(
self, merged: Union[Callable, List], path: Optional[DPPath] = None
):
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSe(self)
if path is not None:
path = path / env_mat_stat.get_hash()
env_mat_stat.load_or_compute_stats(merged, path)
if path is None or not path.is_dir():
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged
else:
sampled = []
env_mat_stat.load_or_compute_stats(sampled, path)
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
if not self.set_davg_zero:
Expand Down
16 changes: 14 additions & 2 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import itertools
from typing import (
Callable,
ClassVar,
Dict,
List,
Optional,
Tuple,
Union,
)

import numpy as np
Expand Down Expand Up @@ -387,12 +389,22 @@ def __getitem__(self, key):
else:
raise KeyError(key)

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(
self, merged: Union[Callable, List], path: Optional[DPPath] = None
):
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSe(self)
if path is not None:
path = path / env_mat_stat.get_hash()
env_mat_stat.load_or_compute_stats(merged, path)
if path is None or not path.is_dir():
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged
else:
sampled = []
env_mat_stat.load_or_compute_stats(sampled, path)
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
if not self.set_davg_zero:
Expand Down
16 changes: 14 additions & 2 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Dict,
List,
Optional,
Union,
)

import numpy as np
Expand Down Expand Up @@ -200,12 +202,22 @@ def dim_emb(self):
"""Returns the output dimension of embedding."""
return self.get_dim_emb()

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
def compute_input_stats(
self, merged: Union[Callable, List], path: Optional[DPPath] = None
):
"""Update mean and stddev for descriptor elements."""
env_mat_stat = EnvMatStatSe(self)
if path is not None:
path = path / env_mat_stat.get_hash()
env_mat_stat.load_or_compute_stats(merged, path)
if path is None or not path.is_dir():
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged
else:
sampled = []
env_mat_stat.load_or_compute_stats(sampled, path)
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
if not self.set_davg_zero:
Expand Down
Loading

0 comments on commit f9265d5

Please sign in to comment.