Skip to content

Commit

Permalink
feat (tf/pt): add atomic weights to tensor loss (deepmodeling#4466)
Browse files Browse the repository at this point in the history
Interfaces are of particular interest in many studies. However, the
configurations in the training set to represent the interface normally
also include large parts of the bulk material. As a result, the final
model would prefer the bulk information while the interfacial
information is less learnt. It is difficult to simply improve the
proportion of interfaces in the configurations since the electronic
structures of the interface might only be reasonable with a certain
thickness of bulk materials. Therefore, I wonder whether it is possible
to define weights for atomic quantities in loss functions. This allows
us to add higher weights for the atomic information for the regions of
interest and probably makes the model "more focused" on the region of
interest.
In this PR, I add the keyword `enable_atomic_weight` to the loss
function of the tensor model. In principle, it could be generalised to
any atomic quantity, e.g., atomic forces.
I would like to know the developers' comments/suggestions about this
feature. I can add support for other loss functions and finish unit
tests once we agree on this feature.

Best. 




<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced an optional parameter for atomic weights in loss
calculations, enhancing flexibility in the `TensorLoss` class.
- Added a suite of unit tests for the `TensorLoss` functionality,
ensuring consistency between TensorFlow and PyTorch implementations.

- **Bug Fixes**
- Updated logic for local loss calculations to ensure correct
application of atomic weights based on user input.

- **Documentation**
- Improved clarity of documentation for several function arguments,
including the addition of a new argument related to atomic weights.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
ChiahsinChu authored Dec 18, 2024
1 parent 104fc36 commit c0914e1
Show file tree
Hide file tree
Showing 4 changed files with 519 additions and 3 deletions.
22 changes: 22 additions & 0 deletions deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
pref_atomic: float = 0.0,
pref: float = 0.0,
inference=False,
enable_atomic_weight: bool = False,
**kwargs,
) -> None:
r"""Construct a loss for local and global tensors.
Expand All @@ -40,6 +41,8 @@ def __init__(
The prefactor of the weight of global loss. It should be larger than or equal to 0.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
enable_atomic_weight : bool
If true, atomic weight will be used in the loss calculation.
**kwargs
Other keyword arguments.
"""
Expand All @@ -50,6 +53,7 @@ def __init__(
self.local_weight = pref_atomic
self.global_weight = pref
self.inference = inference
self.enable_atomic_weight = enable_atomic_weight

assert (
self.local_weight >= 0.0 and self.global_weight >= 0.0
Expand Down Expand Up @@ -85,6 +89,12 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
"""
model_pred = model(**input_dict)
del learning_rate, mae

if self.enable_atomic_weight:
atomic_weight = label["atom_weight"].reshape([-1, 1])
else:
atomic_weight = 1.0

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
if (
Expand All @@ -103,6 +113,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
diff = (local_tensor_pred - local_tensor_label).reshape(
[-1, self.tensor_size]
)
diff = diff * atomic_weight
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss = torch.mean(torch.square(diff))
Expand Down Expand Up @@ -171,4 +182,15 @@ def label_requirement(self) -> list[DataRequirementItem]:
high_prec=False,
)
)
if self.enable_atomic_weight:
label_requirement.append(
DataRequirementItem(
"atomic_weight",
ndof=1,
atomic=True,
must=False,
high_prec=False,
default=1.0,
)
)
return label_requirement
24 changes: 23 additions & 1 deletion deepmd/tf/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, jdata, **kwarg) -> None:
# YWolfeee: modify, use pref / pref_atomic, instead of pref_weight / pref_atomic_weight
self.local_weight = jdata.get("pref_atomic", None)
self.global_weight = jdata.get("pref", None)
self.enable_atomic_weight = jdata.get("enable_atomic_weight", False)

assert (
self.local_weight is not None and self.global_weight is not None
Expand All @@ -66,9 +67,18 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
"global_loss": global_cvt_2_tf_float(0.0),
}

if self.enable_atomic_weight:
atomic_weight = tf.reshape(label_dict["atom_weight"], [-1, 1])
else:
atomic_weight = global_cvt_2_tf_float(1.0)

if self.local_weight > 0.0:
diff = tf.reshape(polar, [-1, self.tensor_size]) - tf.reshape(
atomic_polar_hat, [-1, self.tensor_size]
)
diff = diff * atomic_weight
local_loss = global_cvt_2_tf_float(find_atomic) * tf.reduce_mean(
tf.square(self.scale * (polar - atomic_polar_hat)), name="l2_" + suffix
tf.square(self.scale * diff), name="l2_" + suffix
)
more_loss["local_loss"] = self.display_if_exist(local_loss, find_atomic)
l2_loss += self.local_weight * local_loss
Expand Down Expand Up @@ -163,4 +173,16 @@ def label_requirement(self) -> list[DataRequirementItem]:
type_sel=self.type_sel,
)
)
if self.enable_atomic_weight:
data_requirements.append(
DataRequirementItem(
"atom_weight",
1,
atomic=True,
must=False,
high_prec=False,
default=1.0,
type_sel=self.type_sel,
)
)
return data_requirements
12 changes: 10 additions & 2 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2511,8 +2511,9 @@ def loss_property():
def loss_tensor():
# doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If only `pref` is provided or both are not provided, training will be global mode, i.e. the shape of 'polarizability.npy` or `dipole.npy` should be #frams x [9 or 3]."
# doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If only `pref_atomic` is provided, training will be atomic mode, i.e. the shape of `polarizability.npy` or `dipole.npy` should be #frames x ([9 or 3] x #selected atoms). If both `pref` and `pref_atomic` are provided, training will be combined mode, and atomic label should be provided as well."
doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included."
doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #selected atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0."
doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included."
doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0."
doc_enable_atomic_weight = "If true, the atomic loss will be reweighted."
return [
Argument(
"pref", [float, int], optional=False, default=None, doc=doc_global_weight
Expand All @@ -2524,6 +2525,13 @@ def loss_tensor():
default=None,
doc=doc_local_weight,
),
Argument(
"enable_atomic_weight",
bool,
optional=True,
default=False,
doc=doc_enable_atomic_weight,
),
]


Expand Down
Loading

0 comments on commit c0914e1

Please sign in to comment.