Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add DOSnet training in PT #3486

Merged
merged 31 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f1a3a0d
feat: add dos training
anyangml Mar 18, 2024
d4a3965
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2024
3e1f3c6
fix: precommit
anyangml Mar 18, 2024
c4769ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2024
429f03f
feat: add dos stat
anyangml Mar 19, 2024
91d2a8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
04c7477
fix: training test
anyangml Mar 19, 2024
4d95548
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
d71a117
Merge branch 'devel' into feat/dos-train
anyangml Mar 19, 2024
bc54b68
fix: precommit
anyangml Mar 19, 2024
850ea1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
004cdb2
fix: UTs
anyangml Mar 19, 2024
a116235
fix: UTs
anyangml Mar 19, 2024
dbd2d29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
2ff0ae6
fix: stat
anyangml Mar 20, 2024
b7b69fd
Merge branch 'devel' into feat/dos-train
anyangml Mar 20, 2024
366f6b4
fix: stat
anyangml Mar 20, 2024
915141c
fix: dp test
anyangml Mar 20, 2024
ed65e19
fix: test examples
anyangml Mar 20, 2024
a73d392
fix UTs
anyangml Mar 20, 2024
bf8fac2
Merge branch 'devel' into feat/dos-train
anyangml Mar 20, 2024
3b5be19
Merge branch 'devel' into feat/dos-train
anyangml Mar 20, 2024
1630800
fix: add to test_examples
Mar 20, 2024
be076af
fix: update loss
Mar 20, 2024
36d1674
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2024
c156b02
fix: add numb_dos to jit model
anyangml Mar 20, 2024
148c196
chore: refactor
anyangml Mar 20, 2024
3ad66fb
fix: UTs
anyangml Mar 20, 2024
b751921
Merge branch 'devel' into feat/dos-train
anyangml Mar 22, 2024
2e14755
fix: loss
anyangml Mar 24, 2024
3c5e6f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/pt/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from .denoise import (
DenoiseLoss,
)
from .dos import (
DOSLoss,
)
from .ener import (
EnergyStdLoss,
)
Expand All @@ -21,4 +24,5 @@
"EnergySpinLoss",
"TensorLoss",
"TaskLoss",
"DOSLoss",
]
228 changes: 228 additions & 0 deletions deepmd/pt/loss/dos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
)

import torch

