Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add DOSnet training in PT #3486

Merged
merged 31 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f1a3a0d
feat: add dos training
anyangml Mar 18, 2024
d4a3965
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2024
3e1f3c6
fix: precommit
anyangml Mar 18, 2024
c4769ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2024
429f03f
feat: add dos stat
anyangml Mar 19, 2024
91d2a8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
04c7477
fix: training test
anyangml Mar 19, 2024
4d95548
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
d71a117
Merge branch 'devel' into feat/dos-train
anyangml Mar 19, 2024
bc54b68
fix: precommit
anyangml Mar 19, 2024
850ea1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
004cdb2
fix: UTs
anyangml Mar 19, 2024
a116235
fix: UTs
anyangml Mar 19, 2024
dbd2d29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
2ff0ae6
fix: stat
anyangml Mar 20, 2024
b7b69fd
Merge branch 'devel' into feat/dos-train
anyangml Mar 20, 2024
366f6b4
fix: stat
anyangml Mar 20, 2024
915141c
fix: dp test
anyangml Mar 20, 2024
ed65e19
fix: test examples
anyangml Mar 20, 2024
a73d392
fix UTs
anyangml Mar 20, 2024
bf8fac2
Merge branch 'devel' into feat/dos-train
anyangml Mar 20, 2024
3b5be19
Merge branch 'devel' into feat/dos-train
anyangml Mar 20, 2024
1630800
fix: add to test_examples
Mar 20, 2024
be076af
fix: update loss
Mar 20, 2024
36d1674
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2024
c156b02
fix: add numb_dos to jit model
anyangml Mar 20, 2024
148c196
chore: refactor
anyangml Mar 20, 2024
3ad66fb
fix: UTs
anyangml Mar 20, 2024
b751921
Merge branch 'devel' into feat/dos-train
anyangml Mar 22, 2024
2e14755
fix: loss
anyangml Mar 24, 2024
3c5e6f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@

def get_numb_dos(self) -> int:
"""Get the number of DOS."""
return 0
return self.dp.model["Default"].get_numb_dos()

Check warning on line 197 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L197

Added line #L197 was not covered by tests

def get_has_efield(self):
"""Check if the model has efield."""
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from .denoise import (
DenoiseLoss,
)
from .dos import (
DOSLoss,
)
from .ener import (
EnergyStdLoss,
)
Expand All @@ -21,4 +24,5 @@
"EnergySpinLoss",
"TensorLoss",
"TaskLoss",
"DOSLoss",
]
256 changes: 256 additions & 0 deletions deepmd/pt/loss/dos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
)

import torch

