From c55ceacf89c811f14ba6f3f164fd929cfa86425a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 18 Feb 2024 07:46:02 -0500 Subject: [PATCH] pt: refactor data stat (#3285) Signed-off-by: Jinzhe Zeng --- deepmd/common.py | 14 ++ .../descriptor/make_base_descriptor.py | 12 +- deepmd/dpmodel/descriptor/se_e2_a.py | 10 +- deepmd/pt/entrypoints/main.py | 36 ++- .../pt/model/atomic_model/dp_atomic_model.py | 41 ++-- deepmd/pt/model/descriptor/__init__.py | 2 - deepmd/pt/model/descriptor/descriptor.py | 166 ++------------ deepmd/pt/model/descriptor/dpa1.py | 26 +-- deepmd/pt/model/descriptor/dpa2.py | 69 +----- deepmd/pt/model/descriptor/gaussian_lcc.py | 13 +- deepmd/pt/model/descriptor/hybrid.py | 34 +-- deepmd/pt/model/descriptor/repformers.py | 124 ++-------- deepmd/pt/model/descriptor/se_a.py | 136 +++-------- deepmd/pt/model/descriptor/se_atten.py | 119 ++-------- deepmd/pt/model/model/model.py | 16 +- deepmd/pt/model/task/ener.py | 37 +-- deepmd/pt/model/task/fitting.py | 94 -------- deepmd/pt/train/training.py | 3 +- deepmd/pt/utils/env_mat_stat.py | 206 +++++++++++++++++ deepmd/pt/utils/stat.py | 30 --- deepmd/utils/env_mat_stat.py | 213 ++++++++++++++++++ deepmd/utils/path.py | 150 ++++++++++-- source/tests/pt/test_stat.py | 153 ++++++++++--- 23 files changed, 879 insertions(+), 825 deletions(-) create mode 100644 deepmd/pt/utils/env_mat_stat.py create mode 100644 deepmd/utils/env_mat_stat.py diff --git a/deepmd/common.py b/deepmd/common.py index 05d02234b4..691cc262df 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -5,6 +5,9 @@ import platform import shutil import warnings +from hashlib import ( + sha1, +) from pathlib import ( Path, ) @@ -299,3 +302,14 @@ def symlink_prefix_files(old_prefix: str, new_prefix: str): os.symlink(os.path.relpath(ori_ff, os.path.dirname(new_ff)), new_ff) else: shutil.copyfile(ori_ff, new_ff) + + +def get_hash(obj) -> str: + """Get hash of object. + + Parameters + ---------- + obj + object to hash + """ + return sha1(json.dumps(obj).encode("utf-8")).hexdigest() diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py index 29d3ad6d92..b7a8bfebcf 100644 --- a/deepmd/dpmodel/descriptor/make_base_descriptor.py +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -9,6 +9,10 @@ Optional, ) +from deepmd.utils.path import ( + DPPath, +) + def make_base_descriptor( t_tensor, @@ -69,14 +73,12 @@ def distinguish_types(self) -> bool: """ pass - def compute_input_stats(self, merged): + def compute_input_stats( + self, merged: List[dict], path: Optional[DPPath] = None + ): """Update mean and stddev for descriptor elements.""" raise NotImplementedError - def init_desc_stat(self, **kwargs): - """Initialize the model bias by the statistics.""" - raise NotImplementedError - @abstractmethod def fwd( self, diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 4e26afa729..3b98f9dc67 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -1,6 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import numpy as np +from deepmd.utils.path import ( + DPPath, +) + try: from deepmd._version import version as __version__ except ImportError: @@ -229,14 +233,10 @@ def get_ntypes(self) -> int: """Returns the number of element types.""" return self.ntypes - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" raise NotImplementedError - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): - """Initialize the model bias by the statistics.""" - raise NotImplementedError - def cal_g( self, ss, diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 29ef8761ff..b260000d87 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -12,6 +12,7 @@ Union, ) +import h5py import torch import torch.distributed as dist import torch.version @@ -46,12 +47,6 @@ from deepmd.pt.infer import ( inference, ) -from deepmd.pt.model.descriptor import ( - Descriptor, -) -from deepmd.pt.model.task import ( - Fitting, -) from deepmd.pt.train import ( training, ) @@ -69,7 +64,9 @@ ) from deepmd.pt.utils.stat import ( make_stat_input, - process_stat_path, +) +from deepmd.utils.path import ( + DPPath, ) from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter @@ -134,19 +131,16 @@ def prepare_trainer_input_single( # noise_settings = None # stat files - hybrid_descrpt = model_params_single["descriptor"]["type"] == "hybrid" - if not hybrid_descrpt: - stat_file_path_single, has_stat_file_path = process_stat_path( - data_dict_single.get("stat_file", None), - data_dict_single.get("stat_file_dir", f"stat_files{suffix}"), - model_params_single, - Descriptor, - Fitting, - ) - else: ### TODO hybrid descriptor not implemented - raise NotImplementedError( - "data stat for hybrid descriptor is not implemented!" - ) + stat_file_path_single = data_dict_single.get("stat_file", None) + if stat_file_path_single is not None: + if Path(stat_file_path_single).is_dir(): + raise ValueError( + f"stat_file should be a file, not a directory: {stat_file_path_single}" + ) + if not Path(stat_file_path_single).is_file(): + with h5py.File(stat_file_path_single, "w") as f: + pass + stat_file_path_single = DPPath(stat_file_path_single, "a") # validation and training data validation_data_single = DpLoaderSet( @@ -156,7 +150,7 @@ def prepare_trainer_input_single( type_split=type_split, noise_settings=noise_settings, ) - if ckpt or finetune_model or has_stat_file_path: + if ckpt or finetune_model: train_data_single = DpLoaderSet( training_systems, training_dataset_params["batch_size"], diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 17b70e4701..aafd2831b3 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -1,13 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy import logging -import os import sys from typing import ( Dict, List, Optional, - Union, ) import torch @@ -24,6 +22,9 @@ from deepmd.pt.utils.utils import ( dict_to_device, ) +from deepmd.utils.path import ( + DPPath, +) from .base_atomic_model import ( BaseAtomicModel, @@ -160,9 +161,8 @@ def forward_atomic( def compute_or_load_stat( self, - type_map: Optional[List[str]] = None, - sampled=None, - stat_file_path_dict: Optional[Dict[str, Union[str, List[str]]]] = None, + sampled, + stat_file_path: Optional[DPPath] = None, ): """ Compute or load the statistics parameters of the model, @@ -174,31 +174,22 @@ def compute_or_load_stat( Parameters ---------- - type_map - Mapping atom type to the name (str) of the type. - For example `type_map[1]` gives the name of the type 1. sampled The sampled data frames from different data systems. - stat_file_path_dict + stat_file_path The dictionary of paths to the statistics files. """ - if sampled is not None: # move data to device - for data_sys in sampled: - dict_to_device(data_sys) - if stat_file_path_dict is not None: - if not isinstance(stat_file_path_dict["descriptor"], list): - stat_file_dir = os.path.dirname(stat_file_path_dict["descriptor"]) - else: - stat_file_dir = os.path.dirname(stat_file_path_dict["descriptor"][0]) - if not os.path.exists(stat_file_dir): - os.mkdir(stat_file_dir) - self.descriptor.compute_or_load_stat( - type_map, sampled, stat_file_path_dict["descriptor"] - ) + if stat_file_path is not None and self.type_map is not None: + # descriptors and fitting net with different type_map + # should not share the same parameters + stat_file_path /= " ".join(self.type_map) + for data_sys in sampled: + dict_to_device(data_sys) + if sampled is None: + sampled = [] + self.descriptor.compute_input_stats(sampled, stat_file_path) if self.fitting_net is not None: - self.fitting_net.compute_or_load_stat( - type_map, sampled, stat_file_path_dict["fitting_net"] - ) + self.fitting_net.compute_output_stats(sampled, stat_file_path) @torch.jit.export def get_dim_fparam(self) -> int: diff --git a/deepmd/pt/model/descriptor/__init__.py b/deepmd/pt/model/descriptor/__init__.py index 4252e34905..1c2e943369 100644 --- a/deepmd/pt/model/descriptor/__init__.py +++ b/deepmd/pt/model/descriptor/__init__.py @@ -2,7 +2,6 @@ from .descriptor import ( Descriptor, DescriptorBlock, - compute_std, make_default_type_embedding, ) from .dpa1 import ( @@ -32,7 +31,6 @@ __all__ = [ "Descriptor", "DescriptorBlock", - "compute_std", "make_default_type_embedding", "DescrptBlockSeA", "DescrptBlockSeAtten", diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index bd6839834e..16659e444d 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -6,20 +6,31 @@ ) from typing import ( Callable, + Dict, List, Optional, - Union, ) -import numpy as np import torch from deepmd.pt.model.network.network import ( TypeEmbedNet, ) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env_mat_stat import ( + EnvMatStatSeA, +) from deepmd.pt.utils.plugin import ( Plugin, ) +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.path import ( + DPPath, +) from .base_descriptor import ( BaseDescriptor, @@ -59,19 +70,6 @@ class SomeDescript(Descriptor): """ return Descriptor.__plugins.register(key) - @classmethod - def get_stat_name(cls, ntypes, type_name, **kwargs): - """ - Get the name for the statistic file of the descriptor. - Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name. - """ - if cls is not Descriptor: - raise NotImplementedError("get_stat_name is not implemented!") - descrpt_type = type_name - return Descriptor.__plugins.plugins[descrpt_type].get_stat_name( - ntypes, type_name, **kwargs - ) - @classmethod def get_data_process_key(cls, config): """ @@ -92,98 +90,6 @@ def data_stat_key(self): """ raise NotImplementedError("data_stat_key is not implemented!") - def compute_or_load_stat( - self, - type_map: List[str], - sampled=None, - stat_file_path: Optional[Union[str, List[str]]] = None, - ): - """ - Compute or load the statistics parameters of the descriptor. - Calculate and save the mean and standard deviation of the descriptor to `stat_file_path` - if `sampled` is not None, otherwise load them from `stat_file_path`. - - Parameters - ---------- - type_map - Mapping atom type to the name (str) of the type. - For example `type_map[1]` gives the name of the type 1. - sampled - The sampled data frames from different data systems. - stat_file_path - The path to the statistics files. - """ - # TODO support hybrid descriptor - descrpt_stat_key = self.data_stat_key - if sampled is not None: # compute the statistics results - tmp_dict = self.compute_input_stats(sampled) - result_dict = {key: tmp_dict[key] for key in descrpt_stat_key} - result_dict["type_map"] = type_map - if stat_file_path is not None: - self.save_stats(result_dict, stat_file_path) - else: # load the statistics results - assert stat_file_path is not None, "No stat file to load!" - result_dict = self.load_stats(type_map, stat_file_path) - self.init_desc_stat(**result_dict) - - def save_stats(self, result_dict, stat_file_path: Union[str, List[str]]): - """ - Save the statistics results to `stat_file_path`. - - Parameters - ---------- - result_dict - The dictionary of statistics results. - stat_file_path - The path to the statistics file(s). - """ - if not isinstance(stat_file_path, list): - log.info(f"Saving stat file to {stat_file_path}") - np.savez_compressed(stat_file_path, **result_dict) - else: # TODO hybrid descriptor not implemented - raise NotImplementedError( - "save_stats for hybrid descriptor is not implemented!" - ) - - def load_stats(self, type_map, stat_file_path: Union[str, List[str]]): - """ - Load the statistics results to `stat_file_path`. - - Parameters - ---------- - type_map - Mapping atom type to the name (str) of the type. - For example `type_map[1]` gives the name of the type 1. - stat_file_path - The path to the statistics file(s). - - Returns - ------- - result_dict - The dictionary of statistics results. - """ - descrpt_stat_key = self.data_stat_key - target_type_map = type_map - if not isinstance(stat_file_path, list): - log.info(f"Loading stat file from {stat_file_path}") - stats = np.load(stat_file_path) - stat_type_map = list(stats["type_map"]) - missing_type = [i for i in target_type_map if i not in stat_type_map] - assert not missing_type, ( - f"These type are not in stat file {stat_file_path}: {missing_type}! " - f"Please change the stat file path!" - ) - idx_map = [stat_type_map.index(i) for i in target_type_map] - if stats[descrpt_stat_key[0]].size: # not empty - result_dict = {key: stats[key][idx_map] for key in descrpt_stat_key} - else: - result_dict = {key: [] for key in descrpt_stat_key} - else: # TODO hybrid descriptor not implemented - raise NotImplementedError( - "load_stats for hybrid descriptor is not implemented!" - ) - return result_dict - def __new__(cls, *args, **kwargs): if cls is Descriptor: try: @@ -275,12 +181,12 @@ def get_dim_emb(self) -> int: """Returns the embedding dimension.""" pass - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for DescriptorBlock elements.""" raise NotImplementedError - def init_desc_stat(self, **kwargs): - """Initialize mean and stddev by the statistics.""" + def get_stats(self) -> Dict[str, StatItem]: + """Get the statistics of the descriptor.""" raise NotImplementedError def share_params(self, base_class, shared_level, resume=False): @@ -291,28 +197,14 @@ def share_params(self, base_class, shared_level, resume=False): # link buffers if hasattr(self, "mean") and not resume: # in case of change params during resume - sumr_base, suma_base, sumn_base, sumr2_base, suma2_base = ( - base_class.sumr, - base_class.suma, - base_class.sumn, - base_class.sumr2, - base_class.suma2, - ) - sumr, suma, sumn, sumr2, suma2 = ( - self.sumr, - self.suma, - self.sumn, - self.sumr2, - self.suma2, - ) - stat_dict = { - "sumr": sumr_base + sumr, - "suma": suma_base + suma, - "sumn": sumn_base + sumn, - "sumr2": sumr2_base + sumr2, - "suma2": suma2_base + suma2, - } - base_class.init_desc_stat(**stat_dict) + base_env = EnvMatStatSeA(base_class) + base_env.stats = base_class.stats + for kk in base_class.get_stats(): + base_env.stats[kk] += self.get_stats()[kk] + mean, stddev = base_env() + if not base_class.set_davg_zero: + base_class.mean.copy_(torch.tensor(mean, device=env.DEVICE)) + base_class.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) self.mean = base_class.mean self.stddev = base_class.stddev # self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model @@ -335,16 +227,6 @@ def forward( pass -def compute_std(sumv2, sumv, sumn, rcut_r): - """Compute standard deviation.""" - if sumn == 0: - return 1.0 / rcut_r - val = np.sqrt(sumv2 / sumn - np.multiply(sumv / sumn, sumv / sumn)) - if np.abs(val) < 1e-2: - val = 1e-2 - return val - - def make_default_type_embedding( ntypes, ): diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 76cff174af..6bdb5c2cb3 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -12,6 +12,9 @@ from deepmd.pt.model.network.network import ( TypeEmbedNet, ) +from deepmd.utils.path import ( + DPPath, +) from .se_atten import ( DescrptBlockSeAtten, @@ -119,27 +122,8 @@ def dim_out(self): def dim_emb(self): return self.get_dim_emb() - def compute_input_stats(self, merged): - return self.se_atten.compute_input_stats(merged) - - def init_desc_stat( - self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs - ): - assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) - self.se_atten.init_desc_stat(sumr, suma, sumn, sumr2, suma2) - - @classmethod - def get_stat_name( - cls, ntypes, type_name, rcut=None, rcut_smth=None, sel=None, **kwargs - ): - """ - Get the name for the statistic file of the descriptor. - Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name. - """ - descrpt_type = type_name - assert descrpt_type in ["dpa1", "se_atten"] - assert all(x is not None for x in [rcut, rcut_smth, sel]) - return f"stat_file_descrpt_dpa1_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz" + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): + return self.se_atten.compute_input_stats(merged, path) @classmethod def get_data_process_key(cls, config): diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 6cefaf6f38..0122dcacb8 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -18,6 +18,9 @@ build_multiple_neighbor_list, get_multiple_nlist_key, ) +from deepmd.utils.path import ( + DPPath, +) from .repformers import ( DescrptBlockRepformers, @@ -286,8 +289,7 @@ def dim_emb(self): """Returns the embedding dimension g2.""" return self.get_dim_emb() - def compute_input_stats(self, merged): - sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): for ii, descrpt in enumerate([self.repinit, self.repformers]): merged_tmp = [ { @@ -296,68 +298,7 @@ def compute_input_stats(self, merged): } for item in merged ] - tmp_stat_dict = descrpt.compute_input_stats(merged_tmp) - sumr.append(tmp_stat_dict["sumr"]) - suma.append(tmp_stat_dict["suma"]) - sumn.append(tmp_stat_dict["sumn"]) - sumr2.append(tmp_stat_dict["sumr2"]) - suma2.append(tmp_stat_dict["suma2"]) - return { - "sumr": sumr, - "suma": suma, - "sumn": sumn, - "sumr2": sumr2, - "suma2": suma2, - } - - def init_desc_stat( - self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs - ): - assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) - for ii, descrpt in enumerate([self.repinit, self.repformers]): - stat_dict_ii = { - "sumr": sumr[ii], - "suma": suma[ii], - "sumn": sumn[ii], - "sumr2": sumr2[ii], - "suma2": suma2[ii], - } - descrpt.init_desc_stat(**stat_dict_ii) - - @classmethod - def get_stat_name( - cls, - ntypes, - type_name, - repinit_rcut=None, - repinit_rcut_smth=None, - repinit_nsel=None, - repformer_rcut=None, - repformer_rcut_smth=None, - repformer_nsel=None, - **kwargs, - ): - """ - Get the name for the statistic file of the descriptor. - Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name. - """ - descrpt_type = type_name - assert descrpt_type in ["dpa2"] - assert all( - x is not None - for x in [ - repinit_rcut, - repinit_rcut_smth, - repinit_nsel, - repformer_rcut, - repformer_rcut_smth, - repformer_nsel, - ] - ) - return ( - f"stat_file_descrpt_dpa2_repinit_rcut{repinit_rcut:.2f}_smth{repinit_rcut_smth:.2f}_sel{repinit_nsel}" - f"_repformer_rcut{repformer_rcut:.2f}_smth{repformer_rcut_smth:.2f}_sel{repformer_nsel}_ntypes{ntypes}.npz" - ) + descrpt.compute_input_stats(merged_tmp) @classmethod def get_data_process_key(cls, config): diff --git a/deepmd/pt/model/descriptor/gaussian_lcc.py b/deepmd/pt/model/descriptor/gaussian_lcc.py index 0972b90279..72c9f27b2a 100644 --- a/deepmd/pt/model/descriptor/gaussian_lcc.py +++ b/deepmd/pt/model/descriptor/gaussian_lcc.py @@ -1,4 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, + Optional, +) + import torch import torch.nn as nn @@ -13,6 +18,9 @@ from deepmd.pt.utils import ( env, ) +from deepmd.utils.path import ( + DPPath, +) class DescrptGaussianLcc(Descriptor): @@ -154,11 +162,8 @@ def dim_emb(self): """Returns the output dimension of pair representation.""" return self.pair_embed_dim - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - return [], [], [], [], [] - - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2, **kwargs): pass def forward( diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index c5c08c760d..d6678f2a4b 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -13,6 +13,9 @@ Identity, Linear, ) +from deepmd.utils.path import ( + DPPath, +) @DescriptorBlock.register("hybrid") @@ -150,9 +153,8 @@ def share_params(self, base_class, shared_level, resume=False): else: raise NotImplementedError - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] for ii, descrpt in enumerate(self.descriptor_list): merged_tmp = [ { @@ -161,33 +163,7 @@ def compute_input_stats(self, merged): } for item in merged ] - tmp_stat_dict = descrpt.compute_input_stats(merged_tmp) - sumr.append(tmp_stat_dict["sumr"]) - suma.append(tmp_stat_dict["suma"]) - sumn.append(tmp_stat_dict["sumn"]) - sumr2.append(tmp_stat_dict["sumr2"]) - suma2.append(tmp_stat_dict["suma2"]) - return { - "sumr": sumr, - "suma": suma, - "sumn": sumn, - "sumr2": sumr2, - "suma2": suma2, - } - - def init_desc_stat( - self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs - ): - assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) - for ii, descrpt in enumerate(self.descriptor_list): - stat_dict_ii = { - "sumr": sumr[ii], - "suma": suma[ii], - "sumn": sumn[ii], - "sumr2": sumr2[ii], - "suma2": suma2[ii], - } - descrpt.init_desc_stat(**stat_dict_ii) + descrpt.compute_input_stats(merged_tmp, path) def forward( self, diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 26467124b8..76051a52ed 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -1,15 +1,14 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Dict, List, Optional, ) -import numpy as np import torch from deepmd.pt.model.descriptor.descriptor import ( DescriptorBlock, - compute_std, ) from deepmd.pt.model.descriptor.env_mat import ( prod_env_mat_se_a, @@ -20,19 +19,22 @@ from deepmd.pt.utils import ( env, ) -from deepmd.pt.utils.nlist import ( - extend_input_and_build_neighbor_list, +from deepmd.pt.utils.env_mat_stat import ( + EnvMatStatSeA, ) from deepmd.pt.utils.utils import ( get_activation_fn, ) +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.path import ( + DPPath, +) from .repformer_layer import ( RepformerLayer, ) -from .se_atten import ( - analyze_descrpt, -) mydtype = env.GLOBAL_PT_FLOAT_PRECISION mydev = env.DEVICE @@ -149,6 +151,7 @@ def __init__( stddev = torch.ones(sshape, dtype=mydtype, device=mydev) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) + self.stats = None def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -268,99 +271,22 @@ def forward( return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - ndescrpt = self.nnei * 4 - sumr = [] - suma = [] - sumn = [] - sumr2 = [] - suma2 = [] - mixed_type = "real_natoms_vec" in merged[0] - for system in merged: - coord, atype, box, natoms = ( - system["coord"], - system["atype"], - system["box"], - system["natoms"], - ) - ( - extended_coord, - extended_atype, - mapping, - nlist, - ) = extend_input_and_build_neighbor_list( - coord, - atype, - self.get_rcut(), - self.get_sel(), - distinguish_types=self.distinguish_types(), - box=box, - ) - env_mat, _, _ = prod_env_mat_se_a( - extended_coord, - nlist, - atype, - self.mean, - self.stddev, - self.rcut, - self.rcut_smth, - ) - if not mixed_type: - sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( - env_mat.detach().cpu().numpy(), ndescrpt, natoms - ) - else: - real_natoms_vec = system["real_natoms_vec"] - sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( - env_mat.detach().cpu().numpy(), - ndescrpt, - real_natoms_vec, - mixed_type=mixed_type, - real_atype=atype.detach().cpu().numpy(), - ) - sumr.append(sysr) - suma.append(sysa) - sumn.append(sysn) - sumr2.append(sysr2) - suma2.append(sysa2) - sumr = np.sum(sumr, axis=0) - suma = np.sum(suma, axis=0) - sumn = np.sum(sumn, axis=0) - sumr2 = np.sum(sumr2, axis=0) - suma2 = np.sum(suma2, axis=0) - return { - "sumr": sumr, - "suma": suma, - "sumn": sumn, - "sumr2": sumr2, - "suma2": suma2, - } - - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2, **kwargs): - all_davg = [] - all_dstd = [] - for type_i in range(self.ntypes): - davgunit = [[sumr[type_i] / (sumn[type_i] + 1e-15), 0, 0, 0]] - dstdunit = [ - [ - compute_std(sumr2[type_i], sumr[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - ] - ] - davg = np.tile(davgunit, [self.nnei, 1]) - dstd = np.tile(dstdunit, [self.nnei, 1]) - all_davg.append(davg) - all_dstd.append(dstd) - self.sumr = sumr - self.suma = suma - self.sumn = sumn - self.sumr2 = sumr2 - self.suma2 = suma2 + env_mat_stat = EnvMatStatSeA(self) + if path is not None: + path = path / env_mat_stat.get_hash() + env_mat_stat.load_or_compute_stats(merged, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() if not self.set_davg_zero: - mean = np.stack(all_davg) self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) - stddev = np.stack(all_dstd) self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) + + def get_stats(self) -> Dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index c086fe1cc2..33cc3ee9e2 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( ClassVar, + Dict, List, Optional, Tuple, @@ -12,7 +13,6 @@ from deepmd.pt.model.descriptor import ( Descriptor, DescriptorBlock, - compute_std, prod_env_mat_se_a, ) from deepmd.pt.utils import ( @@ -22,6 +22,15 @@ PRECISION_DICT, RESERVED_PRECISON_DICT, ) +from deepmd.pt.utils.env_mat_stat import ( + EnvMatStatSeA, +) +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.path import ( + DPPath, +) try: from typing import ( @@ -41,9 +50,6 @@ from deepmd.pt.utils.exclude_mask import ( PairExcludeMask, ) -from deepmd.pt.utils.nlist import ( - extend_input_and_build_neighbor_list, -) @Descriptor.register("se_e2_a") @@ -114,28 +120,9 @@ def dim_out(self): """Returns the output dimension of this descriptor.""" return self.sea.dim_out - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - return self.sea.compute_input_stats(merged) - - def init_desc_stat( - self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs - ): - assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) - self.sea.init_desc_stat(sumr, suma, sumn, sumr2, suma2) - - @classmethod - def get_stat_name( - cls, ntypes, type_name, rcut=None, rcut_smth=None, sel=None, **kwargs - ): - """ - Get the name for the statistic file of the descriptor. - Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name. - """ - descrpt_type = type_name - assert descrpt_type in ["se_e2_a"] - assert all(x is not None for x in [rcut, rcut_smth, sel]) - return f"stat_file_descrpt_sea_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz" + return self.sea.compute_input_stats(merged, path) @classmethod def get_data_process_key(cls, config): @@ -330,6 +317,7 @@ def __init__( resnet_dt=self.resnet_dt, ) self.filter_layers = filter_layers + self.stats = None def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -391,91 +379,29 @@ def __getitem__(self, key): else: raise KeyError(key) - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - sumr = [] - suma = [] - sumn = [] - sumr2 = [] - suma2 = [] - for system in merged: - coord, atype, box, natoms = ( - system["coord"], - system["atype"], - system["box"], - system["natoms"], - ) - ( - extended_coord, - extended_atype, - mapping, - nlist, - ) = extend_input_and_build_neighbor_list( - coord, - atype, - self.get_rcut(), - self.get_sel(), - distinguish_types=self.distinguish_types(), - box=box, - ) - env_mat, _, _ = prod_env_mat_se_a( - extended_coord, - nlist, - atype, - self.mean, - self.stddev, - self.rcut, - self.rcut_smth, - ) - sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( - env_mat.detach().cpu().numpy(), self.ndescrpt, natoms - ) - sumr.append(sysr) - suma.append(sysa) - sumn.append(sysn) - sumr2.append(sysr2) - suma2.append(sysa2) - sumr = np.sum(sumr, axis=0) - suma = np.sum(suma, axis=0) - sumn = np.sum(sumn, axis=0) - sumr2 = np.sum(sumr2, axis=0) - suma2 = np.sum(suma2, axis=0) - return { - "sumr": sumr, - "suma": suma, - "sumn": sumn, - "sumr2": sumr2, - "suma2": suma2, - } - - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2, **kwargs): - all_davg = [] - all_dstd = [] - for type_i in range(self.ntypes): - davgunit = [[sumr[type_i] / (sumn[type_i] + 1e-15), 0, 0, 0]] - dstdunit = [ - [ - compute_std(sumr2[type_i], sumr[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - ] - ] - davg = np.tile(davgunit, [self.nnei, 1]) - dstd = np.tile(dstdunit, [self.nnei, 1]) - all_davg.append(davg) - all_dstd.append(dstd) - self.sumr = sumr - self.suma = suma - self.sumn = sumn - self.sumr2 = sumr2 - self.suma2 = suma2 + env_mat_stat = EnvMatStatSeA(self) + if path is not None: + path = path / env_mat_stat.get_hash() + env_mat_stat.load_or_compute_stats(merged, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() + if not self.set_davg_zero: + self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) + self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) if not self.set_davg_zero: - mean = np.stack(all_davg) self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) - stddev = np.stack(all_dstd) self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) + def get_stats(self) -> Dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats + def forward( self, nlist: torch.Tensor, diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index d4dc0cd054..410a2039aa 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Dict, List, Optional, ) @@ -9,7 +10,6 @@ from deepmd.pt.model.descriptor.descriptor import ( DescriptorBlock, - compute_std, ) from deepmd.pt.model.descriptor.env_mat import ( prod_env_mat_se_a, @@ -21,8 +21,14 @@ from deepmd.pt.utils import ( env, ) -from deepmd.pt.utils.nlist import ( - extend_input_and_build_neighbor_list, +from deepmd.pt.utils.env_mat_stat import ( + EnvMatStatSeA, +) +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.path import ( + DPPath, ) @@ -135,6 +141,7 @@ def __init__( ) filter_layers.append(one) self.filter_layers = torch.nn.ModuleList(filter_layers) + self.stats = None def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -185,102 +192,26 @@ def dim_emb(self): """Returns the output dimension of embedding.""" return self.get_dim_emb() - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - sumr = [] - suma = [] - sumn = [] - sumr2 = [] - suma2 = [] - mixed_type = "real_natoms_vec" in merged[0] - for system in merged: - coord, atype, box, natoms = ( - system["coord"], - system["atype"], - system["box"], - system["natoms"], - ) - ( - extended_coord, - extended_atype, - mapping, - nlist, - ) = extend_input_and_build_neighbor_list( - coord, - atype, - self.get_rcut(), - self.get_sel(), - distinguish_types=self.distinguish_types(), - box=box, - ) - env_mat, _, _ = prod_env_mat_se_a( - extended_coord, - nlist, - atype, - self.mean, - self.stddev, - self.rcut, - self.rcut_smth, - ) - if not mixed_type: - sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( - env_mat.detach().cpu().numpy(), self.ndescrpt, natoms - ) - else: - real_natoms_vec = system["real_natoms_vec"] - sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( - env_mat.detach().cpu().numpy(), - self.ndescrpt, - real_natoms_vec, - mixed_type=mixed_type, - real_atype=atype.detach().cpu().numpy(), - ) - sumr.append(sysr) - suma.append(sysa) - sumn.append(sysn) - sumr2.append(sysr2) - suma2.append(sysa2) - sumr = np.sum(sumr, axis=0) - suma = np.sum(suma, axis=0) - sumn = np.sum(sumn, axis=0) - sumr2 = np.sum(sumr2, axis=0) - suma2 = np.sum(suma2, axis=0) - return { - "sumr": sumr, - "suma": suma, - "sumn": sumn, - "sumr2": sumr2, - "suma2": suma2, - } - - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2, **kwargs): - all_davg = [] - all_dstd = [] - for type_i in range(self.ntypes): - davgunit = [[sumr[type_i] / (sumn[type_i] + 1e-15), 0, 0, 0]] - dstdunit = [ - [ - compute_std(sumr2[type_i], sumr[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - ] - ] - davg = np.tile(davgunit, [self.nnei, 1]) - dstd = np.tile(dstdunit, [self.nnei, 1]) - all_davg.append(davg) - all_dstd.append(dstd) - self.sumr = sumr - self.suma = suma - self.sumn = sumn - self.sumr2 = sumr2 - self.suma2 = suma2 + env_mat_stat = EnvMatStatSeA(self) + if path is not None: + path = path / env_mat_stat.get_hash() + env_mat_stat.load_or_compute_stats(merged, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() if not self.set_davg_zero: - mean = np.stack(all_davg) self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) - stddev = np.stack(all_dstd) self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) + def get_stats(self) -> Dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats + def forward( self, nlist: torch.Tensor, diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py index 51c5fcf123..d98d25d539 100644 --- a/deepmd/pt/model/model/model.py +++ b/deepmd/pt/model/model/model.py @@ -1,6 +1,14 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + import torch +from deepmd.utils.path import ( + DPPath, +) + class BaseModel(torch.nn.Module): def __init__(self): @@ -9,9 +17,8 @@ def __init__(self): def compute_or_load_stat( self, - type_map=None, - sampled=None, - stat_file_path=None, + sampled, + stat_file_path: Optional[DPPath] = None, ): """ Compute or load the statistics parameters of the model, @@ -23,9 +30,6 @@ def compute_or_load_stat( Parameters ---------- - type_map - Mapping atom type to the name (str) of the type. - For example `type_map[1]` gives the name of the type 1. sampled The sampled data frames from different data systems. stat_file_path diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index b6ca12b9d8..2f5afaf26e 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -30,6 +30,9 @@ from deepmd.pt.utils.stat import ( compute_output_bias, ) +from deepmd.utils.path import ( + DPPath, +) dtype = env.GLOBAL_PT_FLOAT_PRECISION device = env.DEVICE @@ -152,17 +155,21 @@ def data_stat_key(self): """ return ["bias_atom_e"] - def compute_output_stats(self, merged): + def compute_output_stats(self, merged, stat_file_path: Optional[DPPath] = None): energy = [item["energy"] for item in merged] mixed_type = "real_natoms_vec" in merged[0] if mixed_type: input_natoms = [item["real_natoms_vec"] for item in merged] else: input_natoms = [item["natoms"] for item in merged] - bias_atom_e = compute_output_bias(energy, input_natoms, rcond=self.rcond) - return {"bias_atom_e": bias_atom_e} - - def init_fitting_stat(self, bias_atom_e=None, **kwargs): + if stat_file_path is not None: + stat_file_path = stat_file_path / "bias_atom_e" + if stat_file_path is not None and stat_file_path.is_file(): + bias_atom_e = stat_file_path.load_numpy() + else: + bias_atom_e = compute_output_bias(energy, input_natoms, rcond=self.rcond) + if stat_file_path is not None: + stat_file_path.save_numpy(bias_atom_e) assert all(x is not None for x in [bias_atom_e]) self.bias_atom_e.copy_( torch.tensor(bias_atom_e, device=env.DEVICE).view( @@ -225,16 +232,6 @@ def __init__( **kwargs, ) - @classmethod - def get_stat_name(cls, ntypes, type_name="ener", **kwargs): - """ - Get the name for the statistic file of the fitting. - Usually use the combination of fitting net name and ntypes as the statistic file name. - """ - fitting_type = type_name - assert fitting_type in ["ener"] - return f"stat_file_fitting_ener_ntypes{ntypes}.npz" - @Fitting.register("direct_force") @Fitting.register("direct_force_ener") @@ -327,16 +324,6 @@ def serialize(self) -> dict: def deserialize(cls) -> "EnergyFittingNetDirect": raise NotImplementedError - @classmethod - def get_stat_name(cls, ntypes, type_name="ener", **kwargs): - """ - Get the name for the statistic file of the fitting. - Usually use the combination of fitting net name and ntypes as the statistic file name. - """ - fitting_type = type_name - assert fitting_type in ["direct_force", "direct_force_ener"] - return f"stat_file_fitting_direct_ntypes{ntypes}.npz" - def forward( self, inputs: torch.Tensor, diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index db8daff802..f8f6e3f5dc 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -8,7 +8,6 @@ Callable, List, Optional, - Union, ) import numpy as np @@ -127,19 +126,6 @@ def share_params(self, base_class, shared_level, resume=False): else: raise NotImplementedError - @classmethod - def get_stat_name(cls, ntypes, type_name="ener", **kwargs): - """ - Get the name for the statistic file of the fitting. - Usually use the combination of fitting net name and ntypes as the statistic file name. - """ - if cls is not Fitting: - raise NotImplementedError("get_stat_name is not implemented!") - fitting_type = type_name - return Fitting.__plugins.plugins[fitting_type].get_stat_name( - ntypes, type_name, **kwargs - ) - @property def data_stat_key(self): """ @@ -148,86 +134,6 @@ def data_stat_key(self): """ raise NotImplementedError("data_stat_key is not implemented!") - def compute_or_load_stat( - self, - type_map: List[str], - sampled=None, - stat_file_path: Optional[Union[str, List[str]]] = None, - ): - """ - Compute or load the statistics parameters of the fitting net. - Calculate and save the output bias to `stat_file_path` - if `sampled` is not None, otherwise load them from `stat_file_path`. - - Parameters - ---------- - type_map - Mapping atom type to the name (str) of the type. - For example `type_map[1]` gives the name of the type 1. - sampled - The sampled data frames from different data systems. - stat_file_path - The path to the statistics files. - """ - fitting_stat_key = self.data_stat_key - if sampled is not None: - tmp_dict = self.compute_output_stats(sampled) - result_dict = {key: tmp_dict[key] for key in fitting_stat_key} - result_dict["type_map"] = type_map - self.save_stats(result_dict, stat_file_path) - else: # load the statistics results - assert stat_file_path is not None, "No stat file to load!" - result_dict = self.load_stats(type_map, stat_file_path) - self.init_fitting_stat(**result_dict) - - def save_stats(self, result_dict, stat_file_path: str): - """ - Save the statistics results to `stat_file_path`. - - Parameters - ---------- - result_dict - The dictionary of statistics results. - stat_file_path - The path to the statistics file(s). - """ - log.info(f"Saving stat file to {stat_file_path}") - np.savez_compressed(stat_file_path, **result_dict) - - def load_stats(self, type_map, stat_file_path: str): - """ - Load the statistics results to `stat_file_path`. - - Parameters - ---------- - type_map - Mapping atom type to the name (str) of the type. - For example `type_map[1]` gives the name of the type 1. - stat_file_path - The path to the statistics file(s). - - Returns - ------- - result_dict - The dictionary of statistics results. - """ - fitting_stat_key = self.data_stat_key - target_type_map = type_map - log.info(f"Loading stat file from {stat_file_path}") - stats = np.load(stat_file_path) - stat_type_map = list(stats["type_map"]) - missing_type = [i for i in target_type_map if i not in stat_type_map] - assert not missing_type, ( - f"These type are not in stat file {stat_file_path}: {missing_type}! " - f"Please change the stat file path!" - ) - idx_map = [stat_type_map.index(i) for i in target_type_map] - if stats[fitting_stat_key[0]].size: # not empty - result_dict = {key: stats[key][idx_map] for key in fitting_stat_key} - else: - result_dict = {key: [] for key in fitting_stat_key} - return result_dict - def change_energy_bias( self, config, model, old_type_map, new_type_map, bias_shift="delta", ntest=10 ): diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index b2cac5a5eb..8537be6e12 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -185,9 +185,8 @@ def get_single_model(_model_params, _sampled, _stat_file_path): model = get_model(deepcopy(_model_params)).to(DEVICE) if not model_params.get("resuming", False): model.compute_or_load_stat( - type_map=_model_params["type_map"], sampled=_sampled, - stat_file_path_dict=_stat_file_path, + stat_file_path=_stat_file_path, ) return model diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py new file mode 100644 index 0000000000..5247ce08ba --- /dev/null +++ b/deepmd/pt/utils/env_mat_stat.py @@ -0,0 +1,206 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + TYPE_CHECKING, + Dict, + Iterator, + List, +) + +import numpy as np +import torch + +from deepmd.common import ( + get_hash, +) +from deepmd.pt.model.descriptor.env_mat import ( + prod_env_mat_se_a, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) +from deepmd.utils.env_mat_stat import EnvMatStat as BaseEnvMatStat +from deepmd.utils.env_mat_stat import ( + StatItem, +) + +if TYPE_CHECKING: + from deepmd.pt.model.descriptor import ( + DescriptorBlock, + ) + + +class EnvMatStat(BaseEnvMatStat): + def compute_stat(self, env_mat: Dict[str, torch.Tensor]) -> Dict[str, StatItem]: + """Compute the statistics of the environment matrix for a single system. + + Parameters + ---------- + env_mat : torch.Tensor + The environment matrix. + + Returns + ------- + Dict[str, StatItem] + The statistics of the environment matrix. + """ + stats = {} + for kk, vv in env_mat.items(): + stats[kk] = StatItem( + number=vv.numel(), + sum=vv.sum().item(), + squared_sum=torch.square(vv).sum().item(), + ) + return stats + + +class EnvMatStatSeA(EnvMatStat): + """Environmental matrix statistics for the se_a environemntal matrix. + + Parameters + ---------- + descriptor : DescriptorBlock + The descriptor of the model. + """ + + def __init__(self, descriptor: "DescriptorBlock"): + super().__init__() + self.descriptor = descriptor + + def iter( + self, data: List[Dict[str, torch.Tensor]] + ) -> Iterator[Dict[str, StatItem]]: + """Get the iterator of the environment matrix. + + Parameters + ---------- + data : List[Dict[str, torch.Tensor]] + The environment matrix. + + Yields + ------ + Dict[str, StatItem] + The statistics of the environment matrix. + """ + zero_mean = torch.zeros( + self.descriptor.get_ntypes(), + self.descriptor.get_nsel(), + 4, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, + ) + one_stddev = torch.ones( + self.descriptor.get_ntypes(), + self.descriptor.get_nsel(), + 4, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, + ) + for system in data: + coord, atype, box, natoms = ( + system["coord"], + system["atype"], + system["box"], + system["natoms"], + ) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + self.descriptor.get_rcut(), + self.descriptor.get_sel(), + distinguish_types=self.descriptor.distinguish_types(), + box=box, + ) + env_mat, _, _ = prod_env_mat_se_a( + extended_coord, + nlist, + atype, + zero_mean, + one_stddev, + self.descriptor.get_rcut(), + # TODO: export rcut_smth from DescriptorBlock + self.descriptor.rcut_smth, + ) + env_mat = env_mat.view( + coord.shape[0], coord.shape[1], self.descriptor.get_nsel(), 4 + ) + + if "real_natoms_vec" not in system: + end_indexes = torch.cumsum(natoms[0, 2:], 0) + start_indexes = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=env.DEVICE), + end_indexes[:-1], + ] + ) + for type_i in range(self.descriptor.get_ntypes()): + dd = env_mat[ + :, start_indexes[type_i] : end_indexes[type_i], :, : + ] # all descriptors for this element + env_mats = {} + env_mats[f"r_{type_i}"] = dd[:, :, :, :1] + env_mats[f"a_{type_i}"] = dd[:, :, :, 1:] + yield self.compute_stat(env_mats) + else: + for frame_item in range(env_mat.shape[0]): + dd_ff = env_mat[frame_item] + atype_frame = atype[frame_item] + for type_i in range(self.descriptor.get_ntypes()): + type_idx = atype_frame == type_i + dd = dd_ff[type_idx] + dd = dd.reshape([-1, 4]) # typen_atoms * nnei, 4 + env_mats = {} + env_mats[f"r_{type_i}"] = dd[:, :1] + env_mats[f"a_{type_i}"] = dd[:, 1:] + yield self.compute_stat(env_mats) + + def get_hash(self) -> str: + """Get the hash of the environment matrix. + + Returns + ------- + str + The hash of the environment matrix. + """ + return get_hash( + { + "type": "se_a", + "ntypes": self.descriptor.get_ntypes(), + "rcut": round(self.descriptor.get_rcut(), 2), + "rcut_smth": round(self.descriptor.rcut_smth, 2), + "nsel": self.descriptor.get_nsel(), + "sel": self.descriptor.get_sel(), + "distinguish_types": self.descriptor.distinguish_types(), + } + ) + + def __call__(self): + avgs = self.get_avg() + stds = self.get_std() + + all_davg = [] + all_dstd = [] + for type_i in range(self.descriptor.get_ntypes()): + davgunit = [[avgs[f"r_{type_i}"], 0, 0, 0]] + dstdunit = [ + [ + stds[f"r_{type_i}"], + stds[f"a_{type_i}"], + stds[f"a_{type_i}"], + stds[f"a_{type_i}"], + ] + ] + davg = np.tile(davgunit, [self.descriptor.get_nsel(), 1]) + dstd = np.tile(dstdunit, [self.descriptor.get_nsel(), 1]) + all_davg.append(davg) + all_dstd.append(dstd) + mean = np.stack(all_davg) + stddev = np.stack(all_dstd) + return mean, stddev diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 76b2afe41b..051fddd14b 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging -import os import numpy as np import torch @@ -77,32 +76,3 @@ def compute_output_bias(energy, natoms, rcond=None): sys_tynatom = torch.cat(natoms)[:, 2:].cpu() energy_coef, _, _, _ = np.linalg.lstsq(sys_tynatom, sys_ener, rcond) return energy_coef - - -def process_stat_path( - stat_file_dict, stat_file_dir, model_params_dict, descriptor_cls, fitting_cls -): - if stat_file_dict is None: - stat_file_dict = {} - if "descriptor" in model_params_dict: - default_stat_file_name_descrpt = descriptor_cls.get_stat_name( - len(model_params_dict["type_map"]), - model_params_dict["descriptor"]["type"], - **model_params_dict["descriptor"], - ) - stat_file_dict["descriptor"] = default_stat_file_name_descrpt - if "fitting_net" in model_params_dict: - default_stat_file_name_fitting = fitting_cls.get_stat_name( - len(model_params_dict["type_map"]), - model_params_dict["fitting_net"].get("type", "ener"), - **model_params_dict["fitting_net"], - ) - stat_file_dict["fitting_net"] = default_stat_file_name_fitting - stat_file_path = { - key: os.path.join(stat_file_dir, stat_file_dict[key]) for key in stat_file_dict - } - - has_stat_file_path_list = [ - os.path.exists(stat_file_path[key]) for key in stat_file_dict - ] - return stat_file_path, all(has_stat_file_path_list) diff --git a/deepmd/utils/env_mat_stat.py b/deepmd/utils/env_mat_stat.py new file mode 100644 index 0000000000..2fa497b9b6 --- /dev/null +++ b/deepmd/utils/env_mat_stat.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + ABC, + abstractmethod, +) +from collections import ( + defaultdict, +) +from typing import ( + Dict, + Iterator, + List, + Optional, +) + +import numpy as np + +from deepmd.utils.path import ( + DPPath, +) + + +class StatItem: + """A class to store the statistics of the environment matrix. + + Parameters + ---------- + number : int + The total size of given array. + sum : float + The sum value of the matrix. + squared_sum : float + The sum squared value of the matrix. + """ + + def __init__(self, number: int = 0, sum: float = 0, squared_sum: float = 0) -> None: + self.number = number + self.sum = sum + self.squared_sum = squared_sum + + def __add__(self, other: "StatItem") -> "StatItem": + return StatItem( + number=self.number + other.number, + sum=self.sum + other.sum, + squared_sum=self.squared_sum + other.squared_sum, + ) + + def compute_avg(self, default: float = 0) -> float: + """Compute the average of the environment matrix. + + Parameters + ---------- + default : float, optional + The default value of the average, by default 0. + + Returns + ------- + float + The average of the environment matrix. + """ + if self.number == 0: + return default + return self.sum / self.number + + def compute_std(self, default: float = 1e-1, protection: float = 1e-2) -> float: + """Compute the standard deviation of the environment matrix. + + Parameters + ---------- + default : float, optional + The default value of the standard deviation, by default 1e-1. + protection : float, optional + The protection value for the standard deviation, by default 1e-2. + + Returns + ------- + float + The standard deviation of the environment matrix. + """ + if self.number == 0: + return default + val = np.sqrt( + self.squared_sum / self.number + - np.multiply(self.sum / self.number, self.sum / self.number) + ) + if np.abs(val) < protection: + val = protection + return val + + +class EnvMatStat(ABC): + """A base class to store and calculate the statistics of the environment matrix.""" + + def __init__(self) -> None: + super().__init__() + self.stats = defaultdict(StatItem) + + def compute_stats(self, data: List[Dict[str, np.ndarray]]) -> None: + """Compute the statistics of the environment matrix. + + Parameters + ---------- + data : List[Dict[str, np.ndarray]] + The environment matrix. + """ + if len(self.stats) > 0: + raise ValueError("The statistics has already been computed.") + for iter_stats in self.iter(data): + for kk in iter_stats: + self.stats[kk] += iter_stats[kk] + + @abstractmethod + def iter(self, data: List[Dict[str, np.ndarray]]) -> Iterator[Dict[str, StatItem]]: + """Get the iterator of the environment matrix. + + Parameters + ---------- + data : List[Dict[str, np.ndarray]] + The environment matrix. + + Yields + ------ + Dict[str, StatItem] + The statistics of the environment matrix. + """ + + def save_stats(self, path: DPPath) -> None: + """Save the statistics of the environment matrix. + + Parameters + ---------- + path : DPH5Path + The path to save the statistics of the environment matrix. + """ + if len(self.stats) == 0: + raise ValueError("The statistics hasn't been computed.") + for kk, vv in self.stats.items(): + path.mkdir(parents=True, exist_ok=True) + (path / kk).save_numpy(np.array([vv.number, vv.sum, vv.squared_sum])) + + def load_stats(self, path: DPPath) -> None: + """Load the statistics of the environment matrix. + + Parameters + ---------- + path : DPH5Path + The path to load the statistics of the environment matrix. + """ + if len(self.stats) > 0: + raise ValueError("The statistics has already been computed.") + for kk in path.glob("*"): + arr = kk.load_numpy() + self.stats[kk.name] = StatItem( + number=arr[0], + sum=arr[1], + squared_sum=arr[2], + ) + + def load_or_compute_stats( + self, data: List[Dict[str, np.ndarray]], path: Optional[DPPath] = None + ) -> None: + """Load the statistics of the environment matrix if it exists, otherwise compute and save it. + + Parameters + ---------- + path : DPH5Path + The path to load the statistics of the environment matrix. + data : List[Dict[str, np.ndarray]] + The environment matrix. + """ + if path is not None and path.is_dir(): + self.load_stats(path) + else: + self.compute_stats(data) + if path is not None: + self.save_stats(path) + + def get_avg(self, default: float = 0) -> Dict[str, float]: + """Get the average of the environment matrix. + + Parameters + ---------- + default : float, optional + The default value of the average, by default 0. + + Returns + ------- + Dict[str, float] + The average of the environment matrix. + """ + return {kk: vv.compute_avg(default=default) for kk, vv in self.stats.items()} + + def get_std( + self, default: float = 1e-1, protection: float = 1e-2 + ) -> Dict[str, float]: + """Get the standard deviation of the environment matrix. + + Parameters + ---------- + default : float, optional + The default value of the standard deviation, by default 1e-1. + protection : float, optional + The protection value for the standard deviation, by default 1e-2. + + Returns + ------- + Dict[str, float] + The standard deviation of the environment matrix. + """ + return { + kk: vv.compute_std(default=default, protection=protection) + for kk, vv in self.stats.items() + } diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index a8e4bc329f..c9a7cd8554 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -29,9 +29,11 @@ class DPPath(ABC): ---------- path : str path + mode : str, optional + mode, by default "r" """ - def __new__(cls, path: str): + def __new__(cls, path: str, mode: str = "r"): if cls is DPPath: if os.path.isdir(path): return super().__new__(DPOSPath) @@ -62,6 +64,16 @@ def load_txt(self, **kwargs) -> np.ndarray: loaded NumPy array """ + @abstractmethod + def save_numpy(self, arr: np.ndarray) -> None: + """Save NumPy array. + + Parameters + ---------- + arr : np.ndarray + NumPy array + """ + @abstractmethod def glob(self, pattern: str) -> List["DPPath"]: """Search path using the glob pattern. @@ -122,6 +134,23 @@ def __eq__(self, other) -> bool: def __hash__(self): return hash(str(self)) + @property + @abstractmethod + def name(self) -> str: + """Name of the path.""" + + @abstractmethod + def mkdir(self, parents: bool = False, exist_ok: bool = False) -> None: + """Make directory. + + Parameters + ---------- + parents : bool, optional + If true, any missing parents of this directory are created as well. + exist_ok : bool, optional + If true, no error will be raised if the target directory already exists. + """ + class DPOSPath(DPPath): """The OS path class to data system (DeepmdData) for real directories. @@ -130,10 +159,13 @@ class DPOSPath(DPPath): ---------- path : str path + mode : str, optional + mode, by default "r" """ - def __init__(self, path: str) -> None: + def __init__(self, path: str, mode: str = "r") -> None: super().__init__() + self.mode = mode if isinstance(path, Path): self.path = path else: @@ -159,6 +191,18 @@ def load_txt(self, **kwargs) -> np.ndarray: """ return np.loadtxt(str(self.path), **kwargs) + def save_numpy(self, arr: np.ndarray) -> None: + """Save NumPy array. + + Parameters + ---------- + arr : np.ndarray + NumPy array + """ + if self.mode == "r": + raise ValueError("Cannot save to read-only path") + np.save(str(self.path), arr) + def glob(self, pattern: str) -> List["DPPath"]: """Search path using the glob pattern. @@ -174,7 +218,7 @@ def glob(self, pattern: str) -> List["DPPath"]: """ # currently DPOSPath will only derivative DPOSPath # TODO: discuss if we want to mix DPOSPath and DPH5Path? - return [type(self)(p) for p in self.path.glob(pattern)] + return [type(self)(p, mode=self.mode) for p in self.path.glob(pattern)] def rglob(self, pattern: str) -> List["DPPath"]: """This is like calling :meth:`DPPath.glob()` with `**/` added in front @@ -190,7 +234,7 @@ def rglob(self, pattern: str) -> List["DPPath"]: List[DPPath] list of paths """ - return [type(self)(p) for p in self.path.rglob(pattern)] + return [type(self)(p, mode=self.mode) for p in self.path.rglob(pattern)] def is_file(self) -> bool: """Check if self is file.""" @@ -202,7 +246,7 @@ def is_dir(self) -> bool: def __truediv__(self, key: str) -> "DPPath": """Used for / operator.""" - return type(self)(self.path / key) + return type(self)(self.path / key, mode=self.mode) def __lt__(self, other: "DPOSPath") -> bool: """Whether this DPPath is less than other for sorting.""" @@ -212,6 +256,25 @@ def __str__(self) -> str: """Represent string.""" return str(self.path) + @property + def name(self) -> str: + """Name of the path.""" + return self.path.name + + def mkdir(self, parents: bool = False, exist_ok: bool = False) -> None: + """Make directory. + + Parameters + ---------- + parents : bool, optional + If true, any missing parents of this directory are created as well. + exist_ok : bool, optional + If true, no error will be raised if the target directory already exists. + """ + if self.mode == "r": + raise ValueError("Cannot mkdir to read-only path") + self.path.mkdir(parents=parents, exist_ok=exist_ok) + class DPH5Path(DPPath): """The path class to data system (DeepmdData) for HDF5 files. @@ -226,32 +289,37 @@ class DPH5Path(DPPath): ---------- path : str path + mode : str, optional + mode, by default "r" """ - def __init__(self, path: str) -> None: + def __init__(self, path: str, mode: str = "r") -> None: super().__init__() + self.mode = mode # we use "#" to split path # so we do not support file names containing #... s = path.split("#") self.root_path = s[0] - self.root = self._load_h5py(s[0]) + self.root = self._load_h5py(s[0], mode) # h5 path: default is the root path - self.name = s[1] if len(s) > 1 else "/" + self._name = s[1] if len(s) > 1 else "/" @classmethod @lru_cache(None) - def _load_h5py(cls, path: str) -> h5py.File: + def _load_h5py(cls, path: str, mode: str = "r") -> h5py.File: """Load hdf5 file. Parameters ---------- path : str path to hdf5 file + mode : str, optional + mode, by default 'r' """ # this method has cache to avoid duplicated # loading from different DPH5Path # However the file will be never closed? - return h5py.File(path, "r") + return h5py.File(path, mode) def load_numpy(self) -> np.ndarray: """Load NumPy array. @@ -261,7 +329,7 @@ def load_numpy(self) -> np.ndarray: np.ndarray loaded NumPy array """ - return self.root[self.name][:] + return self.root[self._name][:] def load_txt(self, dtype: Optional[np.dtype] = None, **kwargs) -> np.ndarray: """Load NumPy array from text. @@ -276,6 +344,18 @@ def load_txt(self, dtype: Optional[np.dtype] = None, **kwargs) -> np.ndarray: arr = arr.astype(dtype) return arr + def save_numpy(self, arr: np.ndarray) -> None: + """Save NumPy array. + + Parameters + ---------- + arr : np.ndarray + NumPy array + """ + if self._name in self._keys: + del self.root[self._name] + self.root.create_dataset(self._name, data=arr) + def glob(self, pattern: str) -> List["DPPath"]: """Search path using the glob pattern. @@ -290,9 +370,9 @@ def glob(self, pattern: str) -> List["DPPath"]: list of paths """ # got paths starts with current path first, which is faster - subpaths = [ii for ii in self._keys if ii.startswith(self.name)] + subpaths = [ii for ii in self._keys if ii.startswith(self._name)] return [ - type(self)(f"{self.root_path}#{pp}") + type(self)(f"{self.root_path}#{pp}", mode=self.mode) for pp in globfilter(subpaths, self._connect_path(pattern)) ] @@ -327,32 +407,56 @@ def _file_keys(cls, file: h5py.File) -> List[str]: def is_file(self) -> bool: """Check if self is file.""" - if self.name not in self._keys: + if self._name not in self._keys: return False - return isinstance(self.root[self.name], h5py.Dataset) + return isinstance(self.root[self._name], h5py.Dataset) def is_dir(self) -> bool: """Check if self is directory.""" - if self.name not in self._keys: + if self._name not in self._keys: return False - return isinstance(self.root[self.name], h5py.Group) + return isinstance(self.root[self._name], h5py.Group) def __truediv__(self, key: str) -> "DPPath": """Used for / operator.""" - return type(self)(f"{self.root_path}#{self._connect_path(key)}") + return type(self)(f"{self.root_path}#{self._connect_path(key)}", mode=self.mode) def _connect_path(self, path: str) -> str: """Connect self with path.""" - if self.name.endswith("/"): - return f"{self.name}{path}" - return f"{self.name}/{path}" + if self._name.endswith("/"): + return f"{self._name}{path}" + return f"{self._name}/{path}" def __lt__(self, other: "DPH5Path") -> bool: """Whether this DPPath is less than other for sorting.""" if self.root_path == other.root_path: - return self.name < other.name + return self._name < other._name return self.root_path < other.root_path def __str__(self) -> str: """Returns path of self.""" - return f"{self.root_path}#{self.name}" + return f"{self.root_path}#{self._name}" + + @property + def name(self) -> str: + """Name of the path.""" + return self._name.split("/")[-1] + + def mkdir(self, parents: bool = False, exist_ok: bool = False) -> None: + """Make directory. + + Parameters + ---------- + parents : bool, optional + If true, any missing parents of this directory are created as well. + exist_ok : bool, optional + If true, no error will be raised if the target directory already exists. + """ + if self._name in self._keys: + if not exist_ok: + raise FileExistsError(f"{self} already exists") + return + if parents: + self.root.require_group(self._name) + else: + self.root.create_group(self._name) diff --git a/source/tests/pt/test_stat.py b/source/tests/pt/test_stat.py index bc95575a5a..12af6ba866 100644 --- a/source/tests/pt/test_stat.py +++ b/source/tests/pt/test_stat.py @@ -2,16 +2,24 @@ import json import os import unittest +from abc import ( + ABC, + abstractmethod, +) from pathlib import ( Path, ) +import dpdata import numpy as np import torch from deepmd.pt.model.descriptor import ( DescrptSeA, ) +from deepmd.pt.model.descriptor.dpa1 import ( + DescrptDPA1, +) from deepmd.pt.utils import ( env, ) @@ -26,6 +34,7 @@ expand_sys_str, ) from deepmd.tf.descriptor.se_a import DescrptSeA as DescrptSeA_tf +from deepmd.tf.descriptor.se_atten import DescrptSeAtten as DescrptSeAtten_tf from deepmd.tf.fit.ener import ( EnerFitting, ) @@ -50,12 +59,29 @@ def compare(ut, base, given): ut.assertEqual(base, given) -class TestDataset(unittest.TestCase): +class DatasetTest(ABC): + @abstractmethod + def setup_data(self): + pass + + @abstractmethod + def setup_tf(self): + pass + + @abstractmethod + def setup_pt(self): + pass + + @abstractmethod + def tf_compute_input_stats(self): + pass + def setUp(self): with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin: content = fin.read() config = json.loads(content) - data_file = [str(Path(__file__).parent / "water/data/data_0")] + data_file = [self.setup_data()] + config["training"]["training_data"]["systems"] = data_file config["training"]["validation_data"]["systems"] = data_file model_config = config["model"] @@ -97,13 +123,7 @@ def setUp(self): self.dp_sampled = dp_make(dp_dataset, self.data_stat_nbatch, False) self.dp_merged = dp_merge(self.dp_sampled) self.dp_mesh = self.dp_merged.pop("default_mesh") - self.dp_d = DescrptSeA_tf( - rcut=self.rcut, - rcut_smth=self.rcut_smth, - sel=self.sel, - neuron=self.filter_neuron, - axis_neuron=self.axis_neuron, - ) + self.dp_d = self.setup_tf() def test_stat_output(self): def my_merge(energy, natoms): @@ -147,16 +167,9 @@ def test_stat_input(self): """ def test_descriptor(self): - coord = self.dp_merged["coord"] - atype = self.dp_merged["type"] - natoms = self.dp_merged["natoms_vec"] - box = self.dp_merged["box"] - self.dp_d.compute_input_stats(coord, box, atype, natoms, self.dp_mesh, {}) + self.tf_compute_input_stats() - my_en = DescrptSeA( - self.rcut, self.rcut_smth, self.sel, self.filter_neuron, self.axis_neuron - ) - my_en = my_en.sea # get the block who has stat as private vars + my_en = self.setup_pt() sampled = self.my_sampled for sys in sampled: for key in [ @@ -170,20 +183,102 @@ def test_descriptor(self): if key in sys.keys(): sys[key] = sys[key].to(env.DEVICE) stat_dict = my_en.compute_input_stats(sampled) - my_en.init_desc_stat(**stat_dict) my_en.mean = my_en.mean my_en.stddev = my_en.stddev - self.assertTrue( - np.allclose( - self.dp_d.davg.reshape([-1]), my_en.mean.cpu().reshape([-1]), rtol=0.01 - ) + np.testing.assert_allclose( + self.dp_d.davg.reshape([-1]), + my_en.mean.cpu().reshape([-1]), + rtol=1e-14, + atol=1e-14, + ) + np.testing.assert_allclose( + self.dp_d.dstd.reshape([-1]), + my_en.stddev.cpu().reshape([-1]), + rtol=1e-14, + atol=1e-14, ) - self.assertTrue( - np.allclose( - self.dp_d.dstd.reshape([-1]), - my_en.stddev.cpu().reshape([-1]), - rtol=0.01, - ) + + +class TestDatasetNoMixed(DatasetTest, unittest.TestCase): + def setup_data(self): + original_data = str(Path(__file__).parent / "water/data/data_0") + picked_data = str(Path(__file__).parent / "picked_data_for_test_stat") + dpdata.LabeledSystem(original_data, fmt="deepmd/npy")[:2].to_deepmd_npy( + picked_data + ) + self.mixed_type = False + return picked_data + + def setup_tf(self): + return DescrptSeA_tf( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + sel=self.sel, + neuron=self.filter_neuron, + axis_neuron=self.axis_neuron, + ) + + def setup_pt(self): + return DescrptSeA( + self.rcut, self.rcut_smth, self.sel, self.filter_neuron, self.axis_neuron + ).sea # get the block who has stat as private vars + + def tf_compute_input_stats(self): + coord = self.dp_merged["coord"] + atype = self.dp_merged["type"] + natoms = self.dp_merged["natoms_vec"] + box = self.dp_merged["box"] + self.dp_d.compute_input_stats(coord, box, atype, natoms, self.dp_mesh, {}) + + +class TestDatasetMixed(DatasetTest, unittest.TestCase): + def setup_data(self): + original_data = str(Path(__file__).parent / "water/data/data_0") + picked_data = str(Path(__file__).parent / "picked_data_for_test_stat") + dpdata.LabeledSystem(original_data, fmt="deepmd/npy")[:2].to_deepmd_npy_mixed( + picked_data + ) + self.mixed_type = True + return picked_data + + def setup_tf(self): + return DescrptSeAtten_tf( + ntypes=2, + rcut=self.rcut, + rcut_smth=self.rcut_smth, + sel=sum(self.sel), + neuron=self.filter_neuron, + axis_neuron=self.axis_neuron, + set_davg_zero=False, + ) + + def setup_pt(self): + return DescrptDPA1( + self.rcut, + self.rcut_smth, + sum(self.sel), + 2, + self.filter_neuron, + self.axis_neuron, + set_davg_zero=False, + ).se_atten + + def tf_compute_input_stats(self): + coord = self.dp_merged["coord"] + atype = self.dp_merged["type"] + natoms = self.dp_merged["natoms_vec"] + box = self.dp_merged["box"] + real_natoms_vec = self.dp_merged["real_natoms_vec"] + + self.dp_d.compute_input_stats( + coord, + box, + atype, + natoms, + self.dp_mesh, + {}, + mixed_type=True, + real_natoms_vec=real_natoms_vec, )