Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: refact training code #3359

Merged
merged 47 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3812866
Fix single-task training&data stat
iProzd Feb 28, 2024
08e18fe
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
ae27607
Fix EnergyFittingNetDirect
iProzd Feb 28, 2024
7f573ab
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
f9265d5
Add data_requirement for dataloader
iProzd Feb 28, 2024
f8d2980
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
c9eb767
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 28, 2024
00105c7
Update make_base_descriptor.py
iProzd Feb 28, 2024
5a9df83
Update typing
iProzd Feb 28, 2024
75da5b1
Update training.py
iProzd Feb 28, 2024
6c171c5
Fix uts
iProzd Feb 28, 2024
2e87e1d
Fix uts
iProzd Feb 28, 2024
eb8094d
Merge branch 'devel' into train_rf
iProzd Feb 28, 2024
2618d98
Support multi-task training
iProzd Feb 28, 2024
f1585b2
Take advice from QL scan
iProzd Feb 28, 2024
463f9fb
Support no validation
iProzd Feb 28, 2024
e8575af
Update se_r.py
iProzd Feb 28, 2024
66d03b8
omit data prob log
iProzd Feb 28, 2024
e9e0d95
omit seed log
iProzd Feb 28, 2024
90be50e
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
ab35653
Add fparam and aparam
iProzd Feb 29, 2024
64d6079
Add type hint for `Callable`
iProzd Feb 29, 2024
6020a2b
Fix nopbc
iProzd Feb 29, 2024
5db7883
Add DataRequirementItem
iProzd Feb 29, 2024
c03a5ba
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
cce52da
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 29, 2024
18cbf9e
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
cdcfcb2
Fix neighbor-stat for multitask (#31)
iProzd Feb 29, 2024
a7d44d1
Revert "Fix neighbor-stat for multitask (#31)"
iProzd Feb 29, 2024
fdca653
Move label requirement to loss func
iProzd Feb 29, 2024
525ce93
resolve conversations
iProzd Feb 29, 2024
46ee16c
set label_requirement abstractmethod
iProzd Feb 29, 2024
9d18dc4
make label_requirement dynamic
iProzd Feb 29, 2024
ad7227d
update docs
iProzd Feb 29, 2024
35598d2
replace lazy with functools.lru_cache
iProzd Feb 29, 2024
c0a0cfc
Update training.py
iProzd Feb 29, 2024
d50e2a2
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
66edca5
Update deepmd/pt/train/training.py
wanghan-iapcm Feb 29, 2024
d5a1549
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 29, 2024
c51f865
Merge branch 'devel' into train_rf
iProzd Feb 29, 2024
e17546a
Update test_multitask.py
iProzd Feb 29, 2024
1debf4f
Fix h5py files in multitask DDP
iProzd Feb 29, 2024
db31edc
FIx h5py file read block
iProzd Feb 29, 2024
60dda49
Merge branch 'devel' into train_rf
iProzd Mar 1, 2024
3dfc31e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2024
615446f
Update hybrid.py
iProzd Mar 1, 2024
e26c118
Update hybrid.py
iProzd Mar 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
abstractmethod,
)
from typing import (
Callable,
List,
Optional,
Union,
)

