-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(pt): support spin virial #4545
base: devel
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughThis pull request introduces enhanced support for virial calculations in the DeepMD-kit framework, specifically focusing on spin-related models. The changes span multiple files across the project, adding new functionality to handle virial outputs, coordinate corrections, and spin-related computations. The modifications enable more comprehensive energy and virial loss calculations, with updates to model processing, loss computation, and testing frameworks. Changes
Possibly related PRs
Suggested Labels
Suggested Reviewers
Finishing Touches
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (12)
deepmd/pt/model/model/spin_model.py (3)
58-59
: Enhance code readability by reshaping tensor in a single stepIn line 58, consider reshaping
self.virtual_scale_mask.to(atype.device)[atype]
directly without wrapping it in parentheses for better readability.
379-381
: Avoid unnecessary computation by not callingprocess_spin_input
during stat computationIn the
compute_or_load_stat
method, callingprocess_spin_input
may introduce unnecessary computational overhead ifcoord_corr
is not used. Consider modifying the code to excludecoord_corr
when it's not needed.
591-594
: Consider consistency in handlingdo_grad_c
checksIn the
forward
method, ensure that the handling ofdo_grad_c("energy")
and subsequent assignments align with the changes made intranslated_output_def
. This maintains consistency across the methods.source/tests/pt/model/test_autodiff.py (4)
144-144
: Initialize thespin
variable only when necessaryThe
spin
variable is initialized even whentest_spin
isFalse
. Consider moving the initialization inside the conditional block to optimize performance.Apply this diff to adjust the initialization:
- spin = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator)
Move the initialization to after line 150, within the
if test_spin
block.
148-148
: Ensurespin
is only converted to NumPy when necessarySimilar to the previous comment, the conversion of
spin
to a NumPy array should be conditional based ontest_spin
to avoid unnecessary computations.
151-154
: Simplify the conditional assignment oftest_keys
The assignment of
test_keys
can be streamlined for clarity.Apply this diff to simplify the code:
- if not test_spin: - test_keys = ["energy", "force", "virial"] - else: - test_keys = ["energy", "force", "force_mag", "virial"] + test_keys = ["energy", "force", "virial"] + if test_spin: + test_keys.append("force_mag")
263-268
: Add a newline for code style consistencyInclude a blank line after the class definition to follow PEP 8 style guidelines for better readability.
Apply this diff:
class TestEnergyModelSpinSeAVirial(unittest.TestCase, VirialTest): + def setUp(self) -> None: model_params = copy.deepcopy(model_spin)
deepmd/pt/model/model/transform_output.py (3)
159-159
: Update function documentation to include new parameterThe
fit_output_to_model_output
function has a new parameterextended_coord_corr
. Update the docstring to describe this parameter and its role in the computation.
195-195
: Avoid using# noqa
comments for line lengthInstead of using
# noqa: RUF005
to suppress line length warnings, refactor the code to comply with style guidelines for better maintainability.Apply this diff to split the line:
- ).view(list(dc.shape[:-2]) + [1, 9]) # noqa: RUF005 + ) + dc = dc.view(list(dc.shape[:-2]) + [1, 9])
Line range hint
226-226
: Consider adding type annotations for function returnsAdding type annotations to functions enhances code clarity and aids in static analysis. Consider specifying the return types for the functions in this module.
deepmd/pt/loss/ener_spin.py (1)
271-286
: LGTM! The virial loss calculation is well implemented.The implementation follows the established pattern for loss calculations, with proper scaling and optional MAE computation. The code is clean and well-structured.
Consider extracting the common pattern of loss calculation (L2, MAE, scaling) into a helper method to reduce code duplication across energy, force, and virial loss calculations.
source/tests/pt/model/test_ener_spin_model.py (1)
118-118
: Document the purpose of the ignored return values.The additional return values (marked with
_
) fromprocess_spin_input
andprocess_spin_input_lower
are silently ignored. Consider adding a comment explaining what these values represent and why they can be safely ignored in these tests.Also applies to: 177-177
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
deepmd/pt/loss/ener_spin.py
(1 hunks)deepmd/pt/model/model/make_model.py
(7 hunks)deepmd/pt/model/model/spin_model.py
(11 hunks)deepmd/pt/model/model/transform_output.py
(2 hunks)source/api_c/include/deepmd.hpp
(2 hunks)source/api_c/src/c_api.cc
(1 hunks)source/api_cc/src/DeepSpinPT.cc
(4 hunks)source/tests/pt/model/test_autodiff.py
(3 hunks)source/tests/pt/model/test_ener_spin_model.py
(2 hunks)source/tests/universal/common/cases/model/utils.py
(3 hunks)source/tests/universal/pt/model/test_model.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (21)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Analyze (python)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test C++ (false)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C++ (cpu, cpu)
🔇 Additional comments (13)
deepmd/pt/model/model/spin_model.py (4)
63-64
: Ensure proper alignment of coordinate correctionsThe concatenation of tensors in
coord_corr
must maintain correct alignment with the corresponding atoms. Verify thattorch.zeros_like(coord)
and-spin_dist
are correctly ordered, ensuring that the coordinate corrections apply to the appropriate atoms.
92-95
: Validate the consistency of virtual atom handlingWhen creating
extended_coord_corr
, confirm that the virtual atoms are correctly accounted for, and that the concatenation preserves the intended structure. This is crucial for accurate virial calculations involving spin corrections.
410-412
: Handle the new outputcoord_corr_for_virial
appropriatelyEnsure that all downstream methods that receive
coord_corr_for_virial
can handle this new parameter without errors. Verify thatself.backbone_model.forward_common
acceptscoord_corr_for_virial
as an argument.
631-636
: Verify accurate squeezing of tensors and assignmentIn the
forward_lower
method, confirm that thesqueeze
operations correctly reduce tensor dimensions and that the results are assigned to the appropriate keys inmodel_predict
.source/tests/pt/model/test_autodiff.py (1)
150-150
: Conditionally handlespin
andtest_spin
variablesVerify that all usages of
spin
andtest_spin
within theVirialTest
class are properly guarded by conditionals to prevent errors whentest_spin
isFalse
.deepmd/pt/model/model/transform_output.py (1)
191-196
: Ensure tensor shapes are compatible during matrix multiplicationIn the computation of
dc_corr
, validate that the shapes of the tensors involved in the matrix multiplication are compatible to prevent runtime errors.deepmd/pt/model/model/make_model.py (2)
138-138
: LGTM! The parameter addition is well documented.The new
coord_corr_for_virial
parameter is properly documented with clear shape information.Also applies to: 157-159
187-194
: LGTM! The coordinate correction handling is robust.The implementation properly handles both cases:
- When correction is provided: converts to correct dtype and gathers extended coordinates
- When correction is not provided: sets extended_coord_corr to None
source/api_cc/src/DeepSpinPT.cc (1)
254-254
: LGTM! The virial output handling is properly implemented.The implementation correctly:
- Retrieves virial tensor from outputs
- Converts tensor to appropriate float type
- Transfers data to CPU
- Assigns to output vector
Also applies to: 269-273, 417-417, 432-435
source/tests/universal/pt/model/test_model.py (1)
716-717
: Note the temporary nature of this flag.The added
test_spin_virial
flag enables spin virial testing for this backend. As indicated by the comment, this can be removed once other backends support spin virial calculations.Let's verify which backends currently lack spin virial support:
✅ Verification successful
Temporary flag is correctly annotated
The comment accurately reflects the current state of the codebase. The spin virial implementation is currently only available in the PyTorch (PT) backend, and the flag is appropriately marked as temporary until other backends (TensorFlow, PaddlePaddle) implement this functionality.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for spin virial implementations across backends rg -A 5 "spin.*virial" source/api_cc/src/Length of output: 43
Script:
#!/bin/bash # Broader search for spin virial across the codebase echo "=== Searching for spin virial across the codebase ===" rg -i "spin.*virial" --type py echo -e "\n=== Searching for related test files ===" fd "test.*model.*py$" echo -e "\n=== Looking for backend-specific implementations ===" rg -i "spin.*virial" --type py -A 5 source/tests/Length of output: 5891
source/tests/universal/common/cases/model/utils.py (1)
895-898
: Improved control flow for spin virial testing.The condition ensures virial computation only happens when appropriate - either for non-spin models or when spin virial testing is explicitly enabled.
source/api_c/src/c_api.cc (1)
865-869
: Re-enabled virial data propagation.The previously commented-out code for copying virial data has been restored, allowing proper propagation of virial values in spin model deviation computations.
source/api_c/include/deepmd.hpp (1)
2605-2607
: Re-enabled virial data propagation in C++ interface.The previously commented-out code for copying virial data has been restored, maintaining consistency with the C API changes.
@@ -408,6 +419,7 @@ def forward_common( | |||
fparam=fparam, | |||
aparam=aparam, | |||
do_atomic_virial=do_atomic_virial, | |||
coord_corr_for_virial=coord_corr_for_virial, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confirm the compatibility of the new argument coord_corr_for_virial
Check if the backbone model's forward_common
method is designed to accept coord_corr_for_virial
. If not, update the backbone model accordingly or modify the call to prevent runtime errors.
@@ -469,6 +482,7 @@ | |||
do_atomic_virial=do_atomic_virial, | |||
comm_dict=comm_dict, | |||
extra_nlist_sort=extra_nlist_sort, | |||
extended_coord_corr=extended_coord_corr_for_virial, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure extended_coord_corr
is accepted by forward_common_lower
Similar to the previous comment, verify that self.backbone_model.forward_common_lower
accepts extended_coord_corr
as an argument. This prevents potential issues during model execution.
if self.do_grad_c("energy"): | ||
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) | ||
output_def["virial"].squeeze(-2) | ||
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) | ||
output_def["atom_virial"].squeeze(-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assign the result of squeeze
operations to reduce tensor dimensions
The squeeze
methods in lines 560 and 562 do not modify tensors in place. Assign the results to ensure the dimensions are correctly reduced.
Apply this diff to fix the issue:
- output_def["virial"].squeeze(-2)
+ output_def["virial"] = output_def["virial"].squeeze(-2)
- output_def["atom_virial"].squeeze(-3)
+ output_def["atom_virial"] = output_def["atom_virial"].squeeze(-3)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if self.do_grad_c("energy"): | |
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) | |
output_def["virial"].squeeze(-2) | |
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) | |
output_def["atom_virial"].squeeze(-3) | |
if self.do_grad_c("energy"): | |
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) | |
output_def["virial"] = output_def["virial"].squeeze(-2) | |
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) | |
output_def["atom_virial"] = output_def["atom_virial"].squeeze(-3) |
spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0), | ||
) | ||
# detach | ||
ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure compatibility of tensor devices
When creating tensors within the np_infer
function, ensure that all tensors are on the same device to prevent device mismatch errors, especially when env.DEVICE
differs from "cpu"
.
Apply this diff to correct the device assignment:
- spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0),
+ spins=torch.tensor(spin, device=new_cell.device).unsqueeze(0),
Committable suggestion skipped: line range outside the PR's diff.
Summary by CodeRabbit
New Features
Bug Fixes
Tests