Skip to content

Commit

Permalink
feat(pt): add dpa3 alpha descriptor (deepmodeling#4476)
Browse files Browse the repository at this point in the history
This PR is an early experimental preview version of DPA3. Significant
changes may occur in subsequent updates. Please use with caution.
  • Loading branch information
iProzd authored Dec 24, 2024
2 parents 104fc36 + 1309e26 commit 4e65d8b
Show file tree
Hide file tree
Showing 12 changed files with 2,478 additions and 31 deletions.
114 changes: 114 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


class RepFlowArgs:
def __init__(
self,
n_dim: int = 128,
e_dim: int = 64,
a_dim: int = 64,
nlayers: int = 6,
e_rcut: float = 6.0,
e_rcut_smth: float = 5.0,
e_sel: int = 120,
a_rcut: float = 4.0,
a_rcut_smth: float = 3.5,
a_sel: int = 20,
a_compress_rate: int = 0,
axis_neuron: int = 4,
update_angle: bool = True,
update_style: str = "res_residual",
update_residual: float = 0.1,
update_residual_init: str = "const",
) -> None:
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
Parameters
----------
n_dim : int, optional
The dimension of node representation.
e_dim : int, optional
The dimension of edge representation.
a_dim : int, optional
The dimension of angle representation.
nlayers : int, optional
Number of repflow layers.
e_rcut : float, optional
The edge cut-off radius.
e_rcut_smth : float, optional
Where to start smoothing for edge. For example the 1/r term is smoothed from rcut to rcut_smth.
e_sel : int, optional
Maximally possible number of selected edge neighbors.
a_rcut : float, optional
The angle cut-off radius.
a_rcut_smth : float, optional
Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth.
a_sel : int, optional
Maximally possible number of selected angle neighbors.
a_compress_rate : int, optional
The compression rate for angular messages. The default value is 0, indicating no compression.
If a non-zero integer c is provided, the node and edge dimensions will be compressed
to n_dim/c and e_dim/2c, respectively, within the angular message.
axis_neuron : int, optional
The number of dimension of submatrix in the symmetrization ops.
update_angle : bool, optional
Where to update the angle rep. If not, only node and edge rep will be used.
update_style : str, optional
Style to update a representation.
Supported options are:
-'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n)
-'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n)
-'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n)
where `r1`, `r2` ... `r3` are residual weights defined by `update_residual`
and `update_residual_init`.
update_residual : float, optional
When update using residual mode, the initial std of residual vector weights.
update_residual_init : str, optional
When update using residual mode, the initialization mode of residual vector weights.
"""
self.n_dim = n_dim
self.e_dim = e_dim
self.a_dim = a_dim
self.nlayers = nlayers
self.e_rcut = e_rcut
self.e_rcut_smth = e_rcut_smth
self.e_sel = e_sel
self.a_rcut = a_rcut
self.a_rcut_smth = a_rcut_smth
self.a_sel = a_sel
self.a_compress_rate = a_compress_rate
self.axis_neuron = axis_neuron
self.update_angle = update_angle
self.update_style = update_style
self.update_residual = update_residual
self.update_residual_init = update_residual_init

def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
else:
raise KeyError(key)

def serialize(self) -> dict:
return {
"n_dim": self.n_dim,
"e_dim": self.e_dim,
"a_dim": self.a_dim,
"nlayers": self.nlayers,
"e_rcut": self.e_rcut,
"e_rcut_smth": self.e_rcut_smth,
"e_sel": self.e_sel,
"a_rcut": self.a_rcut,
"a_rcut_smth": self.a_rcut_smth,
"a_sel": self.a_sel,
"a_compress_rate": self.a_compress_rate,
"axis_neuron": self.axis_neuron,
"update_angle": self.update_angle,
"update_style": self.update_style,
"update_residual": self.update_residual,
"update_residual_init": self.update_residual_init,
}

@classmethod
def deserialize(cls, data: dict) -> "RepFlowArgs":
return cls(**data)
71 changes: 40 additions & 31 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,28 +187,26 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
)
# more_loss['log_keys'].append('rmse_e')
else: # use l1 and for all atoms
energy_pred = energy_pred * atom_norm
energy_label = energy_label * atom_norm
l1_ener_loss = F.l1_loss(
energy_pred.reshape(-1),
energy_label.reshape(-1),
reduction="sum",
reduction="mean",
)
loss += pref_e * l1_ener_loss
more_loss["mae_e"] = self.display_if_exist(
F.l1_loss(
energy_pred.reshape(-1),
energy_label.reshape(-1),
reduction="mean",
).detach(),
l1_ener_loss.detach(),
find_energy,
)
# more_loss['log_keys'].append('rmse_e')
if mae:
mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm
more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
mae_e_all = torch.mean(torch.abs(energy_pred - energy_label))
more_loss["mae_e_all"] = self.display_if_exist(
mae_e_all.detach(), find_energy
)
# if mae:
# mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm
# more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
# mae_e_all = torch.mean(torch.abs(energy_pred - energy_label))
# more_loss["mae_e_all"] = self.display_if_exist(
# mae_e_all.detach(), find_energy
# )

if (
(self.has_f or self.has_pf or self.relative_f or self.has_gf)
Expand Down Expand Up @@ -241,17 +239,17 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
rmse_f.detach(), find_force
)
else:
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="none")
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="mean")
more_loss["mae_f"] = self.display_if_exist(
l1_force_loss.mean().detach(), find_force
l1_force_loss.detach(), find_force
)
l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
# l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
if mae:
mae_f = torch.mean(torch.abs(diff_f))
more_loss["mae_f"] = self.display_if_exist(
mae_f.detach(), find_force
)
# if mae:
# mae_f = torch.mean(torch.abs(diff_f))
# more_loss["mae_f"] = self.display_if_exist(
# mae_f.detach(), find_force
# )

if self.has_pf and "atom_pref" in label:
atom_pref = label["atom_pref"]
Expand Down Expand Up @@ -297,18 +295,29 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
if self.has_v and "virial" in model_pred and "virial" in label:
find_virial = label.get("find_virial", 0.0)
pref_v = pref_v * find_virial
virial_label = label["virial"]
virial_pred = model_pred["virial"].reshape(-1, 9)
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
l2_virial_loss = torch.mean(torch.square(diff_v))
if not self.inference:
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
if not self.use_l1_all:
l2_virial_loss = torch.mean(torch.square(diff_v))
if not self.inference:
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
)
loss += atom_norm * (pref_v * l2_virial_loss)
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(
rmse_v.detach(), find_virial
)
else:
l1_virial_loss = F.l1_loss(virial_label, virial_pred, reduction="mean")
more_loss["mae_v"] = self.display_if_exist(
l1_virial_loss.detach(), find_virial
)
loss += atom_norm * (pref_v * l2_virial_loss)
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
if mae:
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
loss += (pref_v * l1_virial_loss).to(GLOBAL_PT_FLOAT_PRECISION)
# if mae:
# mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
# more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)

if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label:
atom_ener = model_pred["atom_energy"]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from .dpa2 import (
DescrptDPA2,
)
from .dpa3 import (
DescrptDPA3,
)
from .env_mat import (
prod_env_mat,
)
Expand Down Expand Up @@ -49,6 +52,7 @@
"DescrptBlockSeTTebd",
"DescrptDPA1",
"DescrptDPA2",
"DescrptDPA3",
"DescrptHybrid",
"DescrptSeA",
"DescrptSeAttenV2",
Expand Down
Loading

0 comments on commit 4e65d8b

Please sign in to comment.