-
Notifications
You must be signed in to change notification settings - Fork 520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jax): neighbor stat #4258
feat(jax): neighbor stat #4258
Conversation
Signed-off-by: Jinzhe Zeng <[email protected]>
📝 WalkthroughWalkthroughThe pull request introduces several changes across multiple files, primarily enhancing the functionality of the Changes
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🧹 Outside diff range and nitpick comments (4)
deepmd/jax/utils/auto_batch_size.py (1)
44-59
: Simplify the OOM error detection logic.The implementation correctly handles OOM errors, but the code can be simplified while maintaining readability.
Apply this diff to simplify the return statement:
def is_oom_error(self, e: Exception) -> bool: # several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error, # such as https://github.com/JuliaGPU/CUDA.jl/issues/1924 # (the meaningless error message should be considered as a bug in cusolver) - if isinstance(e, (jaxlib.xla_extension.XlaRuntimeError, ValueError)) and ( - "RESOURCE_EXHAUSTED:" in e.args[0] - ): - return True - return False + return isinstance(e, (jaxlib.xla_extension.XlaRuntimeError, ValueError)) and ( + "RESOURCE_EXHAUSTED:" in e.args[0] + )🧰 Tools
🪛 Ruff
55-59: Return the condition directly
Inline condition
(SIM103)
deepmd/backend/jax.py (1)
Line range hint
26-120
: Consider enhancing error handling for unsupported features.The backend implementation is well-structured, but consider adding runtime checks or decorators to gracefully handle cases where features are enabled but not fully implemented. This would prevent silent failures and provide better error messages.
Example approach:
- Create a decorator that checks if a feature is both enabled and implemented
- Apply it to feature-specific properties
- Provide clear error messages indicating which features are not fully supported
Would you like me to provide an example implementation of such a decorator?
source/tests/consistent/test_neighbor_stat.py (2)
Line range hint
50-75
: LGTM! Well-structured test implementation with comprehensive test casesThe refactoring to
run_neighbor_stat
with a backend parameter improves code reuse. The test logic is thorough with proper validation of neighbor statistics across different cutoff radii and mixed type configurations.Consider adding a docstring to document the method's purpose and parameters:
def run_neighbor_stat(self, backend): + """Test neighbor statistics computation for different backends. + + Args: + backend (str): The backend to use ('tensorflow', 'pytorch', 'numpy', or 'jax') + """ for rcut in (0.0, 1.0, 2.0, 4.0):
76-89
: LGTM! Well-organized backend-specific test methodsThe backend-specific test methods are well-structured with proper conditional execution using
skipUnless
. Each backend is consistently handled.Consider grouping the backend tests together by moving the numpy test (
test_neighbor_stat_dp
) next to other backend tests for better organization:- @unittest.skipUnless(INSTALLED_TF, "tensorflow is not installed") - def test_neighbor_stat_tf(self): - self.run_neighbor_stat("tensorflow") - - @unittest.skipUnless(INSTALLED_PT, "pytorch is not installed") - def test_neighbor_stat_pt(self): - self.run_neighbor_stat("pytorch") - - def test_neighbor_stat_dp(self): - self.run_neighbor_stat("numpy") - - @unittest.skipUnless(INSTALLED_JAX, "jax is not installed") - def test_neighbor_stat_jax(self): - self.run_neighbor_stat("jax") + def test_neighbor_stat_dp(self): + """Test neighbor statistics with numpy backend.""" + self.run_neighbor_stat("numpy") + + @unittest.skipUnless(INSTALLED_TF, "tensorflow is not installed") + def test_neighbor_stat_tf(self): + """Test neighbor statistics with tensorflow backend.""" + self.run_neighbor_stat("tensorflow") + + @unittest.skipUnless(INSTALLED_PT, "pytorch is not installed") + def test_neighbor_stat_pt(self): + """Test neighbor statistics with pytorch backend.""" + self.run_neighbor_stat("pytorch") + + @unittest.skipUnless(INSTALLED_JAX, "jax is not installed") + def test_neighbor_stat_jax(self): + """Test neighbor statistics with jax backend.""" + self.run_neighbor_stat("jax")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (7)
- deepmd/backend/jax.py (2 hunks)
- deepmd/dpmodel/utils/neighbor_stat.py (2 hunks)
- deepmd/jax/utils/auto_batch_size.py (1 hunks)
- deepmd/jax/utils/neighbor_stat.py (1 hunks)
- source/tests/common/dpmodel/test_neighbor_stat.py (0 hunks)
- source/tests/consistent/test_neighbor_stat.py (4 hunks)
- source/tests/tf/test_neighbor_stat.py (0 hunks)
💤 Files with no reviewable changes (2)
- source/tests/common/dpmodel/test_neighbor_stat.py
- source/tests/tf/test_neighbor_stat.py
🧰 Additional context used
🪛 Ruff
deepmd/jax/utils/auto_batch_size.py
55-59: Return the condition directly
Inline condition
(SIM103)
🔇 Additional comments (8)
deepmd/jax/utils/auto_batch_size.py (2)
1-23
: LGTM! Well-structured class definition with proper documentation.The class is properly defined with clear inheritance and well-documented parameters.
24-33
: LGTM! Clean constructor implementation.The constructor properly initializes the parent class with sensible default values.
deepmd/backend/jax.py (3)
85-89
: LGTM! Clean implementation of neighbor_stat property.The implementation follows best practices with lazy imports to avoid circular dependencies.
Line range hint
92-120
: Verify serialization utilities existence.The implementation looks clean, but let's verify the existence of the required serialization utilities.
#!/bin/bash # Description: Verify the existence of serialization utilities # Check for the serialization module and its functions rg -l "serialize_from_file|deserialize_to_file" --type py # Look for any existing serialization tests fd -e py test.*serial
36-38
: Verify ENTRY_POINT feature implementation.The
ENTRY_POINT
feature is enabled, but its correspondingentry_point_hook
property raisesNotImplementedError
. This inconsistency could cause runtime issues.source/tests/consistent/test_neighbor_stat.py (1)
15-19
: LGTM! Clean import organizationThe new imports for backend availability flags are well-organized and follow Python conventions.
deepmd/jax/utils/neighbor_stat.py (1)
78-78
: Verify the Aggregation FunctionsWhen using
np.max
andnp.min
, ensure that theaxis
parameter is correctly set based on the dimensions ofmax_nnei
andminrr2
. Verify that these aggregations produce the expected scalar values.Run the following script to confirm the shapes and values:
✅ Verification successful
Looking at the code, I can now generate a final response since I have enough context to verify the aggregation functions.
Aggregation Functions Are Correctly Used
The
axis=0
parameter innp.max(max_nnei, axis=0)
is correct because:
max_nnei
is returned fromNeighborStatOP
and converted to numpy array- The aggregation is performed over batches of data, where
axis=0
correctly reduces across the batch dimensionminrr2
is correctly reduced usingnp.min
without an axis parameter since we want the global minimum distance🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check the shapes of 'max_nnei' and 'minrr2' returned from '_execute' method. # Test: Print the shapes. Expect: Correct dimensions for aggregation. ast-grep --pattern $'def _execute($_) { $$$ }' | while read -r func_def; do echo "Analyzing function: $func_def" # Assuming the script can extract and print variable shapes echo "Shape of max_nnei:" # Placeholder command to print shape echo "Shape of minrr2:" # Placeholder command to print shape doneLength of output: 66
Script:
#!/bin/bash # Let's first check the content of the file to understand the context better cat deepmd/jax/utils/neighbor_stat.pyLength of output: 2792
deepmd/dpmodel/utils/neighbor_stat.py (1)
72-72
: Usage ofarray_namespace
is appropriateThe use of
array_api_compat.array_namespace
to obtain a consistent array API is correct and enhances compatibility across different array backends.
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (2)
deepmd/jax/utils/neighbor_stat.py (2)
86-96
: Add 'Returns' Section to DocstringThe docstring of the
_execute
method lacks a 'Returns' section describing the outputs of the function. Including this information improves code readability and maintainability.Apply this diff to update the docstring:
""" Execute the operation. Parameters ---------- coord The coordinates of atoms. atype The atom types. cell The cell. + Returns + ------- + tuple[np.ndarray, np.ndarray] + minrr2 : np.ndarray + The minimal squared distances. + max_nnei : np.ndarray + The maximum number of neighbors. """
80-85
: Add Return Type Hint to_execute
MethodFor consistency and to enhance code readability, consider adding a return type hint to the
_execute
method.Apply this diff:
def _execute( self, coord: np.ndarray, atype: np.ndarray, cell: Optional[np.ndarray], - ): + ) -> tuple[np.ndarray, np.ndarray]: """
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- deepmd/jax/utils/neighbor_stat.py (1 hunks)
🧰 Additional context used
📓 Learnings (1)
deepmd/jax/utils/neighbor_stat.py (1)
Learnt from: njzjz PR: deepmodeling/deepmd-kit#4258 File: deepmd/jax/utils/neighbor_stat.py:98-101 Timestamp: 2024-10-26T02:09:01.365Z Learning: The function `to_jax_array` in `deepmd/jax/common.py` can handle `None` values, so it's safe to pass `None` to it without additional checks.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4258 +/- ##
==========================================
- Coverage 84.23% 84.22% -0.01%
==========================================
Files 548 550 +2
Lines 51425 51469 +44
Branches 3051 3051
==========================================
+ Hits 43317 43352 +35
- Misses 7150 7154 +4
- Partials 958 963 +5 ☔ View full report in Codecov by Sentry. |
Summary by CodeRabbit
Release Notes
New Features
NeighborStat
andNeighborStatOP
classes for enhanced neighbor statistics computation.AutoBatchSize
class to manage automatic batch sizing in deep learning applications.Improvements
JAXBackend
functionality with implemented properties for neighbor statistics and serialization.Tests
neighbor_stat
to support multiple backends (TensorFlow, PyTorch, NumPy, JAX).