from deepmd.pt.loss.loss import (
TaskLoss,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.data import (
DataRequirementItem,
)


class DOSLoss(TaskLoss):
def __init__(
self,
starter_learning_rate: float,
numb_dos: int,
start_pref_dos: float = 1.00,
limit_pref_dos: float = 1.00,
start_pref_cdf: float = 1000,
limit_pref_cdf: float = 1.00,
start_pref_ados: float = 0.0,
limit_pref_ados: float = 0.0,
start_pref_acdf: float = 0.0,
limit_pref_acdf: float = 0.0,
inference=False,
**kwargs,
):
r"""Construct a loss for local and global tensors.

Parameters
----------
tensor_name : str
The name of the tensor in the model predictions to compute the loss.
tensor_size : int
The size (dimension) of the tensor.
label_name : str
The name of the tensor in the labels to compute the loss.
pref_atomic : float
The prefactor of the weight of atomic loss. It should be larger than or equal to 0.
pref : float
The prefactor of the weight of global loss. It should be larger than or equal to 0.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
**kwargs
Other keyword arguments.
"""
super().__init__()
self.starter_learning_rate = starter_learning_rate
self.numb_dos = numb_dos
self.inference = inference

self.start_pref_dos = start_pref_dos
self.limit_pref_dos = limit_pref_dos
self.start_pref_cdf = start_pref_cdf
self.limit_pref_cdf = limit_pref_cdf

self.start_pref_ados = start_pref_ados
self.limit_pref_ados = limit_pref_ados
self.start_pref_acdf = start_pref_acdf
self.limit_pref_acdf = limit_pref_acdf

assert (
self.start_pref_dos >= 0.0
and self.limit_pref_dos >= 0.0
and self.start_pref_cdf >= 0.0
and self.limit_pref_cdf >= 0.0
and self.start_pref_ados >= 0.0
and self.limit_pref_ados >= 0.0
and self.start_pref_acdf >= 0.0
and self.limit_pref_acdf >= 0.0
), "Can not assign negative weight to `pref` and `pref_atomic`"

self.has_dos = (start_pref_dos != 0.0 and limit_pref_dos != 0.0) or inference
self.has_cdf = (start_pref_cdf != 0.0 and limit_pref_cdf != 0.0) or inference
self.has_ados = (start_pref_ados != 0.0 and limit_pref_ados != 0.0) or inference
self.has_acdf = (start_pref_acdf != 0.0 and limit_pref_acdf != 0.0) or inference

assert (
self.has_dos or self.has_cdf or self.has_ados or self.has_acdf
), AssertionError("Can not assian zero weight both to `pref` and `pref_atomic`")

def forward(self, model_pred, label, natoms, learning_rate=0.0, mae=False):
"""Return loss on local and global tensors.

Parameters
----------
model_pred : dict[str, torch.Tensor]
Model predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.

Returns
-------
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
coef = learning_rate / self.starter_learning_rate
pref_dos = (
Fixed Show fixed Hide fixed
self.limit_pref_dos + (self.start_pref_dos - self.limit_pref_dos) * coef
)
pref_cdf = (
Fixed Show fixed Hide fixed
self.limit_pref_cdf + (self.start_pref_cdf - self.limit_pref_cdf) * coef
)
pref_ados = (
self.limit_pref_ados + (self.start_pref_ados - self.limit_pref_ados) * coef
)
pref_acdf = (
self.limit_pref_acdf + (self.start_pref_acdf - self.limit_pref_acdf) * coef
)

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
if self.has_ados and "dos" in model_pred and "atomic_dos" in label:
local_tensor_pred_dos = model_pred["dos"].reshape(
[-1, natoms, self.numb_dos]
)
local_tensor_label_dos = label["atomic_dos"].reshape(
[-1, natoms, self.numb_dos]
)
diff = (local_tensor_pred_dos - local_tensor_label_dos).reshape(
[-1, self.numb_dos]
)
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss_dos = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_local_dos_loss"] = l2_local_loss_dos.detach()
loss += pref_ados * l2_local_loss_dos
rmse_local_dos = l2_local_loss_dos.sqrt()
more_loss["rmse_local_dos"] = rmse_local_dos.detach()
if self.has_acdf and "dos" in model_pred and "atom_dos" in label:
local_tensor_pred_cdf = torch.cusum(
model_pred["dos"].reshape([-1, natoms, self.numb_dos]), dim=-1
)
local_tensor_label_cdf = torch.cusum(
label["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1
)
diff = (local_tensor_pred_cdf - local_tensor_label_cdf).reshape(
[-1, self.numb_dos]
)
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss_cdf = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_local_cdf_loss"] = l2_local_loss_cdf.detach()
anyangml marked this conversation as resolved.
Show resolved Hide resolved
loss += pref_acdf * l2_local_loss_cdf
rmse_local_cdf = l2_local_loss_cdf.sqrt()
more_loss["rmse_local_cdf"] = rmse_local_cdf.detach()
if self.has_dos and "global_dos" in model_pred and "dos" in label:
global_tensor_pred_dos = model_pred["global_dos"].reshape(
[-1, self.numb_dos]
)
global_tensor_label_dos = label["dos"].reshape([-1, self.numb_dos])
diff = global_tensor_pred_dos - global_tensor_label_dos
if "mask" in model_pred:
atom_num = model_pred["mask"].sum(-1, keepdim=True)
l2_global_loss_dos = torch.mean(
torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum()
)
atom_num = torch.mean(atom_num.float())
else:
atom_num = natoms
l2_global_loss_dos = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_global_dos_loss"] = l2_global_loss_dos.detach()
loss += pre_dos * l2_global_loss_dos
rmse_global_dos = l2_global_loss_dos.sqrt() / atom_num
more_loss["rmse_global_dos"] = rmse_global_dos.detach()
if self.has_cdf and "global_dos" in model_pred and "dos" in label:
global_tensor_pred_cdf = torch.cusum(
model_pred["global_dos"].reshape([-1, self.numb_dos]), dim - 1
)
global_tensor_label_cdf = torch.cusum(
label["dos"].reshape([-1, self.numb_dos]), dim=-1
)
diff = global_tensor_pred_cdf - global_tensor_label_cdf
if "mask" in model_pred:
atom_num = model_pred["mask"].sum(-1, keepdim=True)
l2_global_loss_cdf = torch.mean(
torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum()
)
atom_num = torch.mean(atom_num.float())
else:
atom_num = natoms
l2_global_loss_cdf = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_global_cdf_loss"] = l2_global_loss_cdf.detach()
loss += pre_cdf * l2_global_loss_cdf
rmse_global_dos = l2_global_loss_cdf.sqrt() / atom_num
more_loss["rmse_global_cdf"] = rmse_global_dos.detach()
return loss, more_loss

@property
def label_requirement(self) -> List[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
label_requirement = []
if self.has_ados or self.has_acdf:
label_requirement.append(
DataRequirementItem(
"atom_dos",
ndof=self.numb_dos,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_dos or self.has_cdf:
label_requirement.append(
DataRequirementItem(
"dos",
ndof=self.numb_dos,
atomic=False,
must=False,
high_prec=False,
)
)
return label_requirement
4 changes: 3 additions & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from deepmd.pt.loss import (
DenoiseLoss,
DOSLoss,
EnergySpinLoss,
EnergyStdLoss,
TensorLoss,
Expand Down Expand Up @@ -276,7 +277,8 @@ def get_loss(loss_params, start_lr, _ntypes, _model):
return EnergyStdLoss(**loss_params)
elif loss_type == "dos":
loss_params["starter_learning_rate"] = start_lr
raise NotImplementedError()
loss_params["numb_dos"] = _model.model_output_def()["dos"].output_size
return DOSLoss(**loss_params)
elif loss_type == "ener_spin":
loss_params["starter_learning_rate"] = start_lr
return EnergySpinLoss(**loss_params)
Expand Down
Loading