-
Notifications
You must be signed in to change notification settings - Fork 519
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat (tf/pt): add atomic weights to tensor loss (#4466)
Interfaces are of particular interest in many studies. However, the configurations in the training set to represent the interface normally also include large parts of the bulk material. As a result, the final model would prefer the bulk information while the interfacial information is less learnt. It is difficult to simply improve the proportion of interfaces in the configurations since the electronic structures of the interface might only be reasonable with a certain thickness of bulk materials. Therefore, I wonder whether it is possible to define weights for atomic quantities in loss functions. This allows us to add higher weights for the atomic information for the regions of interest and probably makes the model "more focused" on the region of interest. In this PR, I add the keyword `enable_atomic_weight` to the loss function of the tensor model. In principle, it could be generalised to any atomic quantity, e.g., atomic forces. I would like to know the developers' comments/suggestions about this feature. I can add support for other loss functions and finish unit tests once we agree on this feature. Best. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced an optional parameter for atomic weights in loss calculations, enhancing flexibility in the `TensorLoss` class. - Added a suite of unit tests for the `TensorLoss` functionality, ensuring consistency between TensorFlow and PyTorch implementations. - **Bug Fixes** - Updated logic for local loss calculations to ensure correct application of atomic weights based on user input. - **Documentation** - Improved clarity of documentation for several function arguments, including the addition of a new argument related to atomic weights. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
- Loading branch information
1 parent
104fc36
commit c0914e1
Showing
4 changed files
with
519 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.