-
Notifications
You must be signed in to change notification settings - Fork 525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(pt): use eval mode in the C++ interface #4402
Conversation
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot wasn't able to review any files in this pull request.
Files not reviewed (2)
- source/api_cc/src/DeepPotPT.cc: Language not supported
- source/api_cc/src/DeepSpinPT.cc: Language not supported
📝 Walkthrough📝 WalkthroughWalkthroughThe pull request introduces modifications to the Changes
Possibly related PRs
Suggested labels
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (3)
source/api_cc/src/DeepPotPT.cc (2)
Line range hint
187-208
: Add validation for message passing tensors.While the message passing implementation is thorough, consider adding validation for tensor dimensions and types before using them in computations.
if (do_message_passing) { int nswap = lmp_list.nswap; + if (nswap < 0) { + throw deepmd::deepmd_exception("Invalid nswap value"); + } torch::Tensor sendproc_tensor = torch::from_blob(lmp_list.sendproc, {nswap}, int32_option); torch::Tensor recvproc_tensor = torch::from_blob(lmp_list.recvproc, {nswap}, int32_option);
Line range hint
512-538
: Enhance the error message for unimplemented methods.The error message for mixed type computation could be more descriptive to help users understand the limitation and potential alternatives.
- throw deepmd::deepmd_exception("computew_mixed_type is not implemented"); + throw deepmd::deepmd_exception("computew_mixed_type is not yet implemented. Please use the standard compute methods with consistent types.");source/api_cc/src/DeepSpinPT.cc (1)
Line range hint
271-273
: Track unsupported spin model features.The comments indicate that certain spin model features (e.g., virial calculations) are not supported yet. This should be tracked for future implementation.
Would you like me to create a GitHub issue to track the implementation of these missing features? The issue would help track:
- Implementation of virial calculations for spin models
- Implementation of atom virial calculations for spin models
Also applies to: 386-388
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
source/api_cc/src/DeepPotPT.cc
(1 hunks)source/api_cc/src/DeepSpinPT.cc
(1 hunks)
🔇 Additional comments (4)
source/api_cc/src/DeepPotPT.cc (1)
Line range hint 15-31
: LGTM! Improved error handling with specific PyTorch exception types.
The error handling is now more granular and provides clearer error messages for different types of PyTorch exceptions.
source/api_cc/src/DeepSpinPT.cc (3)
Line range hint 13-29
: LGTM! Comprehensive error handling implementation.
The error handling implementation properly catches and translates different types of PyTorch exceptions, providing clear error messages for debugging.
91-91
: LGTM! Correctly sets the model to evaluation mode.
Setting module.eval()
is essential for inference tasks as it disables training-specific behaviors like dropout, ensuring consistent predictions.
Line range hint 449-517
: LGTM! Well-implemented template specializations and wrappers.
The template specializations and wrapper methods are properly implemented with:
- Consistent error handling across all variants
- Type safety for both float and double precision
- Proper forwarding to template methods
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4402 +/- ##
=======================================
Coverage 84.59% 84.59%
=======================================
Files 614 614
Lines 57009 57010 +1
Branches 3486 3487 +1
=======================================
+ Hits 48224 48225 +1
+ Misses 7660 7659 -1
- Partials 1125 1126 +1 ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
source/api_cc/src/DeepSpinPT.cc (1)
Line range hint
293-294
: Track implementation of spin model features.There are several commented-out sections related to spin model features (virial calculations) that are marked as "not supported yet". Consider creating tracking issues for implementing these features.
Would you like me to help create GitHub issues to track the implementation of these spin model features?
Also applies to: 386-387, 417-418
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
source/api_cc/src/DeepPotPT.cc
(1 hunks)source/api_cc/src/DeepSpinPT.cc
(1 hunks)
🔇 Additional comments (5)
source/api_cc/src/DeepPotPT.cc (2)
91-91
: Consider wrapping module.eval() in translate_error.
While setting eval mode is correct for inference, it should be wrapped in the error translation mechanism for consistent exception handling.
92-96
: LGTM! Good practice for inference optimization.
Freezing the module with carefully selected preserved attributes is an excellent approach to optimize inference performance while maintaining access to essential methods.
Let's verify that all required methods are preserved and accessible:
✅ Verification successful
All preserved methods are actively used in the codebase
The verification confirms that all preserved methods in the preserved_attrs
list are actively used in both DeepPotPT.cc
and DeepSpinPT.cc
. Each method is called using run_method()
and serves a specific purpose:
forward_lower
: Used for core model inferencehas_message_passing
: Determines message passing configurationget_rcut
,get_ntypes
: Retrieves model parametersget_dim_fparam
,get_dim_aparam
,is_aparam_nall
: Handles parameter dimensionsget_type_map
: Manages type mapping
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify that all preserved methods are actually used in the codebase
# to ensure we're not preserving unnecessary methods or missing required ones
# Search for each preserved method usage
for method in "forward_lower" "has_message_passing" "get_rcut" "get_ntypes" "get_dim_fparam" "get_dim_aparam" "is_aparam_nall" "get_type_map"; do
echo "Checking usage of $method:"
rg "run_method\(\"$method\""
done
Length of output: 3754
source/api_cc/src/DeepSpinPT.cc (3)
Line range hint 13-29
: LGTM: Comprehensive error handling implementation.
The error handling implementation properly catches and translates PyTorch-specific exceptions into DeepMD exceptions with clear error messages.
91-96
: LGTM: Proper model initialization for inference.
The changes correctly set up the model for inference by:
- Setting evaluation mode with
module.eval()
- Freezing the module while preserving essential methods
- Maintaining a comprehensive list of preserved attributes
These changes will optimize inference performance while ensuring all required functionality remains accessible.
Line range hint 47-60
: Verify GPU device management across different scenarios.
The code handles both CPU and GPU devices, but we should verify proper cleanup in edge cases:
- When GPU is not available
- During device transitions
- When errors occur during computation
✅ Verification successful
The previous search revealed proper device management code, but let's verify the error handling and cleanup mechanisms to ensure complete verification of the device management concerns.
Based on the comprehensive search results, I can now generate the final response:
GPU device management is properly implemented with robust error handling and cleanup.
The codebase shows proper device management across different scenarios:
- GPU availability check:
torch::cuda::is_available()
andtorch::cuda::device_count()
- Fallback to CPU:
device = torch::Device(torch::kCPU)
when GPU is not available - Error handling: Comprehensive try-catch blocks with proper cleanup in constructor and device transitions
- Device cleanup: Memory management through
DPErrcheck
and proper device selection withDPSetDevice
The implementation in DeepSpinPT.cc follows best practices for device management with proper initialization, error handling, and cleanup mechanisms.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for proper GPU device management
# Test: Look for potential device leaks or cleanup issues
# Search for device management patterns
rg -A 5 "torch::Device|DPSetDevice|cuda::"
# Search for error handling around device operations
rg -A 5 "try.*catch.*device"
Length of output: 8304
Script:
#!/bin/bash
# Let's check for error handling and cleanup patterns in DeepSpinPT
# Look for error handling patterns in DeepSpinPT
ast-grep --pattern 'try {
$$$
} catch ($$$) {
$$$
}'
# Check for cleanup/destructor implementation
rg -A 5 "~DeepSpinPT|delete.*device|device.*delete"
# Check for device initialization and error handling patterns
rg -A 10 "init.*device|device.*init"
Length of output: 64560
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (3)
source/api_cc/src/DeepSpinPT.cc (3)
Line range hint
99-108
: Consider adding error logging for thread configurationThe thread configuration silently catches exceptions. Consider logging these exceptions to help with debugging deployment issues.
if (num_inter_nthreads) { try { at::set_num_interop_threads(num_inter_nthreads); - } catch (...) { + } catch (const std::exception& e) { + std::cerr << "Warning: Failed to set inter-op threads: " << e.what() << std::endl; } } if (num_intra_nthreads) { try { at::set_num_threads(num_intra_nthreads); - } catch (...) { + } catch (const std::exception& e) { + std::cerr << "Warning: Failed to set intra-op threads: " << e.what() << std::endl; } }
Line range hint
167-187
: Consider documenting message passing configurationThe message passing implementation is comprehensive but could benefit from documentation explaining the expected format and requirements of the communication dictionary.
Spin model support is incomplete across multiple components
Based on the search results, there are multiple areas where spin model support is either not implemented or disabled:
- Multiple Python modules raise
NotImplementedError
for spin support:
deepmd/tf/model/model.py
deepmd/pt/model/descriptor/
(se_a.py, dpa1.py, se_t.py)deepmd/dpmodel/fitting/general_fitting.py
- Several commented-out sections in
source/api_cc/src/DeepSpinPT.cc
marked as "spin model not supported yet"- Partial implementations exist in LAMMPS integration (
source/lmp/pair_deepspin.cpp
) with TODO comments for spin atomic forces- Test files indicate planned spin model functionality that's currently disabled
🔗 Analysis chain
Line range hint
188-196
: Verify spin model support statusThere are multiple commented-out sections marked as "spin model not supported yet". Consider adding a TODO or tracking issue to implement these features.
Would you like me to create a GitHub issue to track the implementation of these spin model features?
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check for other spin-related TODOs or implementation gaps # Search for spin-related TODO comments rg -i "spin.*not.*support|spin.*todo" # Search for commented-out spin-related code rg "//.*spin"Length of output: 3622
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
source/api_cc/src/DeepPotPT.cc
(1 hunks)source/api_cc/src/DeepSpinPT.cc
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- source/api_cc/src/DeepPotPT.cc
🔇 Additional comments (3)
source/api_cc/src/DeepSpinPT.cc (3)
Line range hint 13-29
: LGTM: Comprehensive error handling implementation
The error handling implementation properly catches and translates PyTorch-specific exceptions into domain-specific exceptions with clear error messages.
91-91
: LGTM: Added evaluation mode for inference
Setting the module to evaluation mode is crucial for inference tasks, ensuring proper behavior of layers like dropout and batch normalization.
Line range hint 449-492
: LGTM: Well-structured template specializations and wrappers
The template specializations and wrapper methods are well-implemented with proper error handling and type support.
Same as deepmodeling#4402 Signed-off-by: Jinzhe Zeng <[email protected]>
Same as #4402 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced model initialization by ensuring the model is set to evaluation mode immediately after loading, improving inference accuracy. - **Bug Fixes** - Corrected the control flow during model setup to prevent potential issues during evaluation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <[email protected]>
Summary by CodeRabbit