Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: refactor data stat #3285

Merged
merged 42 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
f352a67
checkpoint
njzjz Feb 16, 2024
b2b48b2
Merge branch 'devel' into env-mat-stat
njzjz Feb 16, 2024
1c3e2bb
record sum
njzjz Feb 16, 2024
d6bf4ab
compute_std
njzjz Feb 16, 2024
4f029b0
protection
njzjz Feb 16, 2024
afd6d6a
save stat
njzjz Feb 16, 2024
2f67aac
get std
njzjz Feb 16, 2024
89d413c
compute avg
njzjz Feb 16, 2024
8b6e24e
sea looks good
njzjz Feb 16, 2024
4736ae5
rm init_desc_stat
njzjz Feb 16, 2024
6cb1275
rm get_stat_name
njzjz Feb 16, 2024
b99d330
rewrite compute_input_stats
njzjz Feb 16, 2024
da1e72d
hybrid
njzjz Feb 16, 2024
959583a
fix hash
njzjz Feb 16, 2024
524366d
se atten
njzjz Feb 16, 2024
87f0d85
to make it work
njzjz Feb 16, 2024
2e815ea
compute_or_load_stat
njzjz Feb 16, 2024
ed34d59
init
njzjz Feb 16, 2024
d04d16c
fix shape
njzjz Feb 16, 2024
2580f8e
fix concat
njzjz Feb 16, 2024
2771b1e
make it work
njzjz Feb 16, 2024
3eb5577
rm save_stats and load_stats
njzjz Feb 16, 2024
3ad5484
assert_allclose
njzjz Feb 17, 2024
2b9bbd8
fix shape
njzjz Feb 17, 2024
7d40b9f
add env mat type
njzjz Feb 17, 2024
a55e21f
remove print
njzjz Feb 17, 2024
c34622d
merge methods
njzjz Feb 17, 2024
af0711c
merge
njzjz Feb 17, 2024
37e9b28
clean
njzjz Feb 17, 2024
2c63335
clean
njzjz Feb 17, 2024
d343c32
fix load stats
njzjz Feb 17, 2024
6d8955a
rm process_stat_path
njzjz Feb 17, 2024
582451e
fix py38 compatibility
njzjz Feb 17, 2024
1c1c1a5
fix typo
njzjz Feb 17, 2024
b47014b
rm unused compute_std
njzjz Feb 17, 2024
ef5a92e
bugfix
njzjz Feb 17, 2024
367f472
Merge branch 'devel' into env-mat-stat
njzjz Feb 17, 2024
a8aef18
make test work
njzjz Feb 18, 2024
6a83465
add type_map to stat_file_path
njzjz Feb 18, 2024
ba688a9
update share_params
njzjz Feb 18, 2024
7c9a66e
base_env starts from base_class.stats
njzjz Feb 18, 2024
2ad6990
fix py38 compatibility
njzjz Feb 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import platform
import shutil
import warnings
from hashlib import (
sha1,
)
from pathlib import (
Path,
)
Expand Down Expand Up @@ -299,3 +302,14 @@
os.symlink(os.path.relpath(ori_ff, os.path.dirname(new_ff)), new_ff)
else:
shutil.copyfile(ori_ff, new_ff)


def get_hash(obj) -> str:
"""Get hash of object.

Parameters
----------
obj
object to hash
"""
return sha1(json.dumps(obj).encode("utf-8")).hexdigest()

Check warning on line 315 in deepmd/common.py

View check run for this annotation

Codecov / codecov/patch

deepmd/common.py#L315

Added line #L315 was not covered by tests
12 changes: 7 additions & 5 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
Optional,
)

from deepmd.utils.path import (
DPPath,
)