from deepmd.common import (
Expand Down Expand Up @@ -84,8 +86,19 @@
"""
pass

@abstractmethod
def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
pass

Check warning on line 96 in deepmd/dpmodel/descriptor/make_base_descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/make_base_descriptor.py#L96

Added line #L96 was not covered by tests

def compute_input_stats(
self, merged: List[dict], path: Optional[DPPath] = None
self,
merged: Union[Callable[[], List[dict]], List[dict]],
iProzd marked this conversation as resolved.
Show resolved Hide resolved
path: Optional[DPPath] = None,
):
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@
"""
return False

def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
raise NotImplementedError

Check warning on line 252 in deepmd/dpmodel/descriptor/se_e2_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_e2_a.py#L252

Added line #L252 was not covered by tests

def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@
"""
return False

def share_params(self, base_class, shared_level, resume=False):
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
raise NotImplementedError

Check warning on line 212 in deepmd/dpmodel/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_r.py#L212

Added line #L212 was not covered by tests

def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
Expand Down
58 changes: 13 additions & 45 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.pt.utils.stat import (
make_stat_input,
)
from deepmd.utils.argcheck import (
normalize,
)
Expand Down Expand Up @@ -97,7 +94,6 @@ def get_trainer(
multi_task=multi_task,
model_branch=model_branch,
)
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])
Expand All @@ -109,26 +105,11 @@ def prepare_trainer_input_single(
type_split = False
if model_params_single["descriptor"]["type"] in ["se_e2_a"]:
type_split = True
validation_dataset_params = data_dict_single["validation_data"]
validation_dataset_params = data_dict_single.get("validation_data", None)
validation_systems = (
validation_dataset_params["systems"] if validation_dataset_params else None
)
training_systems = training_dataset_params["systems"]
validation_systems = validation_dataset_params["systems"]

# noise params
noise_settings = None
if loss_dict_single.get("type", "ener") == "denoise":
noise_settings = {
"noise_type": loss_dict_single.pop("noise_type", "uniform"),
"noise": loss_dict_single.pop("noise", 1.0),
"noise_mode": loss_dict_single.pop("noise_mode", "fix_num"),
"mask_num": loss_dict_single.pop("mask_num", 8),
"mask_prob": loss_dict_single.pop("mask_prob", 0.15),
"same_mask": loss_dict_single.pop("same_mask", False),
"mask_coord": loss_dict_single.pop("mask_coord", False),
"mask_type": loss_dict_single.pop("mask_type", False),
"max_fail_num": loss_dict_single.pop("max_fail_num", 10),
"mask_type_idx": len(model_params_single["type_map"]) - 1,
}
# noise_settings = None

# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
Expand All @@ -143,59 +124,47 @@ def prepare_trainer_input_single(
stat_file_path_single = DPPath(stat_file_path_single, "a")

# validation and training data
validation_data_single = DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single,
validation_data_single = (
DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single,
)
if validation_systems
else None
)
if ckpt or finetune_model:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
sampled_single = None
else:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
data_stat_nbatch = model_params_single.get("data_stat_nbatch", 10)
sampled_single = make_stat_input(
train_data_single.systems,
train_data_single.dataloaders,
data_stat_nbatch,
)
if noise_settings is not None:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
)
return (
train_data_single,
validation_data_single,
sampled_single,
stat_file_path_single,
)

if not multi_task:
(
train_data,
validation_data,
sampled,
stat_file_path,
) = prepare_trainer_input_single(
config["model"], config["training"], config["loss"]
)
else:
train_data, validation_data, sampled, stat_file_path = {}, {}, {}, {}
train_data, validation_data, stat_file_path = {}, {}, {}
for model_key in config["model"]["model_dict"]:
(
train_data[model_key],
validation_data[model_key],
sampled[model_key],
stat_file_path[model_key],
) = prepare_trainer_input_single(
config["model"]["model_dict"][model_key],
Expand All @@ -207,7 +176,6 @@ def prepare_trainer_input_single(
trainer = training.Trainer(
config,
train_data,
sampled=sampled,
stat_file_path=stat_file_path,
validation_data=validation_data,
init_model=init_model,
Expand Down
107 changes: 106 additions & 1 deletion deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
)

import torch
import torch.nn.functional as F

Expand All @@ -11,6 +15,9 @@
from deepmd.pt.utils.env import (
GLOBAL_PT_FLOAT_PRECISION,
)
from deepmd.utils.data import (
DataRequirementItem,
)


class EnergyStdLoss(TaskLoss):
Expand All @@ -23,16 +30,57 @@
limit_pref_f=0.0,
start_pref_v=0.0,
limit_pref_v=0.0,
start_pref_ae: float = 0.0,
limit_pref_ae: float = 0.0,
start_pref_pf: float = 0.0,
limit_pref_pf: float = 0.0,
use_l1_all: bool = False,
inference=False,
**kwargs,
):
"""Construct a layer to compute loss on energy, force and virial."""
r"""Construct a layer to compute loss on energy, force and virial.

Parameters
----------
starter_learning_rate : float
The learning rate at the start of the training.
start_pref_e : float
The prefactor of energy loss at the start of the training.
limit_pref_e : float
The prefactor of energy loss at the end of the training.
start_pref_f : float
The prefactor of force loss at the start of the training.
limit_pref_f : float
The prefactor of force loss at the end of the training.
start_pref_v : float
The prefactor of virial loss at the start of the training.
limit_pref_v : float
The prefactor of virial loss at the end of the training.
start_pref_ae : float
The prefactor of atomic energy loss at the start of the training.
limit_pref_ae : float
The prefactor of atomic energy loss at the end of the training.
start_pref_pf : float
The prefactor of atomic prefactor force loss at the start of the training.
limit_pref_pf : float
The prefactor of atomic prefactor force loss at the end of the training.
use_l1_all : bool
Whether to use L1 loss, if False (default), it will use L2 loss.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
**kwargs
Other keyword arguments.
"""
super().__init__()
self.starter_learning_rate = starter_learning_rate
self.has_e = (start_pref_e != 0.0 and limit_pref_e != 0.0) or inference
self.has_f = (start_pref_f != 0.0 and limit_pref_f != 0.0) or inference
self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference

# TODO need support for atomic energy and atomic pref
self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference
self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference

self.start_pref_e = start_pref_e
self.limit_pref_e = limit_pref_e
self.start_pref_f = start_pref_f
Expand Down Expand Up @@ -153,3 +201,60 @@
if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return loss, more_loss

@property
def label_requirement(self) -> List[DataRequirementItem]:
iProzd marked this conversation as resolved.
Show resolved Hide resolved
"""Return data label requirements needed for this loss calculation."""
label_requirement = []
if self.has_e:
label_requirement.append(
DataRequirementItem(
"energy",
ndof=1,
atomic=False,
must=False,
high_prec=True,
)
)
if self.has_f:
label_requirement.append(
DataRequirementItem(
"force",
ndof=3,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_v:
label_requirement.append(

Check warning on line 230 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L230

Added line #L230 was not covered by tests
DataRequirementItem(
"virial",
ndof=9,
atomic=False,
must=False,
high_prec=False,
)
)
if self.has_ae:
label_requirement.append(

Check warning on line 240 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L240

Added line #L240 was not covered by tests
DataRequirementItem(
"atom_ener",
ndof=1,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_pf:
label_requirement.append(

Check warning on line 250 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L250

Added line #L250 was not covered by tests
DataRequirementItem(
"atom_pref",
ndof=1,
atomic=True,
must=False,
high_prec=False,
repeat=3,
)
)
return label_requirement
20 changes: 19 additions & 1 deletion deepmd/pt/loss/loss.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractmethod,
)
from typing import (
List,
)

import torch

from deepmd.utils.data import (
DataRequirementItem,
)


class TaskLoss(torch.nn.Module):
class TaskLoss(torch.nn.Module, ABC):
def __init__(self, **kwargs):
"""Construct loss."""
super().__init__()

def forward(self, model_pred, label, natoms, learning_rate):
"""Return loss ."""
raise NotImplementedError

@property
@abstractmethod
def label_requirement(self) -> List[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
pass

Check warning on line 30 in deepmd/pt/loss/loss.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/loss.py#L30

Added line #L30 was not covered by tests
Loading