-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: dpmodel energy loss & consistent tests #4531
Conversation
Fix deepmodeling#4105. Fix deepmodeling#4429. Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 5 out of 10 changed files in this pull request and generated no comments.
Files not reviewed (5)
- deepmd/dpmodel/loss/ener.py: Evaluated as low risk
- deepmd/dpmodel/loss/loss.py: Evaluated as low risk
- deepmd/pt/loss/ener.py: Evaluated as low risk
- deepmd/pt/loss/loss.py: Evaluated as low risk
- deepmd/tf/loss/ener.py: Evaluated as low risk
📝 WalkthroughWalkthroughThis pull request introduces enhancements to the loss modules across various backends (TensorFlow, PyTorch, and DeepMD) by adding serialization and deserialization capabilities. A new abstract base class Changes
Assessment against linked issues
Possibly related PRs
Suggested Labels
Suggested Reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (8)
deepmd/dpmodel/loss/loss.py (2)
21-30
: Consider expanding docstring details for thecall
method.Currently, the docstring briefly states "Calculate loss from model results and labeled results," but it doesn't clarify the keys or shape expectations of the returned dictionary. Elaborating on the output structure and required dictionary keys can improve clarity for implementers.
37-55
: Validate the method name to match its core purpose.
display_if_exist
sets loss toNaN
if a property is absent, which may not be immediately obvious from the function name. Consider renaming it to more clearly reflect its filtering behavior, such asmask_unavailable_loss
.deepmd/dpmodel/loss/ener.py (1)
283-283
: Use boolean checks rather than numerical comparison.
self.has_gf
is a boolean, so comparing it to> 0
may cause unexpected behavior or confusion.- if self.has_gf > 0: + if self.has_gf:deepmd/tf/loss/loss.py (2)
98-113
: Reflect uniform serialization approach across backends.Like in
deepmd/dpmodel/loss/ener.py
, consider adding version and class annotation keys, e.g.,"@version"
or"@class"
, to maintain uniform serialization across all frameworks. This consistency simplifies cross-backend model loading.
132-149
: Clarifyinit_variables
design.Static analysis suggests making this method abstract or providing a minimal base implementation. If the plan is for derived classes to override it, consider using
@abstractmethod
. Otherwise, confirm that an empty default is intentional and that it shouldn't raiseNotImplementedError
.🧰 Tools
🪛 Ruff (0.8.2)
132-149:
Loss.init_variables
is an empty method in an abstract base class, but has no abstract decorator(B027)
deepmd/pt/loss/ener.py (2)
418-446
: Use consistent class naming in serialized output
Currently, theserialize()
method returns"@class": "EnergyLoss"
. Consider aligning this string identifier with the actual class name (EnergyStdLoss
) to help avoid confusion across multiple codebases that implement serialization.- "@class": "EnergyLoss", + "@class": "EnergyStdLoss",
447-465
: Enhance error handling on deserialization
You might want to verify that all required fields exist indata
before callingcls(**data)
. An explicit check can help detect missing fields at runtime.+ required_fields = ["starter_learning_rate", "start_pref_e", "limit_pref_e", ...] + for field in required_fields: + if field not in data: + raise KeyError(f"Missing required field '{field}' in serialized data")deepmd/tf/loss/ener.py (1)
408-441
: Improve clarity in serialization
Serializing with"@class": "EnergyLoss"
might be less descriptive than"EnerStdLoss"
here, given this is theEnerStdLoss
class. Consider updating the string for future maintainability.- "@class": "EnergyLoss", + "@class": "EnerStdLoss",
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
deepmd/dpmodel/loss/__init__.py
(1 hunks)deepmd/dpmodel/loss/ener.py
(1 hunks)deepmd/dpmodel/loss/loss.py
(1 hunks)deepmd/pt/loss/ener.py
(2 hunks)deepmd/pt/loss/loss.py
(1 hunks)deepmd/tf/loss/ener.py
(2 hunks)deepmd/tf/loss/loss.py
(1 hunks)source/tests/consistent/loss/__init__.py
(1 hunks)source/tests/consistent/loss/common.py
(1 hunks)source/tests/consistent/loss/test_ener.py
(1 hunks)
✅ Files skipped from review due to trivial changes (3)
- source/tests/consistent/loss/common.py
- deepmd/dpmodel/loss/init.py
- source/tests/consistent/loss/init.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/tf/loss/loss.py
132-149: Loss.init_variables
is an empty method in an abstract base class, but has no abstract decorator
(B027)
🔇 Additional comments (7)
deepmd/dpmodel/loss/loss.py (1)
76-85
: Ensure consistent serialization format across derived classes.
serialize
is abstract here, and derived classes (e.g.,EnergyLoss
) implement it using keys like"@class"
and"@version"
. For consistency, document or enforce the same structure in all subclasses, so third-party tools can reliably interpret the serialized data.
[verification]deepmd/dpmodel/loss/ener.py (2)
63-66
: Confirm usage ofRuntimeError
.Raising
RuntimeError
whennumb_generalized_coord < 1
is explicit and correct for your domain. However, consider providing a more detailed error message to guide users on recommended input constraints or fallback solutions for generalized forces.
70-227
:⚠️ Potential issuePrevent confusion around
natoms
usage.
natoms
is typed as an integer but is treated like an array in lines 143-144 (natoms[0]
). Make surenatoms
is consistently passed as a list or a similar sequence. If multiple frames are expected, clarify this in the docstring and function arguments to avoid type mismatches or logic errors.- force_reshape_nframes = xp.reshape(force, [-1, natoms[0] * 3]) - force_hat_reshape_nframes = xp.reshape(force_hat, [-1, natoms[0] * 3]) + force_reshape_nframes = xp.reshape(force, [-1, natoms * 3]) + force_hat_reshape_nframes = xp.reshape(force_hat, [-1, natoms * 3])Likely invalid or redundant comment.
deepmd/pt/loss/loss.py (1)
67-92
: Implement or document defaultserialize
anddeserialize
behavior.Currently, both methods raise
NotImplementedError
. If this class is intended as an abstract base, mark them as abstract methods. Otherwise, clarify the intended usage by either providing a skeleton implementation or adding docstrings indicating how derived classes should implement serialization.deepmd/pt/loss/ener.py (1)
21-23
: Maintain consistent naming for version checks
The newly added import fromdeepmd.utils.version
looks good, ensuring uniform version compatibility checks.deepmd/tf/loss/ener.py (2)
19-21
: Leverage version compatibility checks
The introduction ofcheck_version_compatibility
aligns this file with other modules, ensuring consistent handling of serialization versions.
442-462
: 🛠️ Refactor suggestionVerify presence of required data fields
Before passingdata
into the constructor, ensure it contains the necessary fields. A structured validation step can reduce runtime exceptions if users supply incomplete data.+ required_fields = ["starter_learning_rate", "start_pref_e", "limit_pref_e", ...] + for field in required_fields: + if field not in data: + raise KeyError(f"Required field '{field}' missing in the serialized data")
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (3)
source/tests/consistent/loss/test_ener.py (1)
57-86
: Consider expanding test coverage for edge cases.The test setup is well-structured, but could benefit from additional test cases:
- Test with zero or negative preference values
- Test with extreme learning rates
- Test with empty or invalid input arrays
deepmd/dpmodel/loss/ener.py (2)
20-67
: Add input validation for preference values.Consider adding validation to ensure that:
- All preference values are non-negative
- Learning rate is positive
relative_f
is positive when provided
291-337
: Enhance error handling in deserialization.Consider adding validation for required fields and type checking in the deserialization method to handle malformed input gracefully.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/dpmodel/loss/ener.py
(1 hunks)source/tests/consistent/loss/test_ener.py
(1 hunks)
🔇 Additional comments (3)
source/tests/consistent/loss/test_ener.py (2)
87-120
: Expand negative or edge case coverage.Currently, the setup uses random arrays and sets
find_*
flags to1.0
. Add tests for scenarios where certain keys (e.g.,"atom_pref"
,"find_energy"
) are missing or zero, ensuring the loss logic handles partially available labels gracefully.
230-241
: LGTM!The helper methods and properties are well-implemented with appropriate tolerance values for floating-point comparisons.
deepmd/dpmodel/loss/ener.py (1)
213-290
: LGTM!The label requirements are well-defined with appropriate flags for each data type.
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
deepmd/pd/loss/ener.py (1)
458-475
: Consider validating the '@Class' field indeserialize()
.
Currently, you remove"@class"
after verifying the version, but there’s no check to ensure the serialized data belongs to the correct class. If you ever extend or share serialization logic across multiple classes, consider validating that"@class"
matches"EnergyStdLoss"
before popping it, to guard against user or developer error.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pd/loss/ener.py
(2 hunks)
🔇 Additional comments (1)
deepmd/pd/loss/ener.py (1)
21-23
: Ensure consistent version checks across the codebase.
These lines introducecheck_version_compatibility
. While this usage appears straightforward, be sure that any older references to version checks throughout the codebase also rely on the same mechanism for uniformity.Run the following script to search for other version checks and confirm consistent practices:
✅ Verification successful
Version compatibility checks are consistently implemented across the codebase
The search results show that
check_version_compatibility
is used consistently throughout the codebase in deserialize methods to validate version compatibility. The implementation follows a standard pattern:
- All modules use the same version checking mechanism from
deepmd.utils.version
- The function is called with consistent arguments format: current version, maximum supported version, and minimum supported version
- The checks are performed at the start of deserialization before any data processing
- The version is consistently popped from data dictionary with "@Version" key
- Default values and version ranges are appropriate for each module's needs
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify consistent usage of check_version_compatibility across the repo. rg -A 5 "check_version_compatibility"Length of output: 63039
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4531 +/- ##
==========================================
- Coverage 84.57% 84.55% -0.03%
==========================================
Files 675 677 +2
Lines 63695 63899 +204
Branches 3488 3486 -2
==========================================
+ Hits 53872 54031 +159
- Misses 8698 8742 +44
- Partials 1125 1126 +1 ☔ View full report in Codecov by Sentry. |
Fix #4105. Fix #4429.
Summary by CodeRabbit
Release Notes
New Features
EnergyLoss
for computing energy-related loss metrics.Documentation
Tests
TestEner
for evaluating energy loss functions.