def make_base_descriptor(
t_tensor,
Expand Down Expand Up @@ -69,14 +73,12 @@ def distinguish_types(self) -> bool:
"""
pass

def compute_input_stats(self, merged):
def compute_input_stats(
self, merged: List[dict], path: Optional[DPPath] = None
):
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

def init_desc_stat(self, **kwargs):
"""Initialize the model bias by the statistics."""
raise NotImplementedError

@abstractmethod
def fwd(
self,
Expand Down
10 changes: 5 additions & 5 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import numpy as np

from deepmd.utils.path import (
DPPath,
)

try:
from deepmd._version import version as __version__
except ImportError:
Expand Down Expand Up @@ -229,14 +233,10 @@ def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes

def compute_input_stats(self, merged):
def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):
"""Initialize the model bias by the statistics."""
raise NotImplementedError

def cal_g(
self,
ss,
Expand Down
36 changes: 15 additions & 21 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Union,
)

import h5py
import torch
import torch.distributed as dist
import torch.version
Expand Down Expand Up @@ -46,12 +47,6 @@
from deepmd.pt.infer import (
inference,
)
from deepmd.pt.model.descriptor import (
Descriptor,
)
from deepmd.pt.model.task import (
Fitting,
)
from deepmd.pt.train import (
training,
)
Expand All @@ -69,7 +64,9 @@
)
from deepmd.pt.utils.stat import (
make_stat_input,
process_stat_path,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter

Expand Down Expand Up @@ -134,19 +131,16 @@
# noise_settings = None

# stat files
hybrid_descrpt = model_params_single["descriptor"]["type"] == "hybrid"
if not hybrid_descrpt:
stat_file_path_single, has_stat_file_path = process_stat_path(
data_dict_single.get("stat_file", None),
data_dict_single.get("stat_file_dir", f"stat_files{suffix}"),
model_params_single,
Descriptor,
Fitting,
)
else: ### TODO hybrid descriptor not implemented
raise NotImplementedError(
"data stat for hybrid descriptor is not implemented!"
)
stat_file_path_single = data_dict_single.get("stat_file", None)
if stat_file_path_single is not None:
if Path(stat_file_path_single).is_dir():
raise ValueError(

Check warning on line 137 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L136-L137

Added lines #L136 - L137 were not covered by tests
f"stat_file should be a file, not a directory: {stat_file_path_single}"
)
if not Path(stat_file_path_single).is_file():
with h5py.File(stat_file_path_single, "w") as f:
pass
stat_file_path_single = DPPath(stat_file_path_single, "a")

Check warning on line 143 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L140-L143

Added lines #L140 - L143 were not covered by tests

# validation and training data
validation_data_single = DpLoaderSet(
Expand All @@ -156,7 +150,7 @@
type_split=type_split,
noise_settings=noise_settings,
)
if ckpt or finetune_model or has_stat_file_path:
if ckpt or finetune_model:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
Expand Down
37 changes: 12 additions & 25 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
import os
import sys
from typing import (
Dict,
List,
Optional,
Union,
)

import torch
Expand All @@ -24,6 +22,9 @@
from deepmd.pt.utils.utils import (
dict_to_device,
)
from deepmd.utils.path import (
DPPath,
)

from .base_atomic_model import (
BaseAtomicModel,
Expand Down Expand Up @@ -160,9 +161,8 @@

def compute_or_load_stat(
self,
type_map: Optional[List[str]] = None,
sampled=None,
stat_file_path_dict: Optional[Dict[str, Union[str, List[str]]]] = None,
sampled,
stat_file_path: Optional[DPPath] = None,
):
"""
Compute or load the statistics parameters of the model,
Expand All @@ -174,31 +174,18 @@

Parameters
----------
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
sampled
The sampled data frames from different data systems.
stat_file_path_dict
stat_file_path
The dictionary of paths to the statistics files.
"""
if sampled is not None: # move data to device
for data_sys in sampled:
dict_to_device(data_sys)
if stat_file_path_dict is not None:
if not isinstance(stat_file_path_dict["descriptor"], list):
stat_file_dir = os.path.dirname(stat_file_path_dict["descriptor"])
else:
stat_file_dir = os.path.dirname(stat_file_path_dict["descriptor"][0])
if not os.path.exists(stat_file_dir):
os.mkdir(stat_file_dir)
self.descriptor.compute_or_load_stat(
type_map, sampled, stat_file_path_dict["descriptor"]
)
for data_sys in sampled:
dict_to_device(data_sys)
if sampled is None:
sampled = []

Check warning on line 185 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L185

Added line #L185 was not covered by tests
self.descriptor.compute_input_stats(sampled, stat_file_path)
if self.fitting_net is not None:
self.fitting_net.compute_or_load_stat(
type_map, sampled, stat_file_path_dict["fitting_net"]
)
self.fitting_net.compute_output_stats(sampled, stat_file_path)

@torch.jit.export
def get_dim_fparam(self) -> int:
Expand Down
2 changes: 0 additions & 2 deletions deepmd/pt/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .descriptor import (
Descriptor,
DescriptorBlock,
compute_std,
make_default_type_embedding,
)
from .dpa1 import (
Expand Down Expand Up @@ -32,7 +31,6 @@
__all__ = [
"Descriptor",
"DescriptorBlock",
"compute_std",
"make_default_type_embedding",
"DescrptBlockSeA",
"DescrptBlockSeAtten",
Expand Down
Loading