From f352a67da58c0c4e9cf1ff4deaed23ae585a49ce Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 15:34:16 -0500 Subject: [PATCH 01/40] checkpoint Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_a.py | 144 ++++++++++++++++++----------- deepmd/pt/utils/env_mat_stat.py | 35 +++++++ deepmd/utils/env_mat_stat.py | 84 +++++++++++++++++ 3 files changed, 209 insertions(+), 54 deletions(-) create mode 100644 deepmd/pt/utils/env_mat_stat.py create mode 100644 deepmd/utils/env_mat_stat.py diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index c722c2dc02..aa66fc2039 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( ClassVar, + Dict, + Iterator, List, Optional, ) @@ -21,6 +23,10 @@ PRECISION_DICT, RESERVED_PRECISON_DICT, ) +from deepmd.pt.utils.env_mat_stat import EnvMatStat as BaseEnvMatStat +from deepmd.utils.env_mat_stat import ( + StatItem, +) try: from typing import ( @@ -382,62 +388,92 @@ def __getitem__(self, key): else: raise KeyError(key) - def compute_input_stats(self, merged): - """Update mean and stddev for descriptor elements.""" - sumr = [] - suma = [] - sumn = [] - sumr2 = [] - suma2 = [] - for system in merged: - coord, atype, box, natoms = ( - system["coord"], - system["atype"], - system["box"], - system["natoms"], + class EnvMatStat(BaseEnvMatStat): + """A class to calculate the statistics of the environment matrix.""" + + def __init__(self, descriptor: "DescrptBlockSeA"): + self.descriptor = descriptor + self.ntypes = descriptor.get_ntypes() + self.rcut = descriptor.get_rcut() + self.rcut_smth = descriptor.rcut_smth + self.nsel = descriptor.get_nsel() + + def iter( + self, data: List[Dict[str, torch.Tensor]] + ) -> Iterator[Dict[str, StatItem]]: + """Get the iterator of the environment matrix. + + Parameters + ---------- + data : List[Dict[str, torch.Tensor]] + The environment matrix. + + Yields + ------ + Dict[str, StatItem] + The statistics of the environment matrix. + """ + zero_mean = torch.zeros( + self.ntypes, + self.nsel * 4, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, ) - ( - extended_coord, - extended_atype, - mapping, - nlist, - ) = extend_input_and_build_neighbor_list( - coord, - atype, - self.get_rcut(), - self.get_sel(), - distinguish_types=self.distinguish_types(), - box=box, + one_stddev = torch.ones( + self.ntypes, + self.nsel * 4, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, ) - env_mat, _, _ = prod_env_mat_se_a( - extended_coord, - nlist, - atype, - self.mean, - self.stddev, - self.rcut, - self.rcut_smth, - ) - sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( - env_mat.detach().cpu().numpy(), self.ndescrpt, natoms - ) - sumr.append(sysr) - suma.append(sysa) - sumn.append(sysn) - sumr2.append(sysr2) - suma2.append(sysa2) - sumr = np.sum(sumr, axis=0) - suma = np.sum(suma, axis=0) - sumn = np.sum(sumn, axis=0) - sumr2 = np.sum(sumr2, axis=0) - suma2 = np.sum(suma2, axis=0) - return { - "sumr": sumr, - "suma": suma, - "sumn": sumn, - "sumr2": sumr2, - "suma2": suma2, - } + for system in data: + coord, atype, box, natoms = ( + system["coord"], + system["atype"], + system["box"], + system["natoms"], + ) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + self.descriptor.get_rcut(), + self.descriptor.get_sel(), + distinguish_types=self.descriptor.distinguish_types(), + box=box, + ) + env_mat, _, _ = prod_env_mat_se_a( + extended_coord, + nlist, + atype, + zero_mean, + one_stddev, + self.descriptor.get_rcut(), + self.descriptor.rcut_smth, + ) + env_mat = env_mat.view(-1, self.nsel, 4) + env_mats = {} + end_indexes = torch.cumsum(natoms[0, 2:], 0) + start_indexes = torch.cat( + [ + torch.zeros([], dtype=torch.int32, device=env.DEVICE), + end_indexes[:-1], + ] + ) + for type_i in range(self.ntypes): + dd = env_mat[ + :, start_indexes[type_i] : end_indexes[type_i], : + ] # all descriptors for this element + env_mats[f"r_{type_i}"] = dd[:, :1] + env_mats[f"a_{type_i}"] = dd[:, 1:] + yield self.compute_stat(env_mats) + + def compute_input_stats(self, merged: list[dict]): + """Update mean and stddev for descriptor elements.""" + stats = self.EnvMatStat(self).compute_stats(merged) def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2, **kwargs): all_davg = [] diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py new file mode 100644 index 0000000000..7386adf35f --- /dev/null +++ b/deepmd/pt/utils/env_mat_stat.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, +) + +import torch + +from deepmd.utils.env_mat_stat import EnvMatStat as BaseEnvMatStat +from deepmd.utils.env_mat_stat import ( + StatItem, +) + + +class EnvMatStat(BaseEnvMatStat): + def compute_stat(self, env_mat: Dict[str, torch.Tensor]) -> Dict[str, StatItem]: + """Compute the statistics of the environment matrix for a single system. + + Parameters + ---------- + env_mat : torch.Tensor + The environment matrix. + + Returns + ------- + Dict[str, StatItem] + The statistics of the environment matrix. + """ + stats = {} + for kk, vv in env_mat.items(): + stats[kk] = StatItem( + number=vv.numel(), + mean=vv.mean().item(), + squared_mean=torch.square(vv).mean().item(), + ) + return stats diff --git a/deepmd/utils/env_mat_stat.py b/deepmd/utils/env_mat_stat.py new file mode 100644 index 0000000000..2be563e47c --- /dev/null +++ b/deepmd/utils/env_mat_stat.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + ABC, + abstractmethod, +) +from collections import ( + defaultdict, +) +from typing import ( + Dict, + Iterator, + List, +) + +import numpy as np + + +class StatItem: + """A class to store the statistics of the environment matrix. + + Parameters + ---------- + number : int + The total size of given array. + mean : float + The mean value of the matrix. + squared_mean : float + The mean squared value of the matrix. + """ + + def __init__( + self, number: int = 0, mean: float = 0, squared_mean: float = 0 + ) -> None: + self.number = number + self.mean = mean + self.squared_mean = squared_mean + + def __add__(self, other: "StatItem") -> "StatItem": + self_frac = self.number / (self.number + other.number) + other_frac = 1 - self_frac + return StatItem( + number=self.number + other.number, + mean=self_frac * self.mean + other_frac * other.mean, + squared_mean=self_frac * self.squared_mean + + other_frac * other.squared_mean, + ) + + +class EnvMatStat(ABC): + """A base class to store and calculate the statistics of the environment matrix.""" + + def compute_stats(self, data: List[Dict[str, np.ndarray]]) -> Dict[str, StatItem]: + """Compute the statistics of the environment matrix. + + Parameters + ---------- + data : List[Dict[str, np.ndarray]] + The environment matrix. + + Returns + ------- + Dict[str, StatItem] + The statistics of the environment matrix. + """ + stats = defaultdict(StatItem) + for iter_stats in self.iter(data): + for kk in iter_stats: + stats[kk] += iter_stats[kk] + return stats + + @abstractmethod + def iter(self, data: List[Dict[str, np.ndarray]]) -> Iterator[Dict[str, StatItem]]: + """Get the iterator of the environment matrix. + + Parameters + ---------- + data : List[Dict[str, np.ndarray]] + The environment matrix. + + Yields + ------ + Dict[str, StatItem] + The statistics of the environment matrix. + """ From 1c3e2bb3cba5103213faf6387842189fb92923aa Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 16:40:36 -0500 Subject: [PATCH 02/40] record sum Signed-off-by: Jinzhe Zeng --- deepmd/pt/utils/env_mat_stat.py | 4 ++-- deepmd/utils/env_mat_stat.py | 23 +++++++++-------------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index 7386adf35f..366217ef54 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -29,7 +29,7 @@ def compute_stat(self, env_mat: Dict[str, torch.Tensor]) -> Dict[str, StatItem]: for kk, vv in env_mat.items(): stats[kk] = StatItem( number=vv.numel(), - mean=vv.mean().item(), - squared_mean=torch.square(vv).mean().item(), + sum=vv.sum().item(), + squared_sum=torch.square(vv).sum().item(), ) return stats diff --git a/deepmd/utils/env_mat_stat.py b/deepmd/utils/env_mat_stat.py index 2be563e47c..7564d857f8 100644 --- a/deepmd/utils/env_mat_stat.py +++ b/deepmd/utils/env_mat_stat.py @@ -22,27 +22,22 @@ class StatItem: ---------- number : int The total size of given array. - mean : float - The mean value of the matrix. - squared_mean : float - The mean squared value of the matrix. + sum : float + The sum value of the matrix. + squared_sum : float + The sum squared value of the matrix. """ - def __init__( - self, number: int = 0, mean: float = 0, squared_mean: float = 0 - ) -> None: + def __init__(self, number: int = 0, sum: float = 0, squared_sum: float = 0) -> None: self.number = number - self.mean = mean - self.squared_mean = squared_mean + self.sum = sum + self.squared_sum = squared_sum def __add__(self, other: "StatItem") -> "StatItem": - self_frac = self.number / (self.number + other.number) - other_frac = 1 - self_frac return StatItem( number=self.number + other.number, - mean=self_frac * self.mean + other_frac * other.mean, - squared_mean=self_frac * self.squared_mean - + other_frac * other.squared_mean, + sum=self.sum + other.sum, + squared_sum=self.squared_sum + other.squared_sum, ) From d6bf4abcfc9e3413269bad1bec0cb699f5f73573 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 16:49:47 -0500 Subject: [PATCH 03/40] compute_std Signed-off-by: Jinzhe Zeng --- deepmd/utils/env_mat_stat.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/deepmd/utils/env_mat_stat.py b/deepmd/utils/env_mat_stat.py index 7564d857f8..70f47dfcb5 100644 --- a/deepmd/utils/env_mat_stat.py +++ b/deepmd/utils/env_mat_stat.py @@ -40,6 +40,29 @@ def __add__(self, other: "StatItem") -> "StatItem": squared_sum=self.squared_sum + other.squared_sum, ) + def compute_std(self, default: float = 1e-1) -> float: + """Compute the standard deviation of the environment matrix. + + Parameters + ---------- + default : float, optional + The default value of the standard deviation, by default 1e-1. + + Returns + ------- + float + The standard deviation of the environment matrix. + """ + if self.number == 0: + return default + val = np.sqrt( + self.squared_sum / self.number + - np.multiply(self.sum / self.number, self.sum / self.number) + ) + if np.abs(val) < 1e-2: + val = 1e-2 + return val + class EnvMatStat(ABC): """A base class to store and calculate the statistics of the environment matrix.""" From 4f029b0735ad645b7c58c7ae1a70a3cf32a4a008 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 16:57:37 -0500 Subject: [PATCH 04/40] protection Signed-off-by: Jinzhe Zeng --- deepmd/utils/env_mat_stat.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/deepmd/utils/env_mat_stat.py b/deepmd/utils/env_mat_stat.py index 70f47dfcb5..f7dbd237ae 100644 --- a/deepmd/utils/env_mat_stat.py +++ b/deepmd/utils/env_mat_stat.py @@ -40,13 +40,15 @@ def __add__(self, other: "StatItem") -> "StatItem": squared_sum=self.squared_sum + other.squared_sum, ) - def compute_std(self, default: float = 1e-1) -> float: + def compute_std(self, default: float = 1e-1, protection: float = 1e-2) -> float: """Compute the standard deviation of the environment matrix. Parameters ---------- default : float, optional The default value of the standard deviation, by default 1e-1. + protection : float, optional + The protection value for the standard deviation, by default 1e-2. Returns ------- @@ -59,8 +61,8 @@ def compute_std(self, default: float = 1e-1) -> float: self.squared_sum / self.number - np.multiply(self.sum / self.number, self.sum / self.number) ) - if np.abs(val) < 1e-2: - val = 1e-2 + if np.abs(val) < protection: + val = protection return val From afd6d6a49193ecaf63f4a21ad279b020c9e5caac Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 17:20:44 -0500 Subject: [PATCH 05/40] save stat Signed-off-by: Jinzhe Zeng --- deepmd/utils/env_mat_stat.py | 73 +++++++++++++++++++++++++++++++----- deepmd/utils/path.py | 46 +++++++++++++++++++++++ 2 files changed, 110 insertions(+), 9 deletions(-) diff --git a/deepmd/utils/env_mat_stat.py b/deepmd/utils/env_mat_stat.py index f7dbd237ae..da7a336da8 100644 --- a/deepmd/utils/env_mat_stat.py +++ b/deepmd/utils/env_mat_stat.py @@ -10,10 +10,15 @@ Dict, Iterator, List, + Optional, ) import numpy as np +from deepmd.utils.path import ( + DPPath, +) + class StatItem: """A class to store the statistics of the environment matrix. @@ -69,24 +74,23 @@ def compute_std(self, default: float = 1e-1, protection: float = 1e-2) -> float: class EnvMatStat(ABC): """A base class to store and calculate the statistics of the environment matrix.""" - def compute_stats(self, data: List[Dict[str, np.ndarray]]) -> Dict[str, StatItem]: + def __init__(self) -> None: + super().__init__() + self.stats = defaultdict(StatItem) + + def compute_stats(self, data: List[Dict[str, np.ndarray]]) -> None: """Compute the statistics of the environment matrix. Parameters ---------- data : List[Dict[str, np.ndarray]] The environment matrix. - - Returns - ------- - Dict[str, StatItem] - The statistics of the environment matrix. """ - stats = defaultdict(StatItem) + if len(self.stats) > 0: + raise ValueError("The statistics has already been computed.") for iter_stats in self.iter(data): for kk in iter_stats: - stats[kk] += iter_stats[kk] - return stats + self.stats[kk] += iter_stats[kk] @abstractmethod def iter(self, data: List[Dict[str, np.ndarray]]) -> Iterator[Dict[str, StatItem]]: @@ -102,3 +106,54 @@ def iter(self, data: List[Dict[str, np.ndarray]]) -> Iterator[Dict[str, StatItem Dict[str, StatItem] The statistics of the environment matrix. """ + + def save_stats(self, path: DPPath) -> None: + """Save the statistics of the environment matrix. + + Parameters + ---------- + path : DPH5Path + The path to save the statistics of the environment matrix. + """ + if len(self.stats) == 0: + raise ValueError("The statistics hasn't been computed.") + for kk, vv in self.stats.items(): + (path / kk / "number").save(vv.number) + (path / kk / "sum").save(vv.sum) + (path / kk / "squared_sum").save(vv.squared_sum) + + def load_stats(self, path: DPPath) -> None: + """Load the statistics of the environment matrix. + + Parameters + ---------- + path : DPH5Path + The path to load the statistics of the environment matrix. + """ + if len(self.stats) > 0: + raise ValueError("The statistics has already been computed.") + for kk in path.glob("*"): + self.stats[kk.name] = StatItem( + number=(kk / "number").load_numpy().item(), + sum=(kk / "sum").load_numpy().item(), + squared_sum=(kk / "squared_sum").load_numpy().item(), + ) + + def load_or_compute_stats( + self, data: List[Dict[str, np.ndarray]], path: Optional[DPPath] = None + ) -> None: + """Load the statistics of the environment matrix if it exists, otherwise compute and save it. + + Parameters + ---------- + path : DPH5Path + The path to load the statistics of the environment matrix. + data : List[Dict[str, np.ndarray]] + The environment matrix. + """ + if path is not None and path.is_dir(): + self.load_stats(path) + else: + self.compute_stats(data) + if path is not None: + self.save_stats(path) diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index a8e4bc329f..4579795cf4 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -62,6 +62,16 @@ def load_txt(self, **kwargs) -> np.ndarray: loaded NumPy array """ + @abstractmethod + def save_numpy(self, arr: np.ndarray) -> None: + """Save NumPy array. + + Parameters + ---------- + arr : np.ndarray + NumPy array + """ + @abstractmethod def glob(self, pattern: str) -> List["DPPath"]: """Search path using the glob pattern. @@ -122,6 +132,11 @@ def __eq__(self, other) -> bool: def __hash__(self): return hash(str(self)) + @property + @abstractmethod + def name(self) -> str: + """Name of the path.""" + class DPOSPath(DPPath): """The OS path class to data system (DeepmdData) for real directories. @@ -159,6 +174,16 @@ def load_txt(self, **kwargs) -> np.ndarray: """ return np.loadtxt(str(self.path), **kwargs) + def save_numpy(self, arr: np.ndarray) -> None: + """Save NumPy array. + + Parameters + ---------- + arr : np.ndarray + NumPy array + """ + np.save(str(self.path), arr) + def glob(self, pattern: str) -> List["DPPath"]: """Search path using the glob pattern. @@ -212,6 +237,11 @@ def __str__(self) -> str: """Represent string.""" return str(self.path) + @property + def name(self) -> str: + """Name of the path.""" + return self.path.name + class DPH5Path(DPPath): """The path class to data system (DeepmdData) for HDF5 files. @@ -276,6 +306,18 @@ def load_txt(self, dtype: Optional[np.dtype] = None, **kwargs) -> np.ndarray: arr = arr.astype(dtype) return arr + def save_numpy(self, arr: np.ndarray) -> None: + """Save NumPy array. + + Parameters + ---------- + arr : np.ndarray + NumPy array + """ + if self.name in self._keys: + del self.root[self.name] + self.root.create_dataset(self.name, data=arr) + def glob(self, pattern: str) -> List["DPPath"]: """Search path using the glob pattern. @@ -356,3 +398,7 @@ def __lt__(self, other: "DPH5Path") -> bool: def __str__(self) -> str: """Returns path of self.""" return f"{self.root_path}#{self.name}" + + def name(self) -> str: + """Name of the path.""" + return self.name.split("/")[-1] From 2f67aac871101667f12b7da33905f0fb7a4943cf Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 17:38:46 -0500 Subject: [PATCH 06/40] get std Signed-off-by: Jinzhe Zeng --- deepmd/common.py | 14 ++++++++++++++ deepmd/utils/env_mat_stat.py | 22 ++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/deepmd/common.py b/deepmd/common.py index 05d02234b4..691cc262df 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -5,6 +5,9 @@ import platform import shutil import warnings +from hashlib import ( + sha1, +) from pathlib import ( Path, ) @@ -299,3 +302,14 @@ def symlink_prefix_files(old_prefix: str, new_prefix: str): os.symlink(os.path.relpath(ori_ff, os.path.dirname(new_ff)), new_ff) else: shutil.copyfile(ori_ff, new_ff) + + +def get_hash(obj) -> str: + """Get hash of object. + + Parameters + ---------- + obj + object to hash + """ + return sha1(json.dumps(obj).encode("utf-8")).hexdigest() diff --git a/deepmd/utils/env_mat_stat.py b/deepmd/utils/env_mat_stat.py index da7a336da8..e308d66c98 100644 --- a/deepmd/utils/env_mat_stat.py +++ b/deepmd/utils/env_mat_stat.py @@ -157,3 +157,25 @@ def load_or_compute_stats( self.compute_stats(data) if path is not None: self.save_stats(path) + + def get_std( + self, default: float = 1e-1, protection: float = 1e-2 + ) -> Dict[str, float]: + """Get the standard deviation of the environment matrix. + + Parameters + ---------- + default : float, optional + The default value of the standard deviation, by default 1e-1. + protection : float, optional + The protection value for the standard deviation, by default 1e-2. + + Returns + ------- + Dict[str, float] + The standard deviation of the environment matrix. + """ + return { + kk: vv.compute_std(default=default, protection=protection) + for kk, vv in self.stats.items() + } From 89d413ccbf6244591f564da8c4f8b2449ebd55f3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 17:42:27 -0500 Subject: [PATCH 07/40] compute avg Signed-off-by: Jinzhe Zeng --- deepmd/utils/env_mat_stat.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/deepmd/utils/env_mat_stat.py b/deepmd/utils/env_mat_stat.py index e308d66c98..36603b6d97 100644 --- a/deepmd/utils/env_mat_stat.py +++ b/deepmd/utils/env_mat_stat.py @@ -45,6 +45,23 @@ def __add__(self, other: "StatItem") -> "StatItem": squared_sum=self.squared_sum + other.squared_sum, ) + def compute_avg(self, default: float = 0) -> float: + """Compute the average of the environment matrix. + + Parameters + ---------- + default : float, optional + The default value of the average, by default 0. + + Returns + ------- + float + The average of the environment matrix. + """ + if self.number == 0: + return default + return self.sum / self.number + def compute_std(self, default: float = 1e-1, protection: float = 1e-2) -> float: """Compute the standard deviation of the environment matrix. @@ -158,6 +175,21 @@ def load_or_compute_stats( if path is not None: self.save_stats(path) + def get_avg(self, default: float = 0) -> Dict[str, float]: + """Get the average of the environment matrix. + + Parameters + ---------- + default : float, optional + The default value of the average, by default 0. + + Returns + ------- + Dict[str, float] + The average of the environment matrix. + """ + return {kk: vv.compute_avg(default=default) for kk, vv in self.stats.items()} + def get_std( self, default: float = 1e-1, protection: float = 1e-2 ) -> Dict[str, float]: From 8b6e24ee0e3c910619b48c34730aeda75fad7930 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 17:51:25 -0500 Subject: [PATCH 08/40] sea looks good Signed-off-by: Jinzhe Zeng --- .../descriptor/make_base_descriptor.py | 12 +++--- deepmd/pt/model/descriptor/se_a.py | 39 ++++++++++++------- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py index 29d3ad6d92..4de4a7f139 100644 --- a/deepmd/dpmodel/descriptor/make_base_descriptor.py +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -9,6 +9,10 @@ Optional, ) +from deepmd.utils.path import ( + DPPath, +) + def make_base_descriptor( t_tensor, @@ -69,14 +73,12 @@ def distinguish_types(self) -> bool: """ pass - def compute_input_stats(self, merged): + def compute_input_stats( + self, merged: list[dict], path: Optional[DPPath] = None + ): """Update mean and stddev for descriptor elements.""" raise NotImplementedError - def init_desc_stat(self, **kwargs): - """Initialize the model bias by the statistics.""" - raise NotImplementedError - @abstractmethod def fwd( self, diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 68ffc4c41c..ba07512c89 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -11,10 +11,12 @@ import numpy as np import torch +from deepmd.common import ( + get_hash, +) from deepmd.pt.model.descriptor import ( Descriptor, DescriptorBlock, - compute_std, prod_env_mat_se_a, ) from deepmd.pt.utils import ( @@ -28,6 +30,9 @@ from deepmd.utils.env_mat_stat import ( StatItem, ) +from deepmd.utils.path import ( + DPPath, +) try: from typing import ( @@ -476,32 +481,38 @@ def iter( env_mats[f"a_{type_i}"] = dd[:, 1:] yield self.compute_stat(env_mats) - def compute_input_stats(self, merged: list[dict]): + def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - stats = self.EnvMatStat(self).compute_stats(merged) + if path is not None: + path = path / get_hash( + { + "ntypes": self.ntypes, + "rcut": round(self.rcut, 2), + "rcut_smth": round(self.rcut_smth, 2), + "sel": self.nsel, + } + ) + env_mat_stat = self.EnvMatStat(self) + env_mat_stat.load_or_compute_stats(merged, path) + avgs = env_mat_stat.get_avg() + stds = env_mat_stat.get_std() - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2, **kwargs): all_davg = [] all_dstd = [] for type_i in range(self.ntypes): - davgunit = [[sumr[type_i] / (sumn[type_i] + 1e-15), 0, 0, 0]] + davgunit = [[avgs[f"r_{type_i}"], 0, 0, 0]] dstdunit = [ [ - compute_std(sumr2[type_i], sumr[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), + stds[f"r_{type_i}"], + stds[f"a_{type_i}"], + stds[f"a_{type_i}"], + stds[f"a_{type_i}"], ] ] davg = np.tile(davgunit, [self.nnei, 1]) dstd = np.tile(dstdunit, [self.nnei, 1]) all_davg.append(davg) all_dstd.append(dstd) - self.sumr = sumr - self.suma = suma - self.sumn = sumn - self.sumr2 = sumr2 - self.suma2 = suma2 if not self.set_davg_zero: mean = np.stack(all_davg) self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) From 4736ae57ae932021fc897882bc84220eb25589ec Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 17:55:38 -0500 Subject: [PATCH 09/40] rm init_desc_stat Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/se_e2_a.py | 4 ---- deepmd/pt/model/descriptor/descriptor.py | 6 ------ deepmd/pt/model/descriptor/dpa1.py | 6 ------ deepmd/pt/model/descriptor/dpa2.py | 11 ----------- deepmd/pt/model/descriptor/gaussian_lcc.py | 3 --- deepmd/pt/model/descriptor/hybrid.py | 21 --------------------- deepmd/pt/model/descriptor/repformers.py | 8 -------- deepmd/pt/model/descriptor/se_a.py | 6 ------ deepmd/pt/model/descriptor/se_atten.py | 8 -------- source/tests/pt/test_stat.py | 1 - 10 files changed, 74 deletions(-) diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 26258f4ac7..02619a550f 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -235,10 +235,6 @@ def compute_input_stats(self, merged): """Update mean and stddev for descriptor elements.""" raise NotImplementedError - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): - """Initialize the model bias by the statistics.""" - raise NotImplementedError - def cal_g( self, ss, diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 63dbe0eb19..75978ae3de 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -129,7 +129,6 @@ def compute_or_load_stat( else: # load the statistics results assert stat_file_path is not None, "No stat file to load!" result_dict = self.load_stats(type_map, stat_file_path) - self.init_desc_stat(**result_dict) def save_stats(self, result_dict, stat_file_path: Union[str, List[str]]): """ @@ -310,10 +309,6 @@ def compute_input_stats(self, merged): """Update mean and stddev for DescriptorBlock elements.""" raise NotImplementedError - def init_desc_stat(self, **kwargs): - """Initialize mean and stddev by the statistics.""" - raise NotImplementedError - def share_params(self, base_class, shared_level, resume=False): assert ( self.__class__ == base_class.__class__ @@ -343,7 +338,6 @@ def share_params(self, base_class, shared_level, resume=False): "sumr2": sumr2_base + sumr2, "suma2": suma2_base + suma2, } - base_class.init_desc_stat(**stat_dict) self.mean = base_class.mean self.stddev = base_class.stddev # self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 76cff174af..2bf177966e 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -122,12 +122,6 @@ def dim_emb(self): def compute_input_stats(self, merged): return self.se_atten.compute_input_stats(merged) - def init_desc_stat( - self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs - ): - assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) - self.se_atten.init_desc_stat(sumr, suma, sumn, sumr2, suma2) - @classmethod def get_stat_name( cls, ntypes, type_name, rcut=None, rcut_smth=None, sel=None, **kwargs diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 6cefaf6f38..eeae098fd3 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -302,17 +302,7 @@ def compute_input_stats(self, merged): sumn.append(tmp_stat_dict["sumn"]) sumr2.append(tmp_stat_dict["sumr2"]) suma2.append(tmp_stat_dict["suma2"]) - return { - "sumr": sumr, - "suma": suma, - "sumn": sumn, - "sumr2": sumr2, - "suma2": suma2, - } - def init_desc_stat( - self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs - ): assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) for ii, descrpt in enumerate([self.repinit, self.repformers]): stat_dict_ii = { @@ -322,7 +312,6 @@ def init_desc_stat( "sumr2": sumr2[ii], "suma2": suma2[ii], } - descrpt.init_desc_stat(**stat_dict_ii) @classmethod def get_stat_name( diff --git a/deepmd/pt/model/descriptor/gaussian_lcc.py b/deepmd/pt/model/descriptor/gaussian_lcc.py index 0972b90279..4bf9c61814 100644 --- a/deepmd/pt/model/descriptor/gaussian_lcc.py +++ b/deepmd/pt/model/descriptor/gaussian_lcc.py @@ -158,9 +158,6 @@ def compute_input_stats(self, merged): """Update mean and stddev for descriptor elements.""" return [], [], [], [], [] - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2, **kwargs): - pass - def forward( self, extended_coord, diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index 511ac5e79b..c3b8610a59 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -167,27 +167,6 @@ def compute_input_stats(self, merged): sumn.append(tmp_stat_dict["sumn"]) sumr2.append(tmp_stat_dict["sumr2"]) suma2.append(tmp_stat_dict["suma2"]) - return { - "sumr": sumr, - "suma": suma, - "sumn": sumn, - "sumr2": sumr2, - "suma2": suma2, - } - - def init_desc_stat( - self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs - ): - assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) - for ii, descrpt in enumerate(self.descriptor_list): - stat_dict_ii = { - "sumr": sumr[ii], - "suma": suma[ii], - "sumn": sumn[ii], - "sumr2": sumr2[ii], - "suma2": suma2[ii], - } - descrpt.init_desc_stat(**stat_dict_ii) def forward( self, diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index de2b5f3565..3da60b89bd 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -329,15 +329,7 @@ def compute_input_stats(self, merged): sumn = np.sum(sumn, axis=0) sumr2 = np.sum(sumr2, axis=0) suma2 = np.sum(suma2, axis=0) - return { - "sumr": sumr, - "suma": suma, - "sumn": sumn, - "sumr2": sumr2, - "suma2": suma2, - } - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2, **kwargs): all_davg = [] all_dstd = [] for type_i in range(self.ntypes): diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index ba07512c89..2a5572449a 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -126,12 +126,6 @@ def compute_input_stats(self, merged): """Update mean and stddev for descriptor elements.""" return self.sea.compute_input_stats(merged) - def init_desc_stat( - self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs - ): - assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) - self.sea.init_desc_stat(sumr, suma, sumn, sumr2, suma2) - @classmethod def get_stat_name( cls, ntypes, type_name, rcut=None, rcut_smth=None, sel=None, **kwargs diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index e1c9942d92..a729dfcc28 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -245,15 +245,7 @@ def compute_input_stats(self, merged): sumn = np.sum(sumn, axis=0) sumr2 = np.sum(sumr2, axis=0) suma2 = np.sum(suma2, axis=0) - return { - "sumr": sumr, - "suma": suma, - "sumn": sumn, - "sumr2": sumr2, - "suma2": suma2, - } - def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2, **kwargs): all_davg = [] all_dstd = [] for type_i in range(self.ntypes): diff --git a/source/tests/pt/test_stat.py b/source/tests/pt/test_stat.py index bc95575a5a..f7e557e371 100644 --- a/source/tests/pt/test_stat.py +++ b/source/tests/pt/test_stat.py @@ -170,7 +170,6 @@ def test_descriptor(self): if key in sys.keys(): sys[key] = sys[key].to(env.DEVICE) stat_dict = my_en.compute_input_stats(sampled) - my_en.init_desc_stat(**stat_dict) my_en.mean = my_en.mean my_en.stddev = my_en.stddev self.assertTrue( From 6cb1275034878fabba5da69fb851569a26a54f0e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 17:58:07 -0500 Subject: [PATCH 10/40] rm get_stat_name Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/descriptor.py | 13 --------- deepmd/pt/model/descriptor/dpa1.py | 13 --------- deepmd/pt/model/descriptor/dpa2.py | 35 ------------------------ deepmd/pt/model/descriptor/se_a.py | 13 --------- deepmd/pt/model/task/ener.py | 20 -------------- deepmd/pt/model/task/fitting.py | 13 --------- deepmd/pt/utils/stat.py | 26 +----------------- 7 files changed, 1 insertion(+), 132 deletions(-) diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 75978ae3de..05dcedb4f6 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -64,19 +64,6 @@ class SomeDescript(Descriptor): """ return Descriptor.__plugins.register(key) - @classmethod - def get_stat_name(cls, ntypes, type_name, **kwargs): - """ - Get the name for the statistic file of the descriptor. - Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name. - """ - if cls is not Descriptor: - raise NotImplementedError("get_stat_name is not implemented!") - descrpt_type = type_name - return Descriptor.__plugins.plugins[descrpt_type].get_stat_name( - ntypes, type_name, **kwargs - ) - @classmethod def get_data_process_key(cls, config): """ diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 2bf177966e..3f82bb38cf 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -122,19 +122,6 @@ def dim_emb(self): def compute_input_stats(self, merged): return self.se_atten.compute_input_stats(merged) - @classmethod - def get_stat_name( - cls, ntypes, type_name, rcut=None, rcut_smth=None, sel=None, **kwargs - ): - """ - Get the name for the statistic file of the descriptor. - Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name. - """ - descrpt_type = type_name - assert descrpt_type in ["dpa1", "se_atten"] - assert all(x is not None for x in [rcut, rcut_smth, sel]) - return f"stat_file_descrpt_dpa1_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz" - @classmethod def get_data_process_key(cls, config): """ diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index eeae098fd3..5fa9caa244 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -313,41 +313,6 @@ def compute_input_stats(self, merged): "suma2": suma2[ii], } - @classmethod - def get_stat_name( - cls, - ntypes, - type_name, - repinit_rcut=None, - repinit_rcut_smth=None, - repinit_nsel=None, - repformer_rcut=None, - repformer_rcut_smth=None, - repformer_nsel=None, - **kwargs, - ): - """ - Get the name for the statistic file of the descriptor. - Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name. - """ - descrpt_type = type_name - assert descrpt_type in ["dpa2"] - assert all( - x is not None - for x in [ - repinit_rcut, - repinit_rcut_smth, - repinit_nsel, - repformer_rcut, - repformer_rcut_smth, - repformer_nsel, - ] - ) - return ( - f"stat_file_descrpt_dpa2_repinit_rcut{repinit_rcut:.2f}_smth{repinit_rcut_smth:.2f}_sel{repinit_nsel}" - f"_repformer_rcut{repformer_rcut:.2f}_smth{repformer_rcut_smth:.2f}_sel{repformer_nsel}_ntypes{ntypes}.npz" - ) - @classmethod def get_data_process_key(cls, config): """ diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 2a5572449a..2b2d4c89c7 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -126,19 +126,6 @@ def compute_input_stats(self, merged): """Update mean and stddev for descriptor elements.""" return self.sea.compute_input_stats(merged) - @classmethod - def get_stat_name( - cls, ntypes, type_name, rcut=None, rcut_smth=None, sel=None, **kwargs - ): - """ - Get the name for the statistic file of the descriptor. - Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name. - """ - descrpt_type = type_name - assert descrpt_type in ["se_e2_a"] - assert all(x is not None for x in [rcut, rcut_smth, sel]) - return f"stat_file_descrpt_sea_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz" - @classmethod def get_data_process_key(cls, config): """ diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index f1dad4c58d..4ab6ea79a6 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -223,16 +223,6 @@ def __init__( **kwargs, ) - @classmethod - def get_stat_name(cls, ntypes, type_name="ener", **kwargs): - """ - Get the name for the statistic file of the fitting. - Usually use the combination of fitting net name and ntypes as the statistic file name. - """ - fitting_type = type_name - assert fitting_type in ["ener"] - return f"stat_file_fitting_ener_ntypes{ntypes}.npz" - @Fitting.register("direct_force") @Fitting.register("direct_force_ener") @@ -325,16 +315,6 @@ def serialize(self) -> dict: def deserialize(cls) -> "EnergyFittingNetDirect": raise NotImplementedError - @classmethod - def get_stat_name(cls, ntypes, type_name="ener", **kwargs): - """ - Get the name for the statistic file of the fitting. - Usually use the combination of fitting net name and ntypes as the statistic file name. - """ - fitting_type = type_name - assert fitting_type in ["direct_force", "direct_force_ener"] - return f"stat_file_fitting_direct_ntypes{ntypes}.npz" - def forward( self, inputs: torch.Tensor, diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index b2d8c875ce..83e33bf5f7 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -124,19 +124,6 @@ def share_params(self, base_class, shared_level, resume=False): else: raise NotImplementedError - @classmethod - def get_stat_name(cls, ntypes, type_name="ener", **kwargs): - """ - Get the name for the statistic file of the fitting. - Usually use the combination of fitting net name and ntypes as the statistic file name. - """ - if cls is not Fitting: - raise NotImplementedError("get_stat_name is not implemented!") - fitting_type = type_name - return Fitting.__plugins.plugins[fitting_type].get_stat_name( - ntypes, type_name, **kwargs - ) - @property def data_stat_key(self): """ diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 76b2afe41b..c31270eb9f 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging -import os import numpy as np import torch @@ -82,27 +81,4 @@ def compute_output_bias(energy, natoms, rcond=None): def process_stat_path( stat_file_dict, stat_file_dir, model_params_dict, descriptor_cls, fitting_cls ): - if stat_file_dict is None: - stat_file_dict = {} - if "descriptor" in model_params_dict: - default_stat_file_name_descrpt = descriptor_cls.get_stat_name( - len(model_params_dict["type_map"]), - model_params_dict["descriptor"]["type"], - **model_params_dict["descriptor"], - ) - stat_file_dict["descriptor"] = default_stat_file_name_descrpt - if "fitting_net" in model_params_dict: - default_stat_file_name_fitting = fitting_cls.get_stat_name( - len(model_params_dict["type_map"]), - model_params_dict["fitting_net"].get("type", "ener"), - **model_params_dict["fitting_net"], - ) - stat_file_dict["fitting_net"] = default_stat_file_name_fitting - stat_file_path = { - key: os.path.join(stat_file_dir, stat_file_dict[key]) for key in stat_file_dict - } - - has_stat_file_path_list = [ - os.path.exists(stat_file_path[key]) for key in stat_file_dict - ] - return stat_file_path, all(has_stat_file_path_list) + raise NotImplementedError("to rewrite") From b99d3302ddcb2ef82d9a026863afa52c59a9bbd3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 18:02:53 -0500 Subject: [PATCH 11/40] rewrite compute_input_stats Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/se_e2_a.py | 6 +++++- deepmd/pt/model/descriptor/descriptor.py | 17 +++++------------ deepmd/pt/model/descriptor/dpa1.py | 7 +++++-- deepmd/pt/model/descriptor/dpa2.py | 5 ++++- deepmd/pt/model/descriptor/gaussian_lcc.py | 11 +++++++++-- deepmd/pt/model/descriptor/hybrid.py | 5 ++++- deepmd/pt/model/descriptor/repformers.py | 5 ++++- deepmd/pt/model/descriptor/se_a.py | 4 ++-- deepmd/pt/model/descriptor/se_atten.py | 5 ++++- 9 files changed, 42 insertions(+), 23 deletions(-) diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 02619a550f..6ecbea9e70 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -1,6 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import numpy as np +from deepmd.utils.path import ( + DPPath, +) + try: from deepmd._version import version as __version__ except ImportError: @@ -231,7 +235,7 @@ def get_ntypes(self) -> int: """Returns the number of element types.""" return self.ntypes - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" raise NotImplementedError diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 05dcedb4f6..a65e73d247 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -25,6 +25,9 @@ from deepmd.pt.utils.utils import ( to_torch_tensor, ) +from deepmd.utils.path import ( + DPPath, +) from .base_descriptor import ( BaseDescriptor, @@ -105,17 +108,7 @@ def compute_or_load_stat( stat_file_path The path to the statistics files. """ - # TODO support hybrid descriptor - descrpt_stat_key = self.data_stat_key - if sampled is not None: # compute the statistics results - tmp_dict = self.compute_input_stats(sampled) - result_dict = {key: tmp_dict[key] for key in descrpt_stat_key} - result_dict["type_map"] = type_map - if stat_file_path is not None: - self.save_stats(result_dict, stat_file_path) - else: # load the statistics results - assert stat_file_path is not None, "No stat file to load!" - result_dict = self.load_stats(type_map, stat_file_path) + raise NotImplementedError("to rewrite") def save_stats(self, result_dict, stat_file_path: Union[str, List[str]]): """ @@ -292,7 +285,7 @@ def get_dim_emb(self) -> int: """Returns the embedding dimension.""" pass - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): """Update mean and stddev for DescriptorBlock elements.""" raise NotImplementedError diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 3f82bb38cf..f115d5a95f 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -12,6 +12,9 @@ from deepmd.pt.model.network.network import ( TypeEmbedNet, ) +from deepmd.utils.path import ( + DPPath, +) from .se_atten import ( DescrptBlockSeAtten, @@ -119,8 +122,8 @@ def dim_out(self): def dim_emb(self): return self.get_dim_emb() - def compute_input_stats(self, merged): - return self.se_atten.compute_input_stats(merged) + def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): + return self.se_atten.compute_input_stats(merged, path) @classmethod def get_data_process_key(cls, config): diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 5fa9caa244..67eeaf31cc 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -18,6 +18,9 @@ build_multiple_neighbor_list, get_multiple_nlist_key, ) +from deepmd.utils.path import ( + DPPath, +) from .repformers import ( DescrptBlockRepformers, @@ -286,7 +289,7 @@ def dim_emb(self): """Returns the embedding dimension g2.""" return self.get_dim_emb() - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] for ii, descrpt in enumerate([self.repinit, self.repformers]): merged_tmp = [ diff --git a/deepmd/pt/model/descriptor/gaussian_lcc.py b/deepmd/pt/model/descriptor/gaussian_lcc.py index 4bf9c61814..8243d32ac9 100644 --- a/deepmd/pt/model/descriptor/gaussian_lcc.py +++ b/deepmd/pt/model/descriptor/gaussian_lcc.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + import torch import torch.nn as nn @@ -13,6 +17,9 @@ from deepmd.pt.utils import ( env, ) +from deepmd.utils.path import ( + DPPath, +) class DescrptGaussianLcc(Descriptor): @@ -154,9 +161,9 @@ def dim_emb(self): """Returns the output dimension of pair representation.""" return self.pair_embed_dim - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - return [], [], [], [], [] + pass def forward( self, diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index c3b8610a59..2d4e0e7172 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -13,6 +13,9 @@ Identity, Linear, ) +from deepmd.utils.path import ( + DPPath, +) @DescriptorBlock.register("hybrid") @@ -150,7 +153,7 @@ def share_params(self, base_class, shared_level, resume=False): else: raise NotImplementedError - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] for ii, descrpt in enumerate(self.descriptor_list): diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 3da60b89bd..6f59f72d4e 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -26,6 +26,9 @@ from deepmd.pt.utils.utils import ( get_activation_fn, ) +from deepmd.utils.path import ( + DPPath, +) from .repformer_layer import ( RepformerLayer, @@ -268,7 +271,7 @@ def forward( return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" ndescrpt = self.nnei * 4 sumr = [] diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 2b2d4c89c7..43425c5f59 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -122,9 +122,9 @@ def dim_out(self): """Returns the output dimension of this descriptor.""" return self.sea.dim_out - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - return self.sea.compute_input_stats(merged) + return self.sea.compute_input_stats(merged, path) @classmethod def get_data_process_key(cls, config): diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index a729dfcc28..7c9d9c1eb0 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -24,6 +24,9 @@ from deepmd.pt.utils.nlist import ( extend_input_and_build_neighbor_list, ) +from deepmd.utils.path import ( + DPPath, +) @DescriptorBlock.register("se_atten") @@ -185,7 +188,7 @@ def dim_emb(self): """Returns the output dimension of embedding.""" return self.get_dim_emb() - def compute_input_stats(self, merged): + def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" sumr = [] suma = [] From da1e72d69629dc0187df939526abdd30d13bdb6f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 18:04:49 -0500 Subject: [PATCH 12/40] hybrid Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/hybrid.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index 2d4e0e7172..c0748b9b76 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -155,7 +155,6 @@ def share_params(self, base_class, shared_level, resume=False): def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] for ii, descrpt in enumerate(self.descriptor_list): merged_tmp = [ { @@ -164,12 +163,7 @@ def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None) } for item in merged ] - tmp_stat_dict = descrpt.compute_input_stats(merged_tmp) - sumr.append(tmp_stat_dict["sumr"]) - suma.append(tmp_stat_dict["suma"]) - sumn.append(tmp_stat_dict["sumn"]) - sumr2.append(tmp_stat_dict["sumr2"]) - suma2.append(tmp_stat_dict["suma2"]) + descrpt.compute_input_stats(merged_tmp, path) def forward( self, From 959583a5a148a3e35a4e0f97fdc62eb7d64fd33e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 18:11:49 -0500 Subject: [PATCH 13/40] fix hash Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_a.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 43425c5f59..5b4f3386db 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -470,7 +470,9 @@ def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None) "ntypes": self.ntypes, "rcut": round(self.rcut, 2), "rcut_smth": round(self.rcut_smth, 2), - "sel": self.nsel, + "nsel": self.nsel, + "sel": self.get_sel(), + "distinguish_types": self.distinguish_types(), } ) env_mat_stat = self.EnvMatStat(self) From 524366d41d9b3e39fe7dd180dc55477f32e5f009 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 18:29:54 -0500 Subject: [PATCH 14/40] se atten Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_atten.py | 192 ++++++++++++++++--------- 1 file changed, 125 insertions(+), 67 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 7c9d9c1eb0..6e41513d78 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Dict, + Iterator, List, Optional, ) @@ -7,9 +9,11 @@ import numpy as np import torch +from deepmd.common import ( + get_hash, +) from deepmd.pt.model.descriptor.descriptor import ( DescriptorBlock, - compute_std, ) from deepmd.pt.model.descriptor.env_mat import ( prod_env_mat_se_a, @@ -21,9 +25,15 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env_mat_stat import ( + BaseEnvMatStat, +) from deepmd.pt.utils.nlist import ( extend_input_and_build_neighbor_list, ) +from deepmd.utils.env_mat_stat import ( + StatItem, +) from deepmd.utils.path import ( DPPath, ) @@ -188,88 +198,136 @@ def dim_emb(self): """Returns the output dimension of embedding.""" return self.get_dim_emb() - def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): - """Update mean and stddev for descriptor elements.""" - sumr = [] - suma = [] - sumn = [] - sumr2 = [] - suma2 = [] - mixed_type = "real_natoms_vec" in merged[0] - for system in merged: - coord, atype, box, natoms = ( - system["coord"], - system["atype"], - system["box"], - system["natoms"], - ) - ( - extended_coord, - extended_atype, - mapping, - nlist, - ) = extend_input_and_build_neighbor_list( - coord, - atype, - self.get_rcut(), - self.get_sel(), - distinguish_types=self.distinguish_types(), - box=box, + class EnvMatStat(BaseEnvMatStat): + """A class to calculate the statistics of the environment matrix.""" + + def __init__(self, descriptor: "DescrptBlockSeAtten"): + self.descriptor = descriptor + self.ntypes = descriptor.get_ntypes() + self.rcut = descriptor.get_rcut() + self.rcut_smth = descriptor.rcut_smth + self.nsel = descriptor.get_nsel() + + def iter( + self, data: List[Dict[str, torch.Tensor]] + ) -> Iterator[Dict[str, StatItem]]: + """Get the iterator of the environment matrix. + + Parameters + ---------- + data : List[Dict[str, torch.Tensor]] + The environment matrix. + + Yields + ------ + Dict[str, StatItem] + The statistics of the environment matrix. + """ + zero_mean = torch.zeros( + self.ntypes, + self.nsel * 4, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, ) - env_mat, _, _ = prod_env_mat_se_a( - extended_coord, - nlist, - atype, - self.mean, - self.stddev, - self.rcut, - self.rcut_smth, + one_stddev = torch.ones( + self.ntypes, + self.nsel * 4, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, ) - if not mixed_type: - sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( - env_mat.detach().cpu().numpy(), self.ndescrpt, natoms + for system in data: + coord, atype, box, natoms = ( + system["coord"], + system["atype"], + system["box"], + system["natoms"], ) - else: - real_natoms_vec = system["real_natoms_vec"] - sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( - env_mat.detach().cpu().numpy(), - self.ndescrpt, - real_natoms_vec, - mixed_type=mixed_type, - real_atype=atype.detach().cpu().numpy(), + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + self.rcut, + self.descriptor.get_sel(), + distinguish_types=self.descriptor.distinguish_types(), + box=box, ) - sumr.append(sysr) - suma.append(sysa) - sumn.append(sysn) - sumr2.append(sysr2) - suma2.append(sysa2) - sumr = np.sum(sumr, axis=0) - suma = np.sum(suma, axis=0) - sumn = np.sum(sumn, axis=0) - sumr2 = np.sum(sumr2, axis=0) - suma2 = np.sum(suma2, axis=0) + env_mat, _, _ = prod_env_mat_se_a( + extended_coord, + nlist, + atype, + zero_mean, + one_stddev, + self.descriptor.get_rcut(), + self.descriptor.rcut_smth, + ) + env_mat = env_mat.view(-1, self.nsel, 4) + env_mats = {} + + if "real_natoms_vec" in system: + end_indexes = torch.cumsum(natoms[0, 2:], 0) + start_indexes = torch.cat( + [ + torch.zeros([], dtype=torch.int32, device=env.DEVICE), + end_indexes[:-1], + ] + ) + for type_i in range(self.ntypes): + dd = env_mat[ + :, start_indexes[type_i] : end_indexes[type_i], : + ] # all descriptors for this element + env_mats[f"r_{type_i}"] = dd[:, :1] + env_mats[f"a_{type_i}"] = dd[:, 1:] + yield self.compute_stat(env_mats) + else: + for frame_item in range(env_mat.shape[0]): + dd_ff = env_mat[frame_item] + atype_frame = atype[frame_item] + for type_i in range(self.ntypes): + type_idx = atype_frame == type_i + dd = dd_ff[type_idx] + dd = np.reshape(dd, [-1, 4]) # typen_atoms * nnei, 4 + env_mats[f"r_{type_i}"] = dd[:, :1] + env_mats[f"a_{type_i}"] = dd[:, 1:] + yield self.compute_stat(env_mats) + + def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): + """Update mean and stddev for descriptor elements.""" + if path is not None: + path = path / get_hash( + { + "ntypes": self.ntypes, + "rcut": round(self.rcut, 2), + "rcut_smth": round(self.rcut_smth, 2), + "nsel": self.nsel, + "sel": self.get_sel(), + "distinguish_types": self.distinguish_types(), + } + ) + env_mat_stat = self.EnvMatStat(self) + env_mat_stat.load_or_compute_stats(merged, path) + avgs = env_mat_stat.get_avg() + stds = env_mat_stat.get_std() all_davg = [] all_dstd = [] for type_i in range(self.ntypes): - davgunit = [[sumr[type_i] / (sumn[type_i] + 1e-15), 0, 0, 0]] + davgunit = [[avgs[f"r_{type_i}"], 0, 0, 0]] dstdunit = [ [ - compute_std(sumr2[type_i], sumr[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), + stds[f"r_{type_i}"], + stds[f"a_{type_i}"], + stds[f"a_{type_i}"], + stds[f"a_{type_i}"], ] ] davg = np.tile(davgunit, [self.nnei, 1]) dstd = np.tile(dstdunit, [self.nnei, 1]) all_davg.append(davg) all_dstd.append(dstd) - self.sumr = sumr - self.suma = suma - self.sumn = sumn - self.sumr2 = sumr2 - self.suma2 = suma2 if not self.set_davg_zero: mean = np.stack(all_davg) self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) From 87f0d858f785bfdfc3719d297360064167def862 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 18:36:39 -0500 Subject: [PATCH 15/40] to make it work Signed-off-by: Jinzhe Zeng --- deepmd/pt/entrypoints/main.py | 24 +++---------------- .../pt/model/atomic_model/dp_atomic_model.py | 4 ++-- deepmd/pt/model/descriptor/descriptor.py | 1 - deepmd/pt/utils/stat.py | 3 ++- 4 files changed, 7 insertions(+), 25 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 29ef8761ff..3c8b052570 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -46,12 +46,6 @@ from deepmd.pt.infer import ( inference, ) -from deepmd.pt.model.descriptor import ( - Descriptor, -) -from deepmd.pt.model.task import ( - Fitting, -) from deepmd.pt.train import ( training, ) @@ -69,7 +63,6 @@ ) from deepmd.pt.utils.stat import ( make_stat_input, - process_stat_path, ) from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter @@ -134,19 +127,8 @@ def prepare_trainer_input_single( # noise_settings = None # stat files - hybrid_descrpt = model_params_single["descriptor"]["type"] == "hybrid" - if not hybrid_descrpt: - stat_file_path_single, has_stat_file_path = process_stat_path( - data_dict_single.get("stat_file", None), - data_dict_single.get("stat_file_dir", f"stat_files{suffix}"), - model_params_single, - Descriptor, - Fitting, - ) - else: ### TODO hybrid descriptor not implemented - raise NotImplementedError( - "data stat for hybrid descriptor is not implemented!" - ) + # TODO: rewrite + stat_file_path_single = {"descriptor": "", "fitting": ""} # validation and training data validation_data_single = DpLoaderSet( @@ -156,7 +138,7 @@ def prepare_trainer_input_single( type_split=type_split, noise_settings=noise_settings, ) - if ckpt or finetune_model or has_stat_file_path: + if ckpt or finetune_model: train_data_single = DpLoaderSet( training_systems, training_dataset_params["batch_size"], diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 17b70e4701..37f956abaf 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -190,8 +190,8 @@ def compute_or_load_stat( stat_file_dir = os.path.dirname(stat_file_path_dict["descriptor"]) else: stat_file_dir = os.path.dirname(stat_file_path_dict["descriptor"][0]) - if not os.path.exists(stat_file_dir): - os.mkdir(stat_file_dir) + # if not os.path.exists(stat_file_dir): + # os.mkdir(stat_file_dir) self.descriptor.compute_or_load_stat( type_map, sampled, stat_file_path_dict["descriptor"] ) diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index a65e73d247..ac1e58350f 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -108,7 +108,6 @@ def compute_or_load_stat( stat_file_path The path to the statistics files. """ - raise NotImplementedError("to rewrite") def save_stats(self, result_dict, stat_file_path: Union[str, List[str]]): """ diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index c31270eb9f..56875d37fa 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -81,4 +81,5 @@ def compute_output_bias(energy, natoms, rcond=None): def process_stat_path( stat_file_dict, stat_file_dir, model_params_dict, descriptor_cls, fitting_cls ): - raise NotImplementedError("to rewrite") + # TODO: to rewrite + return From 2e815eacb4d13c5acdf367280289032008e2521b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 18:38:21 -0500 Subject: [PATCH 16/40] compute_or_load_stat Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/descriptor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index ac1e58350f..8011a22e1b 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -108,6 +108,9 @@ def compute_or_load_stat( stat_file_path The path to the statistics files. """ + # TODO + assert sampled is not None + tmp_dict = self.compute_input_stats(sampled, None) def save_stats(self, result_dict, stat_file_path: Union[str, List[str]]): """ From ed34d594aa0dcc864c589a4a1bb4204179dc92c8 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 18:39:23 -0500 Subject: [PATCH 17/40] init Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_a.py | 1 + deepmd/pt/model/descriptor/se_atten.py | 1 + 2 files changed, 2 insertions(+) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 5b4f3386db..2f62582ff4 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -383,6 +383,7 @@ class EnvMatStat(BaseEnvMatStat): """A class to calculate the statistics of the environment matrix.""" def __init__(self, descriptor: "DescrptBlockSeA"): + super().__init__() self.descriptor = descriptor self.ntypes = descriptor.get_ntypes() self.rcut = descriptor.get_rcut() diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 6e41513d78..181c707034 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -202,6 +202,7 @@ class EnvMatStat(BaseEnvMatStat): """A class to calculate the statistics of the environment matrix.""" def __init__(self, descriptor: "DescrptBlockSeAtten"): + super().__init__() self.descriptor = descriptor self.ntypes = descriptor.get_ntypes() self.rcut = descriptor.get_rcut() From d04d16ca41ef96d21ab682278198c5bbde715a15 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 18:41:34 -0500 Subject: [PATCH 18/40] fix shape Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_a.py | 6 ++++-- deepmd/pt/model/descriptor/se_atten.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 2f62582ff4..f90af09d21 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -407,13 +407,15 @@ def iter( """ zero_mean = torch.zeros( self.ntypes, - self.nsel * 4, + self.nsel, + 4, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE, ) one_stddev = torch.ones( self.ntypes, - self.nsel * 4, + self.nsel, + 4, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE, ) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 181c707034..c5b0cedd71 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -226,13 +226,15 @@ def iter( """ zero_mean = torch.zeros( self.ntypes, - self.nsel * 4, + self.nsel, + 4, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE, ) one_stddev = torch.ones( self.ntypes, - self.nsel * 4, + self.nsel, + 4, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE, ) From 2580f8e7e65c510fcbbf5bfd15f63009fafb97ab Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 18:43:22 -0500 Subject: [PATCH 19/40] fix concat Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_a.py | 2 +- deepmd/pt/model/descriptor/se_atten.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index f90af09d21..a9692277d3 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -453,7 +453,7 @@ def iter( end_indexes = torch.cumsum(natoms[0, 2:], 0) start_indexes = torch.cat( [ - torch.zeros([], dtype=torch.int32, device=env.DEVICE), + torch.zeros(1, dtype=torch.int32, device=env.DEVICE), end_indexes[:-1], ] ) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index c5b0cedd71..e6b3e1462e 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -274,7 +274,7 @@ def iter( end_indexes = torch.cumsum(natoms[0, 2:], 0) start_indexes = torch.cat( [ - torch.zeros([], dtype=torch.int32, device=env.DEVICE), + torch.zeros(1, dtype=torch.int32, device=env.DEVICE), end_indexes[:-1], ] ) From 2771b1e72d8ae16a30ee994ee569f089afdfe3db Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 18:48:40 -0500 Subject: [PATCH 20/40] make it work Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/atomic_model/dp_atomic_model.py | 16 ++-------------- deepmd/pt/model/task/fitting.py | 1 - 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 37f956abaf..f627162880 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy import logging -import os import sys from typing import ( Dict, @@ -185,20 +184,9 @@ def compute_or_load_stat( if sampled is not None: # move data to device for data_sys in sampled: dict_to_device(data_sys) - if stat_file_path_dict is not None: - if not isinstance(stat_file_path_dict["descriptor"], list): - stat_file_dir = os.path.dirname(stat_file_path_dict["descriptor"]) - else: - stat_file_dir = os.path.dirname(stat_file_path_dict["descriptor"][0]) - # if not os.path.exists(stat_file_dir): - # os.mkdir(stat_file_dir) - self.descriptor.compute_or_load_stat( - type_map, sampled, stat_file_path_dict["descriptor"] - ) + self.descriptor.compute_or_load_stat(type_map, sampled, None) if self.fitting_net is not None: - self.fitting_net.compute_or_load_stat( - type_map, sampled, stat_file_path_dict["fitting_net"] - ) + self.fitting_net.compute_or_load_stat(type_map, sampled, None) @torch.jit.export def get_dim_fparam(self) -> int: diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 83e33bf5f7..537cb700db 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -158,7 +158,6 @@ def compute_or_load_stat( tmp_dict = self.compute_output_stats(sampled) result_dict = {key: tmp_dict[key] for key in fitting_stat_key} result_dict["type_map"] = type_map - self.save_stats(result_dict, stat_file_path) else: # load the statistics results assert stat_file_path is not None, "No stat file to load!" result_dict = self.load_stats(type_map, stat_file_path) From 3eb557789469bee9c4c81e8ae4d7cb178dac5ecc Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 18:50:28 -0500 Subject: [PATCH 21/40] rm save_stats and load_stats Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/descriptor.py | 58 ----------------------- deepmd/pt/model/task/fitting.py | 59 ++---------------------- 2 files changed, 4 insertions(+), 113 deletions(-) diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 8011a22e1b..1427236523 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -112,64 +112,6 @@ def compute_or_load_stat( assert sampled is not None tmp_dict = self.compute_input_stats(sampled, None) - def save_stats(self, result_dict, stat_file_path: Union[str, List[str]]): - """ - Save the statistics results to `stat_file_path`. - - Parameters - ---------- - result_dict - The dictionary of statistics results. - stat_file_path - The path to the statistics file(s). - """ - if not isinstance(stat_file_path, list): - log.info(f"Saving stat file to {stat_file_path}") - np.savez_compressed(stat_file_path, **result_dict) - else: # TODO hybrid descriptor not implemented - raise NotImplementedError( - "save_stats for hybrid descriptor is not implemented!" - ) - - def load_stats(self, type_map, stat_file_path: Union[str, List[str]]): - """ - Load the statistics results to `stat_file_path`. - - Parameters - ---------- - type_map - Mapping atom type to the name (str) of the type. - For example `type_map[1]` gives the name of the type 1. - stat_file_path - The path to the statistics file(s). - - Returns - ------- - result_dict - The dictionary of statistics results. - """ - descrpt_stat_key = self.data_stat_key - target_type_map = type_map - if not isinstance(stat_file_path, list): - log.info(f"Loading stat file from {stat_file_path}") - stats = np.load(stat_file_path) - stat_type_map = list(stats["type_map"]) - missing_type = [i for i in target_type_map if i not in stat_type_map] - assert not missing_type, ( - f"These type are not in stat file {stat_file_path}: {missing_type}! " - f"Please change the stat file path!" - ) - idx_map = [stat_type_map.index(i) for i in target_type_map] - if stats[descrpt_stat_key[0]].size: # not empty - result_dict = {key: stats[key][idx_map] for key in descrpt_stat_key} - else: - result_dict = {key: [] for key in descrpt_stat_key} - else: # TODO hybrid descriptor not implemented - raise NotImplementedError( - "load_stats for hybrid descriptor is not implemented!" - ) - return result_dict - def __new__(cls, *args, **kwargs): if cls is Descriptor: try: diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 537cb700db..863d278d17 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -154,63 +154,12 @@ def compute_or_load_stat( The path to the statistics files. """ fitting_stat_key = self.data_stat_key - if sampled is not None: - tmp_dict = self.compute_output_stats(sampled) - result_dict = {key: tmp_dict[key] for key in fitting_stat_key} - result_dict["type_map"] = type_map - else: # load the statistics results - assert stat_file_path is not None, "No stat file to load!" - result_dict = self.load_stats(type_map, stat_file_path) + assert sampled is not None + tmp_dict = self.compute_output_stats(sampled) + result_dict = {key: tmp_dict[key] for key in fitting_stat_key} + result_dict["type_map"] = type_map self.init_fitting_stat(**result_dict) - def save_stats(self, result_dict, stat_file_path: str): - """ - Save the statistics results to `stat_file_path`. - - Parameters - ---------- - result_dict - The dictionary of statistics results. - stat_file_path - The path to the statistics file(s). - """ - log.info(f"Saving stat file to {stat_file_path}") - np.savez_compressed(stat_file_path, **result_dict) - - def load_stats(self, type_map, stat_file_path: str): - """ - Load the statistics results to `stat_file_path`. - - Parameters - ---------- - type_map - Mapping atom type to the name (str) of the type. - For example `type_map[1]` gives the name of the type 1. - stat_file_path - The path to the statistics file(s). - - Returns - ------- - result_dict - The dictionary of statistics results. - """ - fitting_stat_key = self.data_stat_key - target_type_map = type_map - log.info(f"Loading stat file from {stat_file_path}") - stats = np.load(stat_file_path) - stat_type_map = list(stats["type_map"]) - missing_type = [i for i in target_type_map if i not in stat_type_map] - assert not missing_type, ( - f"These type are not in stat file {stat_file_path}: {missing_type}! " - f"Please change the stat file path!" - ) - idx_map = [stat_type_map.index(i) for i in target_type_map] - if stats[fitting_stat_key[0]].size: # not empty - result_dict = {key: stats[key][idx_map] for key in fitting_stat_key} - else: - result_dict = {key: [] for key in fitting_stat_key} - return result_dict - def change_energy_bias( self, config, model, old_type_map, new_type_map, bias_shift="delta", ntest=10 ): From 3ad54840c5e9ae04d56e4c7f6f660b60f337603f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 19:07:47 -0500 Subject: [PATCH 22/40] assert_allclose Signed-off-by: Jinzhe Zeng --- source/tests/pt/test_stat.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/source/tests/pt/test_stat.py b/source/tests/pt/test_stat.py index f7e557e371..6af17a694d 100644 --- a/source/tests/pt/test_stat.py +++ b/source/tests/pt/test_stat.py @@ -172,17 +172,13 @@ def test_descriptor(self): stat_dict = my_en.compute_input_stats(sampled) my_en.mean = my_en.mean my_en.stddev = my_en.stddev - self.assertTrue( - np.allclose( - self.dp_d.davg.reshape([-1]), my_en.mean.cpu().reshape([-1]), rtol=0.01 - ) + np.testing.assert_allclose( + self.dp_d.davg.reshape([-1]), my_en.mean.cpu().reshape([-1]), rtol=0.01 ) - self.assertTrue( - np.allclose( - self.dp_d.dstd.reshape([-1]), - my_en.stddev.cpu().reshape([-1]), - rtol=0.01, - ) + np.testing.assert_allclose( + self.dp_d.dstd.reshape([-1]), + my_en.stddev.cpu().reshape([-1]), + rtol=0.01, ) From 2b9bbd84aaf9c4fc693b39c8aa94447dbe0ea58d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 19:30:18 -0500 Subject: [PATCH 23/40] fix shape Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_a.py | 9 +++++---- deepmd/pt/model/descriptor/se_atten.py | 8 ++++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index a9692277d3..63d9d768e0 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -448,7 +448,7 @@ def iter( self.descriptor.get_rcut(), self.descriptor.rcut_smth, ) - env_mat = env_mat.view(-1, self.nsel, 4) + env_mat = env_mat.view(coord.shape[0], coord.shape[1], self.nsel, 4) env_mats = {} end_indexes = torch.cumsum(natoms[0, 2:], 0) start_indexes = torch.cat( @@ -459,10 +459,10 @@ def iter( ) for type_i in range(self.ntypes): dd = env_mat[ - :, start_indexes[type_i] : end_indexes[type_i], : + :, start_indexes[type_i] : end_indexes[type_i], :, : ] # all descriptors for this element - env_mats[f"r_{type_i}"] = dd[:, :1] - env_mats[f"a_{type_i}"] = dd[:, 1:] + env_mats[f"r_{type_i}"] = dd[:, :, :, :1] + env_mats[f"a_{type_i}"] = dd[:, :, :, 1:] yield self.compute_stat(env_mats) def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): @@ -495,6 +495,7 @@ def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None) stds[f"a_{type_i}"], ] ] + print(type_i, davgunit, dstdunit, avgs, stds) davg = np.tile(davgunit, [self.nnei, 1]) dstd = np.tile(dstdunit, [self.nnei, 1]) all_davg.append(davg) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index e6b3e1462e..1da52761f9 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -267,7 +267,7 @@ def iter( self.descriptor.get_rcut(), self.descriptor.rcut_smth, ) - env_mat = env_mat.view(-1, self.nsel, 4) + env_mat = env_mat.view(coord.shape[0], coord.shape[1], self.nsel, 4) env_mats = {} if "real_natoms_vec" in system: @@ -280,10 +280,10 @@ def iter( ) for type_i in range(self.ntypes): dd = env_mat[ - :, start_indexes[type_i] : end_indexes[type_i], : + :, start_indexes[type_i] : end_indexes[type_i], :, : ] # all descriptors for this element - env_mats[f"r_{type_i}"] = dd[:, :1] - env_mats[f"a_{type_i}"] = dd[:, 1:] + env_mats[f"r_{type_i}"] = dd[:, :, :, :1] + env_mats[f"a_{type_i}"] = dd[:, :, :, 1:] yield self.compute_stat(env_mats) else: for frame_item in range(env_mat.shape[0]): From 7d40b9ff81ee270651e06d1696d2480de34e7c95 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 20:03:16 -0500 Subject: [PATCH 24/40] add env mat type Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_a.py | 1 + deepmd/pt/model/descriptor/se_atten.py | 1 + 2 files changed, 2 insertions(+) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 63d9d768e0..aa6bfc1722 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -470,6 +470,7 @@ def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None) if path is not None: path = path / get_hash( { + "type": "se_a", "ntypes": self.ntypes, "rcut": round(self.rcut, 2), "rcut_smth": round(self.rcut_smth, 2), diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 1da52761f9..2a29563fc9 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -302,6 +302,7 @@ def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None) if path is not None: path = path / get_hash( { + "type": "se_a", "ntypes": self.ntypes, "rcut": round(self.rcut, 2), "rcut_smth": round(self.rcut_smth, 2), From a55e21f643f2a11396550b6b76fe0d70c189d255 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 20:04:06 -0500 Subject: [PATCH 25/40] remove print Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_a.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index aa6bfc1722..3c8b22ee55 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -496,7 +496,6 @@ def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None) stds[f"a_{type_i}"], ] ] - print(type_i, davgunit, dstdunit, avgs, stds) davg = np.tile(davgunit, [self.nnei, 1]) dstd = np.tile(dstdunit, [self.nnei, 1]) all_davg.append(davg) From c34622d7e1855aa7f10ce2f976aeb909964c3139 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 20:16:22 -0500 Subject: [PATCH 26/40] merge methods Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_a.py | 98 +------------------- deepmd/pt/model/descriptor/se_atten.py | 111 +--------------------- deepmd/pt/utils/env_mat_stat.py | 122 +++++++++++++++++++++++++ 3 files changed, 127 insertions(+), 204 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 3c8b22ee55..121da97964 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( ClassVar, - Dict, - Iterator, List, Optional, Tuple, @@ -26,9 +24,8 @@ PRECISION_DICT, RESERVED_PRECISON_DICT, ) -from deepmd.pt.utils.env_mat_stat import EnvMatStat as BaseEnvMatStat -from deepmd.utils.env_mat_stat import ( - StatItem, +from deepmd.pt.utils.env_mat_stat import ( + EnvMatStatSeA, ) from deepmd.utils.path import ( DPPath, @@ -49,9 +46,6 @@ from deepmd.pt.model.network.network import ( TypeFilter, ) -from deepmd.pt.utils.nlist import ( - extend_input_and_build_neighbor_list, -) @Descriptor.register("se_e2_a") @@ -379,92 +373,6 @@ def __getitem__(self, key): else: raise KeyError(key) - class EnvMatStat(BaseEnvMatStat): - """A class to calculate the statistics of the environment matrix.""" - - def __init__(self, descriptor: "DescrptBlockSeA"): - super().__init__() - self.descriptor = descriptor - self.ntypes = descriptor.get_ntypes() - self.rcut = descriptor.get_rcut() - self.rcut_smth = descriptor.rcut_smth - self.nsel = descriptor.get_nsel() - - def iter( - self, data: List[Dict[str, torch.Tensor]] - ) -> Iterator[Dict[str, StatItem]]: - """Get the iterator of the environment matrix. - - Parameters - ---------- - data : List[Dict[str, torch.Tensor]] - The environment matrix. - - Yields - ------ - Dict[str, StatItem] - The statistics of the environment matrix. - """ - zero_mean = torch.zeros( - self.ntypes, - self.nsel, - 4, - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - device=env.DEVICE, - ) - one_stddev = torch.ones( - self.ntypes, - self.nsel, - 4, - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - device=env.DEVICE, - ) - for system in data: - coord, atype, box, natoms = ( - system["coord"], - system["atype"], - system["box"], - system["natoms"], - ) - ( - extended_coord, - extended_atype, - mapping, - nlist, - ) = extend_input_and_build_neighbor_list( - coord, - atype, - self.descriptor.get_rcut(), - self.descriptor.get_sel(), - distinguish_types=self.descriptor.distinguish_types(), - box=box, - ) - env_mat, _, _ = prod_env_mat_se_a( - extended_coord, - nlist, - atype, - zero_mean, - one_stddev, - self.descriptor.get_rcut(), - self.descriptor.rcut_smth, - ) - env_mat = env_mat.view(coord.shape[0], coord.shape[1], self.nsel, 4) - env_mats = {} - end_indexes = torch.cumsum(natoms[0, 2:], 0) - start_indexes = torch.cat( - [ - torch.zeros(1, dtype=torch.int32, device=env.DEVICE), - end_indexes[:-1], - ] - ) - for type_i in range(self.ntypes): - dd = env_mat[ - :, start_indexes[type_i] : end_indexes[type_i], :, : - ] # all descriptors for this element - env_mats[f"r_{type_i}"] = dd[:, :, :, :1] - env_mats[f"a_{type_i}"] = dd[:, :, :, 1:] - yield self.compute_stat(env_mats) - def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" if path is not None: @@ -479,7 +387,7 @@ def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None) "distinguish_types": self.distinguish_types(), } ) - env_mat_stat = self.EnvMatStat(self) + env_mat_stat = EnvMatStatSeA(self) env_mat_stat.load_or_compute_stats(merged, path) avgs = env_mat_stat.get_avg() stds = env_mat_stat.get_std() diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 2a29563fc9..925aa6a42c 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( - Dict, - Iterator, List, Optional, ) @@ -26,13 +24,7 @@ env, ) from deepmd.pt.utils.env_mat_stat import ( - BaseEnvMatStat, -) -from deepmd.pt.utils.nlist import ( - extend_input_and_build_neighbor_list, -) -from deepmd.utils.env_mat_stat import ( - StatItem, + EnvMatStatSeA, ) from deepmd.utils.path import ( DPPath, @@ -198,105 +190,6 @@ def dim_emb(self): """Returns the output dimension of embedding.""" return self.get_dim_emb() - class EnvMatStat(BaseEnvMatStat): - """A class to calculate the statistics of the environment matrix.""" - - def __init__(self, descriptor: "DescrptBlockSeAtten"): - super().__init__() - self.descriptor = descriptor - self.ntypes = descriptor.get_ntypes() - self.rcut = descriptor.get_rcut() - self.rcut_smth = descriptor.rcut_smth - self.nsel = descriptor.get_nsel() - - def iter( - self, data: List[Dict[str, torch.Tensor]] - ) -> Iterator[Dict[str, StatItem]]: - """Get the iterator of the environment matrix. - - Parameters - ---------- - data : List[Dict[str, torch.Tensor]] - The environment matrix. - - Yields - ------ - Dict[str, StatItem] - The statistics of the environment matrix. - """ - zero_mean = torch.zeros( - self.ntypes, - self.nsel, - 4, - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - device=env.DEVICE, - ) - one_stddev = torch.ones( - self.ntypes, - self.nsel, - 4, - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - device=env.DEVICE, - ) - for system in data: - coord, atype, box, natoms = ( - system["coord"], - system["atype"], - system["box"], - system["natoms"], - ) - ( - extended_coord, - extended_atype, - mapping, - nlist, - ) = extend_input_and_build_neighbor_list( - coord, - atype, - self.rcut, - self.descriptor.get_sel(), - distinguish_types=self.descriptor.distinguish_types(), - box=box, - ) - env_mat, _, _ = prod_env_mat_se_a( - extended_coord, - nlist, - atype, - zero_mean, - one_stddev, - self.descriptor.get_rcut(), - self.descriptor.rcut_smth, - ) - env_mat = env_mat.view(coord.shape[0], coord.shape[1], self.nsel, 4) - env_mats = {} - - if "real_natoms_vec" in system: - end_indexes = torch.cumsum(natoms[0, 2:], 0) - start_indexes = torch.cat( - [ - torch.zeros(1, dtype=torch.int32, device=env.DEVICE), - end_indexes[:-1], - ] - ) - for type_i in range(self.ntypes): - dd = env_mat[ - :, start_indexes[type_i] : end_indexes[type_i], :, : - ] # all descriptors for this element - env_mats[f"r_{type_i}"] = dd[:, :, :, :1] - env_mats[f"a_{type_i}"] = dd[:, :, :, 1:] - yield self.compute_stat(env_mats) - else: - for frame_item in range(env_mat.shape[0]): - dd_ff = env_mat[frame_item] - atype_frame = atype[frame_item] - for type_i in range(self.ntypes): - type_idx = atype_frame == type_i - dd = dd_ff[type_idx] - dd = np.reshape(dd, [-1, 4]) # typen_atoms * nnei, 4 - env_mats[f"r_{type_i}"] = dd[:, :1] - env_mats[f"a_{type_i}"] = dd[:, 1:] - yield self.compute_stat(env_mats) - def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" if path is not None: @@ -311,7 +204,7 @@ def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None) "distinguish_types": self.distinguish_types(), } ) - env_mat_stat = self.EnvMatStat(self) + env_mat_stat = EnvMatStatSeA(self) env_mat_stat.load_or_compute_stats(merged, path) avgs = env_mat_stat.get_avg() stds = env_mat_stat.get_std() diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index 366217ef54..af0b3bb318 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -1,15 +1,32 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + TYPE_CHECKING, Dict, + Iterator, + List, ) import torch +from deepmd.pt.model.descriptor.env_mat import ( + prod_env_mat_se_a, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) from deepmd.utils.env_mat_stat import EnvMatStat as BaseEnvMatStat from deepmd.utils.env_mat_stat import ( StatItem, ) +if TYPE_CHECKING: + from deepmd.pt.model.descriptor import ( + DescriptorBlock, + ) + class EnvMatStat(BaseEnvMatStat): def compute_stat(self, env_mat: Dict[str, torch.Tensor]) -> Dict[str, StatItem]: @@ -33,3 +50,108 @@ def compute_stat(self, env_mat: Dict[str, torch.Tensor]) -> Dict[str, StatItem]: squared_sum=torch.square(vv).sum().item(), ) return stats + + +class EnvMatStatSeA(EnvMatStat): + """Environmental matrix statistics for the se_a environemntal matrix. + + Parameters + ---------- + descriptor : DescriptorBlock + The descriptor of the model. + """ + + def __init__(self, descriptor: "DescriptorBlock"): + super().__init__() + self.descriptor = descriptor + + def iter( + self, data: List[Dict[str, torch.Tensor]] + ) -> Iterator[Dict[str, StatItem]]: + """Get the iterator of the environment matrix. + + Parameters + ---------- + data : List[Dict[str, torch.Tensor]] + The environment matrix. + + Yields + ------ + Dict[str, StatItem] + The statistics of the environment matrix. + """ + zero_mean = torch.zeros( + self.descriptor.get_ntypes(), + self.descriptor.get_nsel(), + 4, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, + ) + one_stddev = torch.ones( + self.descriptor.get_ntypes(), + self.descriptor.get_nsel(), + 4, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, + ) + for system in data: + coord, atype, box, natoms = ( + system["coord"], + system["atype"], + system["box"], + system["natoms"], + ) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + self.descriptor.get_rcut(), + self.descriptor.get_sel(), + distinguish_types=self.descriptor.distinguish_types(), + box=box, + ) + env_mat, _, _ = prod_env_mat_se_a( + extended_coord, + nlist, + atype, + zero_mean, + one_stddev, + self.descriptor.get_rcut(), + # TODO: export rcut_smth from DescriptorBlock + self.descriptor.rcut_smth, + ) + env_mat = env_mat.view( + coord.shape[0], coord.shape[1], self.descriptor.get_nsel(), 4 + ) + env_mats = {} + + if "real_natoms_vec" in system: + end_indexes = torch.cumsum(natoms[0, 2:], 0) + start_indexes = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=env.DEVICE), + end_indexes[:-1], + ] + ) + for type_i in range(self.descriptor.get_ntypes()): + dd = env_mat[ + :, start_indexes[type_i] : end_indexes[type_i], :, : + ] # all descriptors for this element + env_mats[f"r_{type_i}"] = dd[:, :, :, :1] + env_mats[f"a_{type_i}"] = dd[:, :, :, 1:] + yield self.compute_stat(env_mats) + else: + for frame_item in range(env_mat.shape[0]): + dd_ff = env_mat[frame_item] + atype_frame = atype[frame_item] + for type_i in range(self.descriptor.get_ntypes()): + type_idx = atype_frame == type_i + dd = dd_ff[type_idx] + dd = dd.reshape([-1, 4]) # typen_atoms * nnei, 4 + env_mats[f"r_{type_i}"] = dd[:, :1] + env_mats[f"a_{type_i}"] = dd[:, 1:] + yield self.compute_stat(env_mats) From af0711c2eb7655e9161c4613f914e9ec6e393707 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 20:25:55 -0500 Subject: [PATCH 27/40] merge Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_a.py | 42 ++++------------------ deepmd/pt/model/descriptor/se_atten.py | 39 ++------------------- deepmd/pt/utils/env_mat_stat.py | 48 ++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 72 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 121da97964..2c7ca9ab75 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -9,9 +9,6 @@ import numpy as np import torch -from deepmd.common import ( - get_hash, -) from deepmd.pt.model.descriptor import ( Descriptor, DescriptorBlock, @@ -375,43 +372,16 @@ def __getitem__(self, key): def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - if path is not None: - path = path / get_hash( - { - "type": "se_a", - "ntypes": self.ntypes, - "rcut": round(self.rcut, 2), - "rcut_smth": round(self.rcut_smth, 2), - "nsel": self.nsel, - "sel": self.get_sel(), - "distinguish_types": self.distinguish_types(), - } - ) env_mat_stat = EnvMatStatSeA(self) + if path is not None: + path = path / env_mat_stat.get_hash() env_mat_stat.load_or_compute_stats(merged, path) - avgs = env_mat_stat.get_avg() - stds = env_mat_stat.get_std() - - all_davg = [] - all_dstd = [] - for type_i in range(self.ntypes): - davgunit = [[avgs[f"r_{type_i}"], 0, 0, 0]] - dstdunit = [ - [ - stds[f"r_{type_i}"], - stds[f"a_{type_i}"], - stds[f"a_{type_i}"], - stds[f"a_{type_i}"], - ] - ] - davg = np.tile(davgunit, [self.nnei, 1]) - dstd = np.tile(dstdunit, [self.nnei, 1]) - all_davg.append(davg) - all_dstd.append(dstd) + mean, stddev = env_mat_stat() + if not self.set_davg_zero: + self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) + self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) if not self.set_davg_zero: - mean = np.stack(all_davg) self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) - stddev = np.stack(all_dstd) self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) def forward( diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 925aa6a42c..89e63497a1 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -7,9 +7,6 @@ import numpy as np import torch -from deepmd.common import ( - get_hash, -) from deepmd.pt.model.descriptor.descriptor import ( DescriptorBlock, ) @@ -192,43 +189,13 @@ def dim_emb(self): def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - if path is not None: - path = path / get_hash( - { - "type": "se_a", - "ntypes": self.ntypes, - "rcut": round(self.rcut, 2), - "rcut_smth": round(self.rcut_smth, 2), - "nsel": self.nsel, - "sel": self.get_sel(), - "distinguish_types": self.distinguish_types(), - } - ) env_mat_stat = EnvMatStatSeA(self) + if path is not None: + path = path / env_mat_stat.get_hash() env_mat_stat.load_or_compute_stats(merged, path) - avgs = env_mat_stat.get_avg() - stds = env_mat_stat.get_std() - - all_davg = [] - all_dstd = [] - for type_i in range(self.ntypes): - davgunit = [[avgs[f"r_{type_i}"], 0, 0, 0]] - dstdunit = [ - [ - stds[f"r_{type_i}"], - stds[f"a_{type_i}"], - stds[f"a_{type_i}"], - stds[f"a_{type_i}"], - ] - ] - davg = np.tile(davgunit, [self.nnei, 1]) - dstd = np.tile(dstdunit, [self.nnei, 1]) - all_davg.append(davg) - all_dstd.append(dstd) + mean, stddev = env_mat_stat() if not self.set_davg_zero: - mean = np.stack(all_davg) self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) - stddev = np.stack(all_dstd) self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) def forward( diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index af0b3bb318..eea701d02d 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -6,8 +6,12 @@ List, ) +import numpy as np import torch +from deepmd.common import ( + get_hash, +) from deepmd.pt.model.descriptor.env_mat import ( prod_env_mat_se_a, ) @@ -155,3 +159,47 @@ def iter( env_mats[f"r_{type_i}"] = dd[:, :1] env_mats[f"a_{type_i}"] = dd[:, 1:] yield self.compute_stat(env_mats) + + def get_hash(self) -> str: + """Get the hash of the environment matrix. + + Returns + ------- + str + The hash of the environment matrix. + """ + return get_hash( + { + "type": "se_a", + "ntypes": self.descriptor.get_ntypes(), + "rcut": round(self.descriptor.get_rcut(), 2), + "rcut_smth": round(self.descriptor.rcut_smth, 2), + "nsel": self.descriptor.get_nsel(), + "sel": self.descriptor.get_sel(), + "distinguish_types": self.descriptor.distinguish_types(), + } + ) + + def __call__(self): + avgs = self.get_avg() + stds = self.get_std() + + all_davg = [] + all_dstd = [] + for type_i in range(self.descriptor.get_ntypes()): + davgunit = [[avgs[f"r_{type_i}"], 0, 0, 0]] + dstdunit = [ + [ + stds[f"r_{type_i}"], + stds[f"a_{type_i}"], + stds[f"a_{type_i}"], + stds[f"a_{type_i}"], + ] + ] + davg = np.tile(davgunit, [self.descriptor.get_nsel(), 1]) + dstd = np.tile(dstdunit, [self.descriptor.get_nsel(), 1]) + all_davg.append(davg) + all_dstd.append(dstd) + mean = np.stack(all_davg) + stddev = np.stack(all_dstd) + return mean, stddev From 37e9b28d26c35f8516178a70c198ac9269a53eac Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 20:28:44 -0500 Subject: [PATCH 28/40] clean Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/dpa2.py | 18 +---- deepmd/pt/model/descriptor/repformers.py | 97 ++---------------------- 2 files changed, 8 insertions(+), 107 deletions(-) diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 67eeaf31cc..7c9d9b868a 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -290,7 +290,6 @@ def dim_emb(self): return self.get_dim_emb() def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): - sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] for ii, descrpt in enumerate([self.repinit, self.repformers]): merged_tmp = [ { @@ -299,22 +298,7 @@ def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None) } for item in merged ] - tmp_stat_dict = descrpt.compute_input_stats(merged_tmp) - sumr.append(tmp_stat_dict["sumr"]) - suma.append(tmp_stat_dict["suma"]) - sumn.append(tmp_stat_dict["sumn"]) - sumr2.append(tmp_stat_dict["sumr2"]) - suma2.append(tmp_stat_dict["suma2"]) - - assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2]) - for ii, descrpt in enumerate([self.repinit, self.repformers]): - stat_dict_ii = { - "sumr": sumr[ii], - "suma": suma[ii], - "sumn": sumn[ii], - "sumr2": sumr2[ii], - "suma2": suma2[ii], - } + descrpt.compute_input_stats(merged_tmp) @classmethod def get_data_process_key(cls, config): diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 6f59f72d4e..8980ba3868 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -4,12 +4,10 @@ Optional, ) -import numpy as np import torch from deepmd.pt.model.descriptor.descriptor import ( DescriptorBlock, - compute_std, ) from deepmd.pt.model.descriptor.env_mat import ( prod_env_mat_se_a, @@ -20,8 +18,8 @@ from deepmd.pt.utils import ( env, ) -from deepmd.pt.utils.nlist import ( - extend_input_and_build_neighbor_list, +from deepmd.pt.utils.env_mat_stat import ( + EnvMatStatSeA, ) from deepmd.pt.utils.utils import ( get_activation_fn, @@ -33,9 +31,6 @@ from .repformer_layer import ( RepformerLayer, ) -from .se_atten import ( - analyze_descrpt, -) mydtype = env.GLOBAL_PT_FLOAT_PRECISION mydev = env.DEVICE @@ -273,89 +268,11 @@ def forward( def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - ndescrpt = self.nnei * 4 - sumr = [] - suma = [] - sumn = [] - sumr2 = [] - suma2 = [] - mixed_type = "real_natoms_vec" in merged[0] - for system in merged: - coord, atype, box, natoms = ( - system["coord"], - system["atype"], - system["box"], - system["natoms"], - ) - ( - extended_coord, - extended_atype, - mapping, - nlist, - ) = extend_input_and_build_neighbor_list( - coord, - atype, - self.get_rcut(), - self.get_sel(), - distinguish_types=self.distinguish_types(), - box=box, - ) - env_mat, _, _ = prod_env_mat_se_a( - extended_coord, - nlist, - atype, - self.mean, - self.stddev, - self.rcut, - self.rcut_smth, - ) - if not mixed_type: - sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( - env_mat.detach().cpu().numpy(), ndescrpt, natoms - ) - else: - real_natoms_vec = system["real_natoms_vec"] - sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( - env_mat.detach().cpu().numpy(), - ndescrpt, - real_natoms_vec, - mixed_type=mixed_type, - real_atype=atype.detach().cpu().numpy(), - ) - sumr.append(sysr) - suma.append(sysa) - sumn.append(sysn) - sumr2.append(sysr2) - suma2.append(sysa2) - sumr = np.sum(sumr, axis=0) - suma = np.sum(suma, axis=0) - sumn = np.sum(sumn, axis=0) - sumr2 = np.sum(sumr2, axis=0) - suma2 = np.sum(suma2, axis=0) - - all_davg = [] - all_dstd = [] - for type_i in range(self.ntypes): - davgunit = [[sumr[type_i] / (sumn[type_i] + 1e-15), 0, 0, 0]] - dstdunit = [ - [ - compute_std(sumr2[type_i], sumr[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), - ] - ] - davg = np.tile(davgunit, [self.nnei, 1]) - dstd = np.tile(dstdunit, [self.nnei, 1]) - all_davg.append(davg) - all_dstd.append(dstd) - self.sumr = sumr - self.suma = suma - self.sumn = sumn - self.sumr2 = sumr2 - self.suma2 = suma2 + env_mat_stat = EnvMatStatSeA(self) + if path is not None: + path = path / env_mat_stat.get_hash() + env_mat_stat.load_or_compute_stats(merged, path) + mean, stddev = env_mat_stat() if not self.set_davg_zero: - mean = np.stack(all_davg) self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) - stddev = np.stack(all_dstd) self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) From 2c63335a0926714c4f1bcb21d757e6f4ce8c4ec5 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 20:53:58 -0500 Subject: [PATCH 29/40] clean Signed-off-by: Jinzhe Zeng --- deepmd/pt/entrypoints/main.py | 10 +++++-- .../pt/model/atomic_model/dp_atomic_model.py | 25 ++++++++-------- deepmd/pt/model/descriptor/descriptor.py | 26 ----------------- deepmd/pt/model/model/model.py | 16 ++++++---- deepmd/pt/model/task/ener.py | 17 +++++++---- deepmd/pt/model/task/fitting.py | 29 ------------------- deepmd/pt/train/training.py | 3 +- 7 files changed, 43 insertions(+), 83 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 3c8b052570..d5b4a7fa68 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -12,6 +12,7 @@ Union, ) +import h5py import torch import torch.distributed as dist import torch.version @@ -127,8 +128,13 @@ def prepare_trainer_input_single( # noise_settings = None # stat files - # TODO: rewrite - stat_file_path_single = {"descriptor": "", "fitting": ""} + stat_file_path_single = data_dict_single.get("stat_file", None) + if ( + stat_file_path_single is not None + and not Path(stat_file_path_single).is_file() + ): + with h5py.File(stat_file_path_single, "w") as f: + pass # validation and training data validation_data_single = DpLoaderSet( diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index f627162880..9ca07922cb 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -6,7 +6,6 @@ Dict, List, Optional, - Union, ) import torch @@ -23,6 +22,9 @@ from deepmd.pt.utils.utils import ( dict_to_device, ) +from deepmd.utils.path import ( + DPPath, +) from .base_atomic_model import ( BaseAtomicModel, @@ -159,9 +161,8 @@ def forward_atomic( def compute_or_load_stat( self, - type_map: Optional[List[str]] = None, - sampled=None, - stat_file_path_dict: Optional[Dict[str, Union[str, List[str]]]] = None, + sampled, + stat_file_path: Optional[DPPath] = None, ): """ Compute or load the statistics parameters of the model, @@ -173,20 +174,18 @@ def compute_or_load_stat( Parameters ---------- - type_map - Mapping atom type to the name (str) of the type. - For example `type_map[1]` gives the name of the type 1. sampled The sampled data frames from different data systems. - stat_file_path_dict + stat_file_path The dictionary of paths to the statistics files. """ - if sampled is not None: # move data to device - for data_sys in sampled: - dict_to_device(data_sys) - self.descriptor.compute_or_load_stat(type_map, sampled, None) + for data_sys in sampled: + dict_to_device(data_sys) + if sampled is None: + sampled = [] + self.descriptor.compute_input_stats(sampled, stat_file_path) if self.fitting_net is not None: - self.fitting_net.compute_or_load_stat(type_map, sampled, None) + self.fitting_net.compute_output_stats(sampled, stat_file_path) @torch.jit.export def get_dim_fparam(self) -> int: diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 1427236523..2383e6a050 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -10,7 +10,6 @@ Optional, Set, Tuple, - Union, ) import numpy as np @@ -87,31 +86,6 @@ def data_stat_key(self): """ raise NotImplementedError("data_stat_key is not implemented!") - def compute_or_load_stat( - self, - type_map: List[str], - sampled=None, - stat_file_path: Optional[Union[str, List[str]]] = None, - ): - """ - Compute or load the statistics parameters of the descriptor. - Calculate and save the mean and standard deviation of the descriptor to `stat_file_path` - if `sampled` is not None, otherwise load them from `stat_file_path`. - - Parameters - ---------- - type_map - Mapping atom type to the name (str) of the type. - For example `type_map[1]` gives the name of the type 1. - sampled - The sampled data frames from different data systems. - stat_file_path - The path to the statistics files. - """ - # TODO - assert sampled is not None - tmp_dict = self.compute_input_stats(sampled, None) - def __new__(cls, *args, **kwargs): if cls is Descriptor: try: diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py index 51c5fcf123..d98d25d539 100644 --- a/deepmd/pt/model/model/model.py +++ b/deepmd/pt/model/model/model.py @@ -1,6 +1,14 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + import torch +from deepmd.utils.path import ( + DPPath, +) + class BaseModel(torch.nn.Module): def __init__(self): @@ -9,9 +17,8 @@ def __init__(self): def compute_or_load_stat( self, - type_map=None, - sampled=None, - stat_file_path=None, + sampled, + stat_file_path: Optional[DPPath] = None, ): """ Compute or load the statistics parameters of the model, @@ -23,9 +30,6 @@ def compute_or_load_stat( Parameters ---------- - type_map - Mapping atom type to the name (str) of the type. - For example `type_map[1]` gives the name of the type 1. sampled The sampled data frames from different data systems. stat_file_path diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index 4ab6ea79a6..bde5dc69c0 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -30,6 +30,9 @@ from deepmd.pt.utils.stat import ( compute_output_bias, ) +from deepmd.utils.path import ( + DPPath, +) dtype = env.GLOBAL_PT_FLOAT_PRECISION device = env.DEVICE @@ -150,17 +153,21 @@ def data_stat_key(self): """ return ["bias_atom_e"] - def compute_output_stats(self, merged): + def compute_output_stats(self, merged, stat_file_path: Optional[DPPath] = None): energy = [item["energy"] for item in merged] mixed_type = "real_natoms_vec" in merged[0] if mixed_type: input_natoms = [item["real_natoms_vec"] for item in merged] else: input_natoms = [item["natoms"] for item in merged] - bias_atom_e = compute_output_bias(energy, input_natoms, rcond=self.rcond) - return {"bias_atom_e": bias_atom_e} - - def init_fitting_stat(self, bias_atom_e=None, **kwargs): + if stat_file_path is not None: + stat_file_path = stat_file_path / "bias_atom_e" + if stat_file_path is not None and stat_file_path.is_file(): + bias_atom_e = stat_file_path.load_numpy() + else: + bias_atom_e = compute_output_bias(energy, input_natoms, rcond=self.rcond) + if stat_file_path is not None: + stat_file_path.save_numpy(bias_atom_e) assert all(x is not None for x in [bias_atom_e]) self.bias_atom_e.copy_( torch.tensor(bias_atom_e, device=env.DEVICE).view( diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 863d278d17..2ec8c0f70d 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -8,7 +8,6 @@ Callable, List, Optional, - Union, ) import numpy as np @@ -132,34 +131,6 @@ def data_stat_key(self): """ raise NotImplementedError("data_stat_key is not implemented!") - def compute_or_load_stat( - self, - type_map: List[str], - sampled=None, - stat_file_path: Optional[Union[str, List[str]]] = None, - ): - """ - Compute or load the statistics parameters of the fitting net. - Calculate and save the output bias to `stat_file_path` - if `sampled` is not None, otherwise load them from `stat_file_path`. - - Parameters - ---------- - type_map - Mapping atom type to the name (str) of the type. - For example `type_map[1]` gives the name of the type 1. - sampled - The sampled data frames from different data systems. - stat_file_path - The path to the statistics files. - """ - fitting_stat_key = self.data_stat_key - assert sampled is not None - tmp_dict = self.compute_output_stats(sampled) - result_dict = {key: tmp_dict[key] for key in fitting_stat_key} - result_dict["type_map"] = type_map - self.init_fitting_stat(**result_dict) - def change_energy_bias( self, config, model, old_type_map, new_type_map, bias_shift="delta", ntest=10 ): diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index b2cac5a5eb..8537be6e12 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -185,9 +185,8 @@ def get_single_model(_model_params, _sampled, _stat_file_path): model = get_model(deepcopy(_model_params)).to(DEVICE) if not model_params.get("resuming", False): model.compute_or_load_stat( - type_map=_model_params["type_map"], sampled=_sampled, - stat_file_path_dict=_stat_file_path, + stat_file_path=_stat_file_path, ) return model From d343c326b6c03ebc7fb092939ca7878da0251906 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 21:23:51 -0500 Subject: [PATCH 30/40] fix load stats Signed-off-by: Jinzhe Zeng --- deepmd/pt/entrypoints/main.py | 18 ++++-- deepmd/utils/env_mat_stat.py | 12 ++-- deepmd/utils/path.py | 112 ++++++++++++++++++++++++++-------- 3 files changed, 103 insertions(+), 39 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index d5b4a7fa68..b260000d87 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -65,6 +65,9 @@ from deepmd.pt.utils.stat import ( make_stat_input, ) +from deepmd.utils.path import ( + DPPath, +) from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter log = logging.getLogger(__name__) @@ -129,12 +132,15 @@ def prepare_trainer_input_single( # stat files stat_file_path_single = data_dict_single.get("stat_file", None) - if ( - stat_file_path_single is not None - and not Path(stat_file_path_single).is_file() - ): - with h5py.File(stat_file_path_single, "w") as f: - pass + if stat_file_path_single is not None: + if Path(stat_file_path_single).is_dir(): + raise ValueError( + f"stat_file should be a file, not a directory: {stat_file_path_single}" + ) + if not Path(stat_file_path_single).is_file(): + with h5py.File(stat_file_path_single, "w") as f: + pass + stat_file_path_single = DPPath(stat_file_path_single, "a") # validation and training data validation_data_single = DpLoaderSet( diff --git a/deepmd/utils/env_mat_stat.py b/deepmd/utils/env_mat_stat.py index 36603b6d97..2fa497b9b6 100644 --- a/deepmd/utils/env_mat_stat.py +++ b/deepmd/utils/env_mat_stat.py @@ -135,9 +135,8 @@ def save_stats(self, path: DPPath) -> None: if len(self.stats) == 0: raise ValueError("The statistics hasn't been computed.") for kk, vv in self.stats.items(): - (path / kk / "number").save(vv.number) - (path / kk / "sum").save(vv.sum) - (path / kk / "squared_sum").save(vv.squared_sum) + path.mkdir(parents=True, exist_ok=True) + (path / kk).save_numpy(np.array([vv.number, vv.sum, vv.squared_sum])) def load_stats(self, path: DPPath) -> None: """Load the statistics of the environment matrix. @@ -150,10 +149,11 @@ def load_stats(self, path: DPPath) -> None: if len(self.stats) > 0: raise ValueError("The statistics has already been computed.") for kk in path.glob("*"): + arr = kk.load_numpy() self.stats[kk.name] = StatItem( - number=(kk / "number").load_numpy().item(), - sum=(kk / "sum").load_numpy().item(), - squared_sum=(kk / "squared_sum").load_numpy().item(), + number=arr[0], + sum=arr[1], + squared_sum=arr[2], ) def load_or_compute_stats( diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index 4579795cf4..c9a7cd8554 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -29,9 +29,11 @@ class DPPath(ABC): ---------- path : str path + mode : str, optional + mode, by default "r" """ - def __new__(cls, path: str): + def __new__(cls, path: str, mode: str = "r"): if cls is DPPath: if os.path.isdir(path): return super().__new__(DPOSPath) @@ -137,6 +139,18 @@ def __hash__(self): def name(self) -> str: """Name of the path.""" + @abstractmethod + def mkdir(self, parents: bool = False, exist_ok: bool = False) -> None: + """Make directory. + + Parameters + ---------- + parents : bool, optional + If true, any missing parents of this directory are created as well. + exist_ok : bool, optional + If true, no error will be raised if the target directory already exists. + """ + class DPOSPath(DPPath): """The OS path class to data system (DeepmdData) for real directories. @@ -145,10 +159,13 @@ class DPOSPath(DPPath): ---------- path : str path + mode : str, optional + mode, by default "r" """ - def __init__(self, path: str) -> None: + def __init__(self, path: str, mode: str = "r") -> None: super().__init__() + self.mode = mode if isinstance(path, Path): self.path = path else: @@ -182,6 +199,8 @@ def save_numpy(self, arr: np.ndarray) -> None: arr : np.ndarray NumPy array """ + if self.mode == "r": + raise ValueError("Cannot save to read-only path") np.save(str(self.path), arr) def glob(self, pattern: str) -> List["DPPath"]: @@ -199,7 +218,7 @@ def glob(self, pattern: str) -> List["DPPath"]: """ # currently DPOSPath will only derivative DPOSPath # TODO: discuss if we want to mix DPOSPath and DPH5Path? - return [type(self)(p) for p in self.path.glob(pattern)] + return [type(self)(p, mode=self.mode) for p in self.path.glob(pattern)] def rglob(self, pattern: str) -> List["DPPath"]: """This is like calling :meth:`DPPath.glob()` with `**/` added in front @@ -215,7 +234,7 @@ def rglob(self, pattern: str) -> List["DPPath"]: List[DPPath] list of paths """ - return [type(self)(p) for p in self.path.rglob(pattern)] + return [type(self)(p, mode=self.mode) for p in self.path.rglob(pattern)] def is_file(self) -> bool: """Check if self is file.""" @@ -227,7 +246,7 @@ def is_dir(self) -> bool: def __truediv__(self, key: str) -> "DPPath": """Used for / operator.""" - return type(self)(self.path / key) + return type(self)(self.path / key, mode=self.mode) def __lt__(self, other: "DPOSPath") -> bool: """Whether this DPPath is less than other for sorting.""" @@ -242,6 +261,20 @@ def name(self) -> str: """Name of the path.""" return self.path.name + def mkdir(self, parents: bool = False, exist_ok: bool = False) -> None: + """Make directory. + + Parameters + ---------- + parents : bool, optional + If true, any missing parents of this directory are created as well. + exist_ok : bool, optional + If true, no error will be raised if the target directory already exists. + """ + if self.mode == "r": + raise ValueError("Cannot mkdir to read-only path") + self.path.mkdir(parents=parents, exist_ok=exist_ok) + class DPH5Path(DPPath): """The path class to data system (DeepmdData) for HDF5 files. @@ -256,32 +289,37 @@ class DPH5Path(DPPath): ---------- path : str path + mode : str, optional + mode, by default "r" """ - def __init__(self, path: str) -> None: + def __init__(self, path: str, mode: str = "r") -> None: super().__init__() + self.mode = mode # we use "#" to split path # so we do not support file names containing #... s = path.split("#") self.root_path = s[0] - self.root = self._load_h5py(s[0]) + self.root = self._load_h5py(s[0], mode) # h5 path: default is the root path - self.name = s[1] if len(s) > 1 else "/" + self._name = s[1] if len(s) > 1 else "/" @classmethod @lru_cache(None) - def _load_h5py(cls, path: str) -> h5py.File: + def _load_h5py(cls, path: str, mode: str = "r") -> h5py.File: """Load hdf5 file. Parameters ---------- path : str path to hdf5 file + mode : str, optional + mode, by default 'r' """ # this method has cache to avoid duplicated # loading from different DPH5Path # However the file will be never closed? - return h5py.File(path, "r") + return h5py.File(path, mode) def load_numpy(self) -> np.ndarray: """Load NumPy array. @@ -291,7 +329,7 @@ def load_numpy(self) -> np.ndarray: np.ndarray loaded NumPy array """ - return self.root[self.name][:] + return self.root[self._name][:] def load_txt(self, dtype: Optional[np.dtype] = None, **kwargs) -> np.ndarray: """Load NumPy array from text. @@ -314,9 +352,9 @@ def save_numpy(self, arr: np.ndarray) -> None: arr : np.ndarray NumPy array """ - if self.name in self._keys: - del self.root[self.name] - self.root.create_dataset(self.name, data=arr) + if self._name in self._keys: + del self.root[self._name] + self.root.create_dataset(self._name, data=arr) def glob(self, pattern: str) -> List["DPPath"]: """Search path using the glob pattern. @@ -332,9 +370,9 @@ def glob(self, pattern: str) -> List["DPPath"]: list of paths """ # got paths starts with current path first, which is faster - subpaths = [ii for ii in self._keys if ii.startswith(self.name)] + subpaths = [ii for ii in self._keys if ii.startswith(self._name)] return [ - type(self)(f"{self.root_path}#{pp}") + type(self)(f"{self.root_path}#{pp}", mode=self.mode) for pp in globfilter(subpaths, self._connect_path(pattern)) ] @@ -369,36 +407,56 @@ def _file_keys(cls, file: h5py.File) -> List[str]: def is_file(self) -> bool: """Check if self is file.""" - if self.name not in self._keys: + if self._name not in self._keys: return False - return isinstance(self.root[self.name], h5py.Dataset) + return isinstance(self.root[self._name], h5py.Dataset) def is_dir(self) -> bool: """Check if self is directory.""" - if self.name not in self._keys: + if self._name not in self._keys: return False - return isinstance(self.root[self.name], h5py.Group) + return isinstance(self.root[self._name], h5py.Group) def __truediv__(self, key: str) -> "DPPath": """Used for / operator.""" - return type(self)(f"{self.root_path}#{self._connect_path(key)}") + return type(self)(f"{self.root_path}#{self._connect_path(key)}", mode=self.mode) def _connect_path(self, path: str) -> str: """Connect self with path.""" - if self.name.endswith("/"): - return f"{self.name}{path}" - return f"{self.name}/{path}" + if self._name.endswith("/"): + return f"{self._name}{path}" + return f"{self._name}/{path}" def __lt__(self, other: "DPH5Path") -> bool: """Whether this DPPath is less than other for sorting.""" if self.root_path == other.root_path: - return self.name < other.name + return self._name < other._name return self.root_path < other.root_path def __str__(self) -> str: """Returns path of self.""" - return f"{self.root_path}#{self.name}" + return f"{self.root_path}#{self._name}" + @property def name(self) -> str: """Name of the path.""" - return self.name.split("/")[-1] + return self._name.split("/")[-1] + + def mkdir(self, parents: bool = False, exist_ok: bool = False) -> None: + """Make directory. + + Parameters + ---------- + parents : bool, optional + If true, any missing parents of this directory are created as well. + exist_ok : bool, optional + If true, no error will be raised if the target directory already exists. + """ + if self._name in self._keys: + if not exist_ok: + raise FileExistsError(f"{self} already exists") + return + if parents: + self.root.require_group(self._name) + else: + self.root.create_group(self._name) From 6d8955a8012f723a2688c1e113bd9c2e2b9376a8 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 21:26:55 -0500 Subject: [PATCH 31/40] rm process_stat_path Signed-off-by: Jinzhe Zeng --- deepmd/pt/utils/stat.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 56875d37fa..051fddd14b 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -76,10 +76,3 @@ def compute_output_bias(energy, natoms, rcond=None): sys_tynatom = torch.cat(natoms)[:, 2:].cpu() energy_coef, _, _, _ = np.linalg.lstsq(sys_tynatom, sys_ener, rcond) return energy_coef - - -def process_stat_path( - stat_file_dict, stat_file_dir, model_params_dict, descriptor_cls, fitting_cls -): - # TODO: to rewrite - return From 582451e08af08cc540f3d1116d1b9636c180858c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 21:38:13 -0500 Subject: [PATCH 32/40] fix py38 compatibility Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/make_base_descriptor.py | 2 +- deepmd/dpmodel/descriptor/se_e2_a.py | 2 +- deepmd/pt/model/descriptor/descriptor.py | 2 +- deepmd/pt/model/descriptor/dpa1.py | 2 +- deepmd/pt/model/descriptor/dpa2.py | 2 +- deepmd/pt/model/descriptor/gaussian_lcc.py | 3 ++- deepmd/pt/model/descriptor/hybrid.py | 2 +- deepmd/pt/model/descriptor/repformers.py | 2 +- deepmd/pt/model/descriptor/se_a.py | 4 ++-- deepmd/pt/model/descriptor/se_atten.py | 2 +- 10 files changed, 12 insertions(+), 11 deletions(-) diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py index 4de4a7f139..b7a8bfebcf 100644 --- a/deepmd/dpmodel/descriptor/make_base_descriptor.py +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -74,7 +74,7 @@ def distinguish_types(self) -> bool: pass def compute_input_stats( - self, merged: list[dict], path: Optional[DPPath] = None + self, merged: List[dict], path: Optional[DPPath] = None ): """Update mean and stddev for descriptor elements.""" raise NotImplementedError diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 6ecbea9e70..f692b44356 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -235,7 +235,7 @@ def get_ntypes(self) -> int: """Returns the number of element types.""" return self.ntypes - def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" raise NotImplementedError diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 2383e6a050..501693ddc2 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -203,7 +203,7 @@ def get_dim_emb(self) -> int: """Returns the embedding dimension.""" pass - def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for DescriptorBlock elements.""" raise NotImplementedError diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index f115d5a95f..6bdb5c2cb3 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -122,7 +122,7 @@ def dim_out(self): def dim_emb(self): return self.get_dim_emb() - def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): return self.se_atten.compute_input_stats(merged, path) @classmethod diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 7c9d9b868a..0122dcacb8 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -289,7 +289,7 @@ def dim_emb(self): """Returns the embedding dimension g2.""" return self.get_dim_emb() - def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): for ii, descrpt in enumerate([self.repinit, self.repformers]): merged_tmp = [ { diff --git a/deepmd/pt/model/descriptor/gaussian_lcc.py b/deepmd/pt/model/descriptor/gaussian_lcc.py index 8243d32ac9..72c9f27b2a 100644 --- a/deepmd/pt/model/descriptor/gaussian_lcc.py +++ b/deepmd/pt/model/descriptor/gaussian_lcc.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + List, Optional, ) @@ -161,7 +162,7 @@ def dim_emb(self): """Returns the output dimension of pair representation.""" return self.pair_embed_dim - def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" pass diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index c0748b9b76..a9fad0d9e8 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -153,7 +153,7 @@ def share_params(self, base_class, shared_level, resume=False): else: raise NotImplementedError - def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" for ii, descrpt in enumerate(self.descriptor_list): merged_tmp = [ diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 8980ba3868..4a0bc6bdfb 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -266,7 +266,7 @@ def forward( return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw - def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" env_mat_stat = EnvMatStatSeA(self) if path is not None: diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 2c7ca9ab75..9bf3c20284 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -113,7 +113,7 @@ def dim_out(self): """Returns the output dimension of this descriptor.""" return self.sea.dim_out - def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" return self.sea.compute_input_stats(merged, path) @@ -370,7 +370,7 @@ def __getitem__(self, key): else: raise KeyError(key) - def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" env_mat_stat = EnvMatStatSeA(self) if path is not None: diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 89e63497a1..5a12f6b9ad 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -187,7 +187,7 @@ def dim_emb(self): """Returns the output dimension of embedding.""" return self.get_dim_emb() - def compute_input_stats(self, merged: list[dict], path: Optional[DPPath] = None): + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" env_mat_stat = EnvMatStatSeA(self) if path is not None: From 1c1c1a5c0277599ea945f1116ea7e7e29c0be465 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 16 Feb 2024 22:00:44 -0500 Subject: [PATCH 33/40] fix typo Signed-off-by: Jinzhe Zeng --- deepmd/pt/utils/env_mat_stat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index eea701d02d..4e37f40b2c 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -133,7 +133,7 @@ def iter( ) env_mats = {} - if "real_natoms_vec" in system: + if "real_natoms_vec" not in system: end_indexes = torch.cumsum(natoms[0, 2:], 0) start_indexes = torch.cat( [ From b47014bb32473d72a55c38bb63625e5debc5381a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 17 Feb 2024 17:39:04 -0500 Subject: [PATCH 34/40] rm unused compute_std Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/__init__.py | 2 -- deepmd/pt/model/descriptor/descriptor.py | 10 ---------- 2 files changed, 12 deletions(-) diff --git a/deepmd/pt/model/descriptor/__init__.py b/deepmd/pt/model/descriptor/__init__.py index 4252e34905..1c2e943369 100644 --- a/deepmd/pt/model/descriptor/__init__.py +++ b/deepmd/pt/model/descriptor/__init__.py @@ -2,7 +2,6 @@ from .descriptor import ( Descriptor, DescriptorBlock, - compute_std, make_default_type_embedding, ) from .dpa1 import ( @@ -32,7 +31,6 @@ __all__ = [ "Descriptor", "DescriptorBlock", - "compute_std", "make_default_type_embedding", "DescrptBlockSeA", "DescrptBlockSeAtten", diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 501693ddc2..fb732248e3 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -305,16 +305,6 @@ def build_type_exclude_mask( return mask -def compute_std(sumv2, sumv, sumn, rcut_r): - """Compute standard deviation.""" - if sumn == 0: - return 1.0 / rcut_r - val = np.sqrt(sumv2 / sumn - np.multiply(sumv / sumn, sumv / sumn)) - if np.abs(val) < 1e-2: - val = 1e-2 - return val - - def make_default_type_embedding( ntypes, ): From ef5a92e0c1949718f3ff4777dc7bfa25d23eadfe Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 17 Feb 2024 18:01:27 -0500 Subject: [PATCH 35/40] bugfix Signed-off-by: Jinzhe Zeng --- deepmd/pt/utils/env_mat_stat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index 4e37f40b2c..5247ce08ba 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -131,7 +131,6 @@ def iter( env_mat = env_mat.view( coord.shape[0], coord.shape[1], self.descriptor.get_nsel(), 4 ) - env_mats = {} if "real_natoms_vec" not in system: end_indexes = torch.cumsum(natoms[0, 2:], 0) @@ -145,6 +144,7 @@ def iter( dd = env_mat[ :, start_indexes[type_i] : end_indexes[type_i], :, : ] # all descriptors for this element + env_mats = {} env_mats[f"r_{type_i}"] = dd[:, :, :, :1] env_mats[f"a_{type_i}"] = dd[:, :, :, 1:] yield self.compute_stat(env_mats) @@ -156,6 +156,7 @@ def iter( type_idx = atype_frame == type_i dd = dd_ff[type_idx] dd = dd.reshape([-1, 4]) # typen_atoms * nnei, 4 + env_mats = {} env_mats[f"r_{type_i}"] = dd[:, :1] env_mats[f"a_{type_i}"] = dd[:, 1:] yield self.compute_stat(env_mats) From a8aef1821dfd28eab3b457ea167ac62488aea6f5 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 17 Feb 2024 20:09:05 -0500 Subject: [PATCH 36/40] make test work Signed-off-by: Jinzhe Zeng --- source/tests/pt/test_stat.py | 140 ++++++++++++++++++++++++++++++----- 1 file changed, 120 insertions(+), 20 deletions(-) diff --git a/source/tests/pt/test_stat.py b/source/tests/pt/test_stat.py index 6af17a694d..12af6ba866 100644 --- a/source/tests/pt/test_stat.py +++ b/source/tests/pt/test_stat.py @@ -2,16 +2,24 @@ import json import os import unittest +from abc import ( + ABC, + abstractmethod, +) from pathlib import ( Path, ) +import dpdata import numpy as np import torch from deepmd.pt.model.descriptor import ( DescrptSeA, ) +from deepmd.pt.model.descriptor.dpa1 import ( + DescrptDPA1, +) from deepmd.pt.utils import ( env, ) @@ -26,6 +34,7 @@ expand_sys_str, ) from deepmd.tf.descriptor.se_a import DescrptSeA as DescrptSeA_tf +from deepmd.tf.descriptor.se_atten import DescrptSeAtten as DescrptSeAtten_tf from deepmd.tf.fit.ener import ( EnerFitting, ) @@ -50,12 +59,29 @@ def compare(ut, base, given): ut.assertEqual(base, given) -class TestDataset(unittest.TestCase): +class DatasetTest(ABC): + @abstractmethod + def setup_data(self): + pass + + @abstractmethod + def setup_tf(self): + pass + + @abstractmethod + def setup_pt(self): + pass + + @abstractmethod + def tf_compute_input_stats(self): + pass + def setUp(self): with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin: content = fin.read() config = json.loads(content) - data_file = [str(Path(__file__).parent / "water/data/data_0")] + data_file = [self.setup_data()] + config["training"]["training_data"]["systems"] = data_file config["training"]["validation_data"]["systems"] = data_file model_config = config["model"] @@ -97,13 +123,7 @@ def setUp(self): self.dp_sampled = dp_make(dp_dataset, self.data_stat_nbatch, False) self.dp_merged = dp_merge(self.dp_sampled) self.dp_mesh = self.dp_merged.pop("default_mesh") - self.dp_d = DescrptSeA_tf( - rcut=self.rcut, - rcut_smth=self.rcut_smth, - sel=self.sel, - neuron=self.filter_neuron, - axis_neuron=self.axis_neuron, - ) + self.dp_d = self.setup_tf() def test_stat_output(self): def my_merge(energy, natoms): @@ -147,16 +167,9 @@ def test_stat_input(self): """ def test_descriptor(self): - coord = self.dp_merged["coord"] - atype = self.dp_merged["type"] - natoms = self.dp_merged["natoms_vec"] - box = self.dp_merged["box"] - self.dp_d.compute_input_stats(coord, box, atype, natoms, self.dp_mesh, {}) + self.tf_compute_input_stats() - my_en = DescrptSeA( - self.rcut, self.rcut_smth, self.sel, self.filter_neuron, self.axis_neuron - ) - my_en = my_en.sea # get the block who has stat as private vars + my_en = self.setup_pt() sampled = self.my_sampled for sys in sampled: for key in [ @@ -173,12 +186,99 @@ def test_descriptor(self): my_en.mean = my_en.mean my_en.stddev = my_en.stddev np.testing.assert_allclose( - self.dp_d.davg.reshape([-1]), my_en.mean.cpu().reshape([-1]), rtol=0.01 + self.dp_d.davg.reshape([-1]), + my_en.mean.cpu().reshape([-1]), + rtol=1e-14, + atol=1e-14, ) np.testing.assert_allclose( self.dp_d.dstd.reshape([-1]), my_en.stddev.cpu().reshape([-1]), - rtol=0.01, + rtol=1e-14, + atol=1e-14, + ) + + +class TestDatasetNoMixed(DatasetTest, unittest.TestCase): + def setup_data(self): + original_data = str(Path(__file__).parent / "water/data/data_0") + picked_data = str(Path(__file__).parent / "picked_data_for_test_stat") + dpdata.LabeledSystem(original_data, fmt="deepmd/npy")[:2].to_deepmd_npy( + picked_data + ) + self.mixed_type = False + return picked_data + + def setup_tf(self): + return DescrptSeA_tf( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + sel=self.sel, + neuron=self.filter_neuron, + axis_neuron=self.axis_neuron, + ) + + def setup_pt(self): + return DescrptSeA( + self.rcut, self.rcut_smth, self.sel, self.filter_neuron, self.axis_neuron + ).sea # get the block who has stat as private vars + + def tf_compute_input_stats(self): + coord = self.dp_merged["coord"] + atype = self.dp_merged["type"] + natoms = self.dp_merged["natoms_vec"] + box = self.dp_merged["box"] + self.dp_d.compute_input_stats(coord, box, atype, natoms, self.dp_mesh, {}) + + +class TestDatasetMixed(DatasetTest, unittest.TestCase): + def setup_data(self): + original_data = str(Path(__file__).parent / "water/data/data_0") + picked_data = str(Path(__file__).parent / "picked_data_for_test_stat") + dpdata.LabeledSystem(original_data, fmt="deepmd/npy")[:2].to_deepmd_npy_mixed( + picked_data + ) + self.mixed_type = True + return picked_data + + def setup_tf(self): + return DescrptSeAtten_tf( + ntypes=2, + rcut=self.rcut, + rcut_smth=self.rcut_smth, + sel=sum(self.sel), + neuron=self.filter_neuron, + axis_neuron=self.axis_neuron, + set_davg_zero=False, + ) + + def setup_pt(self): + return DescrptDPA1( + self.rcut, + self.rcut_smth, + sum(self.sel), + 2, + self.filter_neuron, + self.axis_neuron, + set_davg_zero=False, + ).se_atten + + def tf_compute_input_stats(self): + coord = self.dp_merged["coord"] + atype = self.dp_merged["type"] + natoms = self.dp_merged["natoms_vec"] + box = self.dp_merged["box"] + real_natoms_vec = self.dp_merged["real_natoms_vec"] + + self.dp_d.compute_input_stats( + coord, + box, + atype, + natoms, + self.dp_mesh, + {}, + mixed_type=True, + real_natoms_vec=real_natoms_vec, ) From 6a83465c7447a5aa159bbb892cc4b2b08e4dc843 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 18 Feb 2024 02:51:50 -0500 Subject: [PATCH 37/40] add type_map to stat_file_path Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/atomic_model/dp_atomic_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 9ca07922cb..aafd2831b3 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -179,6 +179,10 @@ def compute_or_load_stat( stat_file_path The dictionary of paths to the statistics files. """ + if stat_file_path is not None and self.type_map is not None: + # descriptors and fitting net with different type_map + # should not share the same parameters + stat_file_path /= " ".join(self.type_map) for data_sys in sampled: dict_to_device(data_sys) if sampled is None: From ba688a9892f4c9bb9786e8ec4bc5a90be6cd649a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 18 Feb 2024 03:21:18 -0500 Subject: [PATCH 38/40] update share_params Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/descriptor.py | 41 ++++++++++++------------ deepmd/pt/model/descriptor/repformers.py | 13 ++++++++ deepmd/pt/model/descriptor/se_a.py | 13 ++++++++ deepmd/pt/model/descriptor/se_atten.py | 13 ++++++++ 4 files changed, 59 insertions(+), 21 deletions(-) diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 8a2cb5096f..6044d46c6b 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -15,9 +15,18 @@ from deepmd.pt.model.network.network import ( TypeEmbedNet, ) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env_mat_stat import ( + EnvMatStatSeA, +) from deepmd.pt.utils.plugin import ( Plugin, ) +from deepmd.utils.env_mat_stat import ( + StatItem, +) from deepmd.utils.path import ( DPPath, ) @@ -175,6 +184,10 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) """Update mean and stddev for DescriptorBlock elements.""" raise NotImplementedError + def get_stats(self) -> dict[str, StatItem]: + """Get the statistics of the descriptor.""" + raise NotImplementedError + def share_params(self, base_class, shared_level, resume=False): assert ( self.__class__ == base_class.__class__ @@ -183,27 +196,13 @@ def share_params(self, base_class, shared_level, resume=False): # link buffers if hasattr(self, "mean") and not resume: # in case of change params during resume - sumr_base, suma_base, sumn_base, sumr2_base, suma2_base = ( - base_class.sumr, - base_class.suma, - base_class.sumn, - base_class.sumr2, - base_class.suma2, - ) - sumr, suma, sumn, sumr2, suma2 = ( - self.sumr, - self.suma, - self.sumn, - self.sumr2, - self.suma2, - ) - stat_dict = { - "sumr": sumr_base + sumr, - "suma": suma_base + suma, - "sumn": sumn_base + sumn, - "sumr2": sumr2_base + sumr2, - "suma2": suma2_base + suma2, - } + base_env = EnvMatStatSeA(base_class) + for kk in base_class.get_stats(): + base_env.stats[kk] += self.get_stats()[kk] + mean, stddev = base_env() + if not base_class.set_davg_zero: + base_class.mean.copy_(torch.tensor(mean, device=env.DEVICE)) + base_class.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) self.mean = base_class.mean self.stddev = base_class.stddev # self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index d89574a5eb..c88d066c43 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -24,6 +24,9 @@ from deepmd.pt.utils.utils import ( get_activation_fn, ) +from deepmd.utils.env_mat_stat import ( + StatItem, +) from deepmd.utils.path import ( DPPath, ) @@ -147,6 +150,7 @@ def __init__( stddev = torch.ones(sshape, dtype=mydtype, device=mydev) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) + self.stats = None def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -272,7 +276,16 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) if path is not None: path = path / env_mat_stat.get_hash() env_mat_stat.load_or_compute_stats(merged, path) + self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() if not self.set_davg_zero: self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) + + def get_stats(self) -> dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index ac742bde5e..8b317ec480 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -24,6 +24,9 @@ from deepmd.pt.utils.env_mat_stat import ( EnvMatStatSeA, ) +from deepmd.utils.env_mat_stat import ( + StatItem, +) from deepmd.utils.path import ( DPPath, ) @@ -313,6 +316,7 @@ def __init__( resnet_dt=self.resnet_dt, ) self.filter_layers = filter_layers + self.stats = None def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -380,6 +384,7 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) if path is not None: path = path / env_mat_stat.get_hash() env_mat_stat.load_or_compute_stats(merged, path) + self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() if not self.set_davg_zero: self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) @@ -388,6 +393,14 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) + def get_stats(self) -> dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats + def forward( self, nlist: torch.Tensor, diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 5fa01f56fc..70b15ee22f 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -23,6 +23,9 @@ from deepmd.pt.utils.env_mat_stat import ( EnvMatStatSeA, ) +from deepmd.utils.env_mat_stat import ( + StatItem, +) from deepmd.utils.path import ( DPPath, ) @@ -137,6 +140,7 @@ def __init__( ) filter_layers.append(one) self.filter_layers = torch.nn.ModuleList(filter_layers) + self.stats = None def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -193,11 +197,20 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) if path is not None: path = path / env_mat_stat.get_hash() env_mat_stat.load_or_compute_stats(merged, path) + self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() if not self.set_davg_zero: self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) + def get_stats(self) -> dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats + def forward( self, nlist: torch.Tensor, From 7c9a66e544a371cc33dae7a03e4e7a7ebf66d950 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 18 Feb 2024 04:30:08 -0500 Subject: [PATCH 39/40] base_env starts from base_class.stats Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/descriptor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 6044d46c6b..c8cdf1aaf1 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -197,6 +197,7 @@ def share_params(self, base_class, shared_level, resume=False): if hasattr(self, "mean") and not resume: # in case of change params during resume base_env = EnvMatStatSeA(base_class) + base_env.stats = base_class.stats for kk in base_class.get_stats(): base_env.stats[kk] += self.get_stats()[kk] mean, stddev = base_env() From 2ad6990b8cf7584a03767e803a2be0fd9347d354 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 18 Feb 2024 04:35:02 -0500 Subject: [PATCH 40/40] fix py38 compatibility Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/descriptor.py | 3 ++- deepmd/pt/model/descriptor/repformers.py | 3 ++- deepmd/pt/model/descriptor/se_a.py | 3 ++- deepmd/pt/model/descriptor/se_atten.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index c8cdf1aaf1..16659e444d 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -6,6 +6,7 @@ ) from typing import ( Callable, + Dict, List, Optional, ) @@ -184,7 +185,7 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) """Update mean and stddev for DescriptorBlock elements.""" raise NotImplementedError - def get_stats(self) -> dict[str, StatItem]: + def get_stats(self) -> Dict[str, StatItem]: """Get the statistics of the descriptor.""" raise NotImplementedError diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index c88d066c43..76051a52ed 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Dict, List, Optional, ) @@ -282,7 +283,7 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) - def get_stats(self) -> dict[str, StatItem]: + def get_stats(self) -> Dict[str, StatItem]: """Get the statistics of the descriptor.""" if self.stats is None: raise RuntimeError( diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 8b317ec480..33cc3ee9e2 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( ClassVar, + Dict, List, Optional, Tuple, @@ -393,7 +394,7 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) - def get_stats(self) -> dict[str, StatItem]: + def get_stats(self) -> Dict[str, StatItem]: """Get the statistics of the descriptor.""" if self.stats is None: raise RuntimeError( diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 70b15ee22f..410a2039aa 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Dict, List, Optional, ) @@ -203,7 +204,7 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) - def get_stats(self) -> dict[str, StatItem]: + def get_stats(self) -> Dict[str, StatItem]: """Get the statistics of the descriptor.""" if self.stats is None: raise RuntimeError(