-
Notifications
You must be signed in to change notification settings - Fork 525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(pt): support DeepEval.eval_descriptor
#4214
Conversation
Fix deepmodeling#4112. Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
📝 Walkthrough📝 WalkthroughWalkthroughThe changes introduce a new method Changes
Assessment against linked issues
Possibly related PRs
Suggested reviewers
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Outside diff range and nitpick comments (1)
deepmd/pt/infer/deep_eval.py (1)
602-647
: LGTM! Consider some improvements for robustness and consistency.The implementation of
eval_descriptor
looks good and correctly implements the descriptor evaluation using a hook mechanism. Here are a few suggestions to consider:
Error Handling: Consider adding try-except blocks to handle potential errors gracefully, especially around lines 643-646 where the hook is set and unset.
Performance Optimization: The method calls
self.eval
, which might perform unnecessary computations if only the descriptor is needed. If descriptor evaluation is a frequent operation, consider implementing a more optimized path.Consistency: On line 642, you're using
self.dp.model["Default"]
directly, while other methods in the class useself.dp.to(DEVICE)
. For consistency, consider using the same approach here.Here's a potential refactor addressing these points:
def eval_descriptor( self, coords: np.ndarray, cells: Optional[np.ndarray], atom_types: np.ndarray, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, **kwargs: Any, ) -> np.ndarray: """Evaluate descriptors by using this DP.""" model = self.dp.to(DEVICE).model["Default"] try: model.set_eval_descriptor_hook(True) self.eval(coords, cells, atom_types, fparam=fparam, aparam=aparam, **kwargs) descriptor = model.eval_descriptor() except Exception as e: raise RuntimeError("Error during descriptor evaluation") from e finally: model.set_eval_descriptor_hook(False) return to_numpy_array(descriptor)This refactored version includes error handling, ensures consistent device placement, and maintains the hook state even if an error occurs.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (4)
- deepmd/pt/infer/deep_eval.py (1 hunks)
- deepmd/pt/model/atomic_model/dp_atomic_model.py (2 hunks)
- deepmd/pt/model/model/dp_model.py (2 hunks)
- source/tests/infer/test_models.py (0 hunks)
💤 Files with no reviewable changes (1)
- source/tests/infer/test_models.py
🧰 Additional context used
🔇 Additional comments (6)
deepmd/pt/model/model/dp_model.py (4)
6-7
: LGTM: torch import added correctlyThe
torch
import is necessary for the new methods using torch functionalities and is correctly placed at the top of the file.
58-62
: LGTM: set_eval_descriptor_hook method implemented correctlyThe
set_eval_descriptor_hook
method is well-implemented:
- Correctly uses the
@torch.jit.export
decorator for TorchScript compatibility.- Provides a clear docstring explaining its purpose.
- Follows the single responsibility principle by delegating to
self.atomic_model
.
63-66
: LGTM: eval_descriptor method implemented correctlyThe
eval_descriptor
method is well-implemented:
- Correctly uses the
@torch.jit.export
decorator for TorchScript compatibility.- Provides a clear, concise docstring.
- Follows the single responsibility principle by delegating to
self.atomic_model
.- Specifies the return type as
torch.Tensor
, which is good for type hinting.
57-66
: Consider adding tests for the new methodsThe implementation of
set_eval_descriptor_hook
andeval_descriptor
looks good. To ensure robustness and prevent future regressions, consider adding unit tests for these new methods if not already done.deepmd/pt/model/atomic_model/dp_atomic_model.py (2)
65-67
: Initialization of new attributes is appropriateThe new attributes
enable_eval_descriptor_hook
andeval_descriptor_list
are properly initialized in the__init__
method. This setup allows for flexible control over descriptor evaluation during the model's forward passes.
70-74
: Verify the reset behavior inset_eval_descriptor_hook
The
set_eval_descriptor_hook
method resetseval_descriptor_list
every time it is called, regardless of whether the hook is being enabled or disabled. Is this the intended behavior? If the cache should only be cleared when enabling the hook, consider modifying the implementation to reset the list only whenenable
isTrue
.Suggested consideration:
def set_eval_descriptor_hook(self, enable: bool) -> None: """Set the hook for evaluating descriptor and clear the cache if enabling.""" self.enable_eval_descriptor_hook = enable if enable: self.eval_descriptor_list = []This ensures that the descriptor list is only cleared when the hook is enabled, preserving any cached descriptors if the hook is being disabled.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4214 +/- ##
=======================================
Coverage 83.50% 83.50%
=======================================
Files 541 541
Lines 52459 52483 +24
Branches 3047 3047
=======================================
+ Hits 43804 43825 +21
Misses 7710 7710
- Partials 945 948 +3 ☔ View full report in Codecov by Sentry. |
Fix #4112.
Summary by CodeRabbit
New Features
Bug Fixes
Tests