Skip to content

Commit

Permalink
pt: refactor data stat (#3285)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Feb 18, 2024
1 parent 9053caf commit c55ceac
Show file tree
Hide file tree
Showing 23 changed files with 879 additions and 825 deletions.
14 changes: 14 additions & 0 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import platform
import shutil
import warnings
from hashlib import (
sha1,
)
from pathlib import (
Path,
)
Expand Down Expand Up @@ -299,3 +302,14 @@ def symlink_prefix_files(old_prefix: str, new_prefix: str):
os.symlink(os.path.relpath(ori_ff, os.path.dirname(new_ff)), new_ff)
else:
shutil.copyfile(ori_ff, new_ff)


def get_hash(obj) -> str:
"""Get hash of object.
Parameters
----------
obj
object to hash
"""
return sha1(json.dumps(obj).encode("utf-8")).hexdigest()
12 changes: 7 additions & 5 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
Optional,
)

from deepmd.utils.path import (
DPPath,
)


def make_base_descriptor(
t_tensor,
Expand Down Expand Up @@ -69,14 +73,12 @@ def distinguish_types(self) -> bool:
"""
pass

def compute_input_stats(self, merged):
def compute_input_stats(
self, merged: List[dict], path: Optional[DPPath] = None
):
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

def init_desc_stat(self, **kwargs):
"""Initialize the model bias by the statistics."""
raise NotImplementedError

@abstractmethod
def fwd(
self,
Expand Down
10 changes: 5 additions & 5 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import numpy as np

from deepmd.utils.path import (
DPPath,
)

try:
from deepmd._version import version as __version__
except ImportError:
Expand Down Expand Up @@ -229,14 +233,10 @@ def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes

def compute_input_stats(self, merged):
def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):
"""Initialize the model bias by the statistics."""
raise NotImplementedError

def cal_g(
self,
ss,
Expand Down
36 changes: 15 additions & 21 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Union,
)

import h5py
import torch
import torch.distributed as dist
import torch.version
Expand Down Expand Up @@ -46,12 +47,6 @@
from deepmd.pt.infer import (
inference,
)
from deepmd.pt.model.descriptor import (
Descriptor,
)
from deepmd.pt.model.task import (
Fitting,
)
from deepmd.pt.train import (
training,
)
Expand All @@ -69,7 +64,9 @@
)
from deepmd.pt.utils.stat import (
make_stat_input,
process_stat_path,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter

Expand Down Expand Up @@ -134,19 +131,16 @@ def prepare_trainer_input_single(
# noise_settings = None

# stat files
hybrid_descrpt = model_params_single["descriptor"]["type"] == "hybrid"
if not hybrid_descrpt:
stat_file_path_single, has_stat_file_path = process_stat_path(
data_dict_single.get("stat_file", None),
data_dict_single.get("stat_file_dir", f"stat_files{suffix}"),
model_params_single,
Descriptor,
Fitting,
)
else: ### TODO hybrid descriptor not implemented
raise NotImplementedError(
"data stat for hybrid descriptor is not implemented!"
)
stat_file_path_single = data_dict_single.get("stat_file", None)
if stat_file_path_single is not None:
if Path(stat_file_path_single).is_dir():
raise ValueError(
f"stat_file should be a file, not a directory: {stat_file_path_single}"
)
if not Path(stat_file_path_single).is_file():
with h5py.File(stat_file_path_single, "w") as f:
pass
stat_file_path_single = DPPath(stat_file_path_single, "a")

# validation and training data
validation_data_single = DpLoaderSet(
Expand All @@ -156,7 +150,7 @@ def prepare_trainer_input_single(
type_split=type_split,
noise_settings=noise_settings,
)
if ckpt or finetune_model or has_stat_file_path:
if ckpt or finetune_model:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
Expand Down
41 changes: 16 additions & 25 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
import os
import sys
from typing import (
Dict,
List,
Optional,
Union,
)

import torch
Expand All @@ -24,6 +22,9 @@
from deepmd.pt.utils.utils import (
dict_to_device,
)
from deepmd.utils.path import (
DPPath,
)

from .base_atomic_model import (
BaseAtomicModel,
Expand Down Expand Up @@ -160,9 +161,8 @@ def forward_atomic(

def compute_or_load_stat(
self,
type_map: Optional[List[str]] = None,
sampled=None,
stat_file_path_dict: Optional[Dict[str, Union[str, List[str]]]] = None,
sampled,
stat_file_path: Optional[DPPath] = None,
):
"""
Compute or load the statistics parameters of the model,
Expand All @@ -174,31 +174,22 @@ def compute_or_load_stat(
Parameters
----------
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
sampled
The sampled data frames from different data systems.
stat_file_path_dict
stat_file_path
The dictionary of paths to the statistics files.
"""
if sampled is not None: # move data to device
for data_sys in sampled:
dict_to_device(data_sys)
if stat_file_path_dict is not None:
if not isinstance(stat_file_path_dict["descriptor"], list):
stat_file_dir = os.path.dirname(stat_file_path_dict["descriptor"])
else:
stat_file_dir = os.path.dirname(stat_file_path_dict["descriptor"][0])
if not os.path.exists(stat_file_dir):
os.mkdir(stat_file_dir)
self.descriptor.compute_or_load_stat(
type_map, sampled, stat_file_path_dict["descriptor"]
)
if stat_file_path is not None and self.type_map is not None:
# descriptors and fitting net with different type_map
# should not share the same parameters
stat_file_path /= " ".join(self.type_map)
for data_sys in sampled:
dict_to_device(data_sys)
if sampled is None:
sampled = []
self.descriptor.compute_input_stats(sampled, stat_file_path)
if self.fitting_net is not None:
self.fitting_net.compute_or_load_stat(
type_map, sampled, stat_file_path_dict["fitting_net"]
)
self.fitting_net.compute_output_stats(sampled, stat_file_path)

@torch.jit.export
def get_dim_fparam(self) -> int:
Expand Down
2 changes: 0 additions & 2 deletions deepmd/pt/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .descriptor import (
Descriptor,
DescriptorBlock,
compute_std,
make_default_type_embedding,
)
from .dpa1 import (
Expand Down Expand Up @@ -32,7 +31,6 @@
__all__ = [
"Descriptor",
"DescriptorBlock",
"compute_std",
"make_default_type_embedding",
"DescrptBlockSeA",
"DescrptBlockSeAtten",
Expand Down
Loading

0 comments on commit c55ceac

Please sign in to comment.