from deepmd.pt.loss.loss import (
TaskLoss,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.data import (
DataRequirementItem,
)


class DOSLoss(TaskLoss):
def __init__(
self,
starter_learning_rate: float,
numb_dos: int,
start_pref_dos: float = 1.00,
limit_pref_dos: float = 1.00,
start_pref_cdf: float = 1000,
limit_pref_cdf: float = 1.00,
start_pref_ados: float = 0.0,
limit_pref_ados: float = 0.0,
start_pref_acdf: float = 0.0,
limit_pref_acdf: float = 0.0,
inference=False,
**kwargs,
):
r"""Construct a loss for local and global tensors.

Parameters
----------
tensor_name : str
The name of the tensor in the model predictions to compute the loss.
tensor_size : int
The size (dimension) of the tensor.
label_name : str
The name of the tensor in the labels to compute the loss.
pref_atomic : float
The prefactor of the weight of atomic loss. It should be larger than or equal to 0.
pref : float
The prefactor of the weight of global loss. It should be larger than or equal to 0.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
**kwargs
Other keyword arguments.
"""
super().__init__()
self.starter_learning_rate = starter_learning_rate
self.numb_dos = numb_dos
self.inference = inference

self.start_pref_dos = start_pref_dos
self.limit_pref_dos = limit_pref_dos
self.start_pref_cdf = start_pref_cdf
self.limit_pref_cdf = limit_pref_cdf

self.start_pref_ados = start_pref_ados
self.limit_pref_ados = limit_pref_ados
self.start_pref_acdf = start_pref_acdf
self.limit_pref_acdf = limit_pref_acdf

assert (
self.start_pref_dos >= 0.0
and self.limit_pref_dos >= 0.0
and self.start_pref_cdf >= 0.0
and self.limit_pref_cdf >= 0.0
and self.start_pref_ados >= 0.0
and self.limit_pref_ados >= 0.0
and self.start_pref_acdf >= 0.0
and self.limit_pref_acdf >= 0.0
), "Can not assign negative weight to `pref` and `pref_atomic`"

self.has_dos = (start_pref_dos != 0.0 and limit_pref_dos != 0.0) or inference
self.has_cdf = (start_pref_cdf != 0.0 and limit_pref_cdf != 0.0) or inference
self.has_ados = (start_pref_ados != 0.0 and limit_pref_ados != 0.0) or inference
self.has_acdf = (start_pref_acdf != 0.0 and limit_pref_acdf != 0.0) or inference

assert (
self.has_dos or self.has_cdf or self.has_ados or self.has_acdf
), AssertionError("Can not assian zero weight both to `pref` and `pref_atomic`")

def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False):
"""Return loss on local and global tensors.

Parameters
----------
input_dict : dict[str, torch.Tensor]
Model inputs.
model : torch.nn.Module
Model to be used to output the predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.

Returns
-------
model_pred: dict[str, torch.Tensor]
Model predictions.
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
model_pred = model(**input_dict)

coef = learning_rate / self.starter_learning_rate
pref_dos = (
Fixed Show fixed Hide fixed
self.limit_pref_dos + (self.start_pref_dos - self.limit_pref_dos) * coef
)
pref_cdf = (
Fixed Show fixed Hide fixed
self.limit_pref_cdf + (self.start_pref_cdf - self.limit_pref_cdf) * coef
)
pref_ados = (
self.limit_pref_ados + (self.start_pref_ados - self.limit_pref_ados) * coef
)
pref_acdf = (
self.limit_pref_acdf + (self.start_pref_acdf - self.limit_pref_acdf) * coef
)

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
if self.has_ados and "atom_dos" in model_pred and "atom_dos" in label:
find_local = label.get("find_atom_dos", 0.0)
pref_ados = pref_ados * find_local
local_tensor_pred_dos = model_pred["atom_dos"].reshape(
[-1, natoms, self.numb_dos]
)
local_tensor_label_dos = label["atom_dos"].reshape(
[-1, natoms, self.numb_dos]
)
diff = (local_tensor_pred_dos - local_tensor_label_dos).reshape(
[-1, self.numb_dos]
)
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss_dos = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_local_dos_loss"] = self.display_if_exist(
l2_local_loss_dos.detach(), find_local
)
loss += pref_ados * l2_local_loss_dos
rmse_local_dos = l2_local_loss_dos.sqrt()
more_loss["rmse_local_dos"] = self.display_if_exist(
rmse_local_dos.detach(), find_local
)
if self.has_acdf and "atom_dos" in model_pred and "atom_dos" in label:
find_local = label.get("find_atom_dos", 0.0)
pref_acdf = pref_acdf * find_local
local_tensor_pred_cdf = torch.cusum(

Check warning on line 157 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L155-L157

Added lines #L155 - L157 were not covered by tests
model_pred["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1
)
local_tensor_label_cdf = torch.cusum(

Check warning on line 160 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L160

Added line #L160 was not covered by tests
label["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1
)
diff = (local_tensor_pred_cdf - local_tensor_label_cdf).reshape(

Check warning on line 163 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L163

Added line #L163 was not covered by tests
[-1, self.numb_dos]
)
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss_cdf = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_local_cdf_loss"] = self.display_if_exist(

Check warning on line 170 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L166-L170

Added lines #L166 - L170 were not covered by tests
l2_local_loss_cdf.detach(), find_local
)
loss += pref_acdf * l2_local_loss_cdf
rmse_local_cdf = l2_local_loss_cdf.sqrt()
more_loss["rmse_local_cdf"] = self.display_if_exist(

Check warning on line 175 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L173-L175

Added lines #L173 - L175 were not covered by tests
rmse_local_cdf.detach(), find_local
)
if self.has_dos and "dos" in model_pred and "dos" in label:
find_global = label.get("find_dos", 0.0)
pref_dos = pref_dos * find_global
global_tensor_pred_dos = model_pred["dos"].reshape([-1, self.numb_dos])
global_tensor_label_dos = label["dos"].reshape([-1, self.numb_dos])
diff = global_tensor_pred_dos - global_tensor_label_dos
if "mask" in model_pred:
atom_num = model_pred["mask"].sum(-1, keepdim=True)
l2_global_loss_dos = torch.mean(
torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum()
)
atom_num = torch.mean(atom_num.float())
else:
atom_num = natoms
l2_global_loss_dos = torch.mean(torch.square(diff))

Check warning on line 192 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L191-L192

Added lines #L191 - L192 were not covered by tests
if not self.inference:
more_loss["l2_global_dos_loss"] = self.display_if_exist(
l2_global_loss_dos.detach(), find_global
)
loss += pref_dos * l2_global_loss_dos
rmse_global_dos = l2_global_loss_dos.sqrt() / atom_num
more_loss["rmse_global_dos"] = self.display_if_exist(
rmse_global_dos.detach(), find_global
)
if self.has_cdf and "dos" in model_pred and "dos" in label:
find_global = label.get("find_dos", 0.0)
pref_cdf = pref_cdf * find_global
global_tensor_pred_cdf = torch.cusum(

Check warning on line 205 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L203-L205

Added lines #L203 - L205 were not covered by tests
model_pred["dos"].reshape([-1, self.numb_dos]), dim=-1
)
global_tensor_label_cdf = torch.cusum(

Check warning on line 208 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L208

Added line #L208 was not covered by tests
label["dos"].reshape([-1, self.numb_dos]), dim=-1
)
diff = global_tensor_pred_cdf - global_tensor_label_cdf
if "mask" in model_pred:
atom_num = model_pred["mask"].sum(-1, keepdim=True)
l2_global_loss_cdf = torch.mean(

Check warning on line 214 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L211-L214

Added lines #L211 - L214 were not covered by tests
torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum()
)
atom_num = torch.mean(atom_num.float())

Check warning on line 217 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L217

Added line #L217 was not covered by tests
else:
atom_num = natoms
l2_global_loss_cdf = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_global_cdf_loss"] = self.display_if_exist(

Check warning on line 222 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L219-L222

Added lines #L219 - L222 were not covered by tests
l2_global_loss_cdf.detach(), find_global
)
loss += pref_cdf * l2_global_loss_cdf
rmse_global_dos = l2_global_loss_cdf.sqrt() / atom_num
more_loss["rmse_global_cdf"] = self.display_if_exist(

Check warning on line 227 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L225-L227

Added lines #L225 - L227 were not covered by tests
rmse_global_dos.detach(), find_global
)
return model_pred, loss, more_loss

@property
def label_requirement(self) -> List[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
label_requirement = []
if self.has_ados or self.has_acdf:
label_requirement.append(
DataRequirementItem(
"atom_dos",
ndof=self.numb_dos,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_dos or self.has_cdf:
label_requirement.append(
DataRequirementItem(
"dos",
ndof=self.numb_dos,
atomic=False,
must=False,
high_prec=False,
)
)
return label_requirement
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/dos_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def forward(
model_predict["updated_coord"] += coord
return model_predict

@torch.jit.export
def get_numb_dos(self) -> int:
"""Get the number of DOS for DOSFittingNet."""
return self.get_fitting_net().dim_out

@torch.jit.export
def forward_lower(
self,
Expand Down
66 changes: 66 additions & 0 deletions deepmd/pt/model/task/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import copy
import logging
from typing import (
Callable,
List,
Optional,
Union,
)

import numpy as np
import torch

from deepmd.dpmodel import (
Expand All @@ -28,6 +30,13 @@
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.out_stat import (
compute_stats_from_atomic,
compute_stats_from_redu,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -96,6 +105,63 @@
]
)

def compute_output_stats(
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
self,
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
) -> None:
"""
Compute the output statistics (e.g. dos bias) for the fitting net from packed data.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.

"""
if stat_file_path is not None:
stat_file_path = stat_file_path / "bias_dos"

Check warning on line 130 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L130

Added line #L130 was not covered by tests
if stat_file_path is not None and stat_file_path.is_file():
bias_dos = stat_file_path.load_numpy()

Check warning on line 132 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L132

Added line #L132 was not covered by tests
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
else:
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged

Check warning on line 138 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L138

Added line #L138 was not covered by tests
for sys in range(len(sampled)):
nframs = sampled[sys]["atype"].shape[0]

if "atom_dos" in sampled[sys]:
bias_dos = compute_stats_from_atomic(
sampled[sys]["atom_dos"].numpy(force=True),
sampled[sys]["atype"].numpy(force=True),
)[0]
else:
sys_type_count = np.zeros(

Check warning on line 148 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L148

Added line #L148 was not covered by tests
(nframs, self.ntypes), dtype=env.GLOBAL_NP_FLOAT_PRECISION
)
for itype in range(self.ntypes):
type_mask = sampled[sys]["atype"] == itype
sys_type_count[:, itype] = type_mask.sum(dim=1).numpy(

Check warning on line 153 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L151-L153

Added lines #L151 - L153 were not covered by tests
force=True
)
sys_bias_redu = sampled[sys]["dos"].numpy(force=True)

Check warning on line 156 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L156

Added line #L156 was not covered by tests

bias_dos = compute_stats_from_redu(

Check warning on line 158 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L158

Added line #L158 was not covered by tests
sys_bias_redu, sys_type_count, rcond=self.rcond
)[0]
if stat_file_path is not None:
stat_file_path.save_numpy(bias_dos)

Check warning on line 162 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L162

Added line #L162 was not covered by tests
self.bias_dos = torch.tensor(bias_dos, device=env.DEVICE)

@classmethod
def deserialize(cls, data: dict) -> "DOSFittingNet":
data = copy.deepcopy(data)
Expand Down
Loading