Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): neighbor stat #4258

Merged
merged 2 commits into from
Oct 29, 2024
Merged

feat(jax): neighbor stat #4258

merged 2 commits into from
Oct 29, 2024

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Oct 26, 2024

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced NeighborStat and NeighborStatOP classes for enhanced neighbor statistics computation.
    • Added AutoBatchSize class to manage automatic batch sizing in deep learning applications.
  • Improvements

    • Enhanced JAXBackend functionality with implemented properties for neighbor statistics and serialization.
    • Refactored neighbor counting logic for better clarity and modularity.
  • Tests

    • Updated unit tests for neighbor_stat to support multiple backends (TensorFlow, PyTorch, NumPy, JAX).
    • Removed outdated test files to streamline testing processes.

Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

coderabbitai bot commented Oct 26, 2024

📝 Walkthrough

Walkthrough

The pull request introduces several changes across multiple files, primarily enhancing the functionality of the JAXBackend class and related components. Key updates include the implementation of previously unimplemented properties like neighbor_stat, serialize_hook, and deserialize_hook. Additionally, a new AutoBatchSize class is introduced for managing batch sizes in deep learning applications. The NeighborStat class is also added to compute neighbor statistics, while existing tests for the neighbor_stat function are modified or removed to reflect these changes.

Changes

File Change Summary
deepmd/backend/jax.py Modified JAXBackend class to update features attribute using bitwise OR operations; implemented neighbor_stat, serialize_hook, and deserialize_hook.
deepmd/dpmodel/utils/neighbor_stat.py Enhanced NeighborStatOP class to integrate array_api_compat for flexible array handling; refactored call method for improved clarity and modularity.
deepmd/jax/utils/auto_batch_size.py Introduced AutoBatchSize class with methods for checking GPU availability and handling out-of-memory errors.
deepmd/jax/utils/neighbor_stat.py Added NeighborStat class to compute neighbor statistics with methods for iteration and execution of core operations.
source/tests/common/dpmodel/test_neighbor_stat.py Deleted file containing unit tests for neighbor_stat.
source/tests/consistent/test_neighbor_stat.py Modified test structure, renamed test_neighbor_stat to run_neighbor_stat, added backend parameterization and new test methods for various backends.
source/tests/tf/test_neighbor_stat.py Deleted file containing unit tests for neighbor_stat in TensorFlow.

Possibly related PRs

  • feat(jax): support neural networks #4156: The changes in the main PR regarding the neighbor_stat property and its implementation in the JAXBackend class are related to the modifications in the NeighborStat class introduced in the retrieved PR, which also focuses on neighbor statistics and integrates with JAX functionalities.

Suggested reviewers

  • wanghan-iapcm

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Outside diff range and nitpick comments (4)
deepmd/jax/utils/auto_batch_size.py (1)

44-59: Simplify the OOM error detection logic.

The implementation correctly handles OOM errors, but the code can be simplified while maintaining readability.

Apply this diff to simplify the return statement:

    def is_oom_error(self, e: Exception) -> bool:
        # several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error,
        # such as https://github.com/JuliaGPU/CUDA.jl/issues/1924
        # (the meaningless error message should be considered as a bug in cusolver)
-        if isinstance(e, (jaxlib.xla_extension.XlaRuntimeError, ValueError)) and (
-            "RESOURCE_EXHAUSTED:" in e.args[0]
-        ):
-            return True
-        return False
+        return isinstance(e, (jaxlib.xla_extension.XlaRuntimeError, ValueError)) and (
+            "RESOURCE_EXHAUSTED:" in e.args[0]
+        )
🧰 Tools
🪛 Ruff

55-59: Return the condition directly

Inline condition

(SIM103)

deepmd/backend/jax.py (1)

Line range hint 26-120: Consider enhancing error handling for unsupported features.

The backend implementation is well-structured, but consider adding runtime checks or decorators to gracefully handle cases where features are enabled but not fully implemented. This would prevent silent failures and provide better error messages.

Example approach:

  1. Create a decorator that checks if a feature is both enabled and implemented
  2. Apply it to feature-specific properties
  3. Provide clear error messages indicating which features are not fully supported

Would you like me to provide an example implementation of such a decorator?

source/tests/consistent/test_neighbor_stat.py (2)

Line range hint 50-75: LGTM! Well-structured test implementation with comprehensive test cases

The refactoring to run_neighbor_stat with a backend parameter improves code reuse. The test logic is thorough with proper validation of neighbor statistics across different cutoff radii and mixed type configurations.

Consider adding a docstring to document the method's purpose and parameters:

 def run_neighbor_stat(self, backend):
+    """Test neighbor statistics computation for different backends.
+    
+    Args:
+        backend (str): The backend to use ('tensorflow', 'pytorch', 'numpy', or 'jax')
+    """
     for rcut in (0.0, 1.0, 2.0, 4.0):

76-89: LGTM! Well-organized backend-specific test methods

The backend-specific test methods are well-structured with proper conditional execution using skipUnless. Each backend is consistently handled.

Consider grouping the backend tests together by moving the numpy test (test_neighbor_stat_dp) next to other backend tests for better organization:

-    @unittest.skipUnless(INSTALLED_TF, "tensorflow is not installed")
-    def test_neighbor_stat_tf(self):
-        self.run_neighbor_stat("tensorflow")
-
-    @unittest.skipUnless(INSTALLED_PT, "pytorch is not installed")
-    def test_neighbor_stat_pt(self):
-        self.run_neighbor_stat("pytorch")
-
-    def test_neighbor_stat_dp(self):
-        self.run_neighbor_stat("numpy")
-
-    @unittest.skipUnless(INSTALLED_JAX, "jax is not installed")
-    def test_neighbor_stat_jax(self):
-        self.run_neighbor_stat("jax")
+    def test_neighbor_stat_dp(self):
+        """Test neighbor statistics with numpy backend."""
+        self.run_neighbor_stat("numpy")
+
+    @unittest.skipUnless(INSTALLED_TF, "tensorflow is not installed")
+    def test_neighbor_stat_tf(self):
+        """Test neighbor statistics with tensorflow backend."""
+        self.run_neighbor_stat("tensorflow")
+
+    @unittest.skipUnless(INSTALLED_PT, "pytorch is not installed")
+    def test_neighbor_stat_pt(self):
+        """Test neighbor statistics with pytorch backend."""
+        self.run_neighbor_stat("pytorch")
+
+    @unittest.skipUnless(INSTALLED_JAX, "jax is not installed")
+    def test_neighbor_stat_jax(self):
+        """Test neighbor statistics with jax backend."""
+        self.run_neighbor_stat("jax")
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between a66afd3 and dd73acb.

📒 Files selected for processing (7)
  • deepmd/backend/jax.py (2 hunks)
  • deepmd/dpmodel/utils/neighbor_stat.py (2 hunks)
  • deepmd/jax/utils/auto_batch_size.py (1 hunks)
  • deepmd/jax/utils/neighbor_stat.py (1 hunks)
  • source/tests/common/dpmodel/test_neighbor_stat.py (0 hunks)
  • source/tests/consistent/test_neighbor_stat.py (4 hunks)
  • source/tests/tf/test_neighbor_stat.py (0 hunks)
💤 Files with no reviewable changes (2)
  • source/tests/common/dpmodel/test_neighbor_stat.py
  • source/tests/tf/test_neighbor_stat.py
🧰 Additional context used
🪛 Ruff
deepmd/jax/utils/auto_batch_size.py

55-59: Return the condition directly

Inline condition

(SIM103)

🔇 Additional comments (8)
deepmd/jax/utils/auto_batch_size.py (2)

1-23: LGTM! Well-structured class definition with proper documentation.

The class is properly defined with clear inheritance and well-documented parameters.


24-33: LGTM! Clean constructor implementation.

The constructor properly initializes the parent class with sensible default values.

deepmd/backend/jax.py (3)

85-89: LGTM! Clean implementation of neighbor_stat property.

The implementation follows best practices with lazy imports to avoid circular dependencies.


Line range hint 92-120: Verify serialization utilities existence.

The implementation looks clean, but let's verify the existence of the required serialization utilities.

#!/bin/bash
# Description: Verify the existence of serialization utilities

# Check for the serialization module and its functions
rg -l "serialize_from_file|deserialize_to_file" --type py

# Look for any existing serialization tests
fd -e py test.*serial

36-38: Verify ENTRY_POINT feature implementation.

The ENTRY_POINT feature is enabled, but its corresponding entry_point_hook property raises NotImplementedError. This inconsistency could cause runtime issues.

source/tests/consistent/test_neighbor_stat.py (1)

15-19: LGTM! Clean import organization

The new imports for backend availability flags are well-organized and follow Python conventions.

deepmd/jax/utils/neighbor_stat.py (1)

78-78: Verify the Aggregation Functions

When using np.max and np.min, ensure that the axis parameter is correctly set based on the dimensions of max_nnei and minrr2. Verify that these aggregations produce the expected scalar values.

Run the following script to confirm the shapes and values:

✅ Verification successful

Looking at the code, I can now generate a final response since I have enough context to verify the aggregation functions.

Aggregation Functions Are Correctly Used

The axis=0 parameter in np.max(max_nnei, axis=0) is correct because:

  • max_nnei is returned from NeighborStatOP and converted to numpy array
  • The aggregation is performed over batches of data, where axis=0 correctly reduces across the batch dimension
  • minrr2 is correctly reduced using np.min without an axis parameter since we want the global minimum distance
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check the shapes of 'max_nnei' and 'minrr2' returned from '_execute' method.

# Test: Print the shapes. Expect: Correct dimensions for aggregation.
ast-grep --pattern $'def _execute($_) { $$$ }' | while read -r func_def; do
    echo "Analyzing function: $func_def"
    # Assuming the script can extract and print variable shapes
    echo "Shape of max_nnei:"
    # Placeholder command to print shape
    echo "Shape of minrr2:"
    # Placeholder command to print shape
done

Length of output: 66


Script:

#!/bin/bash
# Let's first check the content of the file to understand the context better
cat deepmd/jax/utils/neighbor_stat.py

Length of output: 2792

deepmd/dpmodel/utils/neighbor_stat.py (1)

72-72: Usage of array_namespace is appropriate

The use of array_api_compat.array_namespace to obtain a consistent array API is correct and enhances compatibility across different array backends.

deepmd/jax/utils/auto_batch_size.py Show resolved Hide resolved
deepmd/jax/utils/neighbor_stat.py Show resolved Hide resolved
deepmd/jax/utils/neighbor_stat.py Outdated Show resolved Hide resolved
deepmd/jax/utils/neighbor_stat.py Show resolved Hide resolved
deepmd/jax/utils/neighbor_stat.py Outdated Show resolved Hide resolved
deepmd/dpmodel/utils/neighbor_stat.py Show resolved Hide resolved
deepmd/dpmodel/utils/neighbor_stat.py Show resolved Hide resolved
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (2)
deepmd/jax/utils/neighbor_stat.py (2)

86-96: Add 'Returns' Section to Docstring

The docstring of the _execute method lacks a 'Returns' section describing the outputs of the function. Including this information improves code readability and maintainability.

Apply this diff to update the docstring:

         """
         Execute the operation.

         Parameters
         ----------
         coord
             The coordinates of atoms.
         atype
             The atom types.
         cell
             The cell.
+        Returns
+        -------
+        tuple[np.ndarray, np.ndarray]
+            minrr2 : np.ndarray
+                The minimal squared distances.
+            max_nnei : np.ndarray
+                The maximum number of neighbors.
         """

80-85: Add Return Type Hint to _execute Method

For consistency and to enhance code readability, consider adding a return type hint to the _execute method.

Apply this diff:

     def _execute(
         self,
         coord: np.ndarray,
         atype: np.ndarray,
         cell: Optional[np.ndarray],
-    ):
+    ) -> tuple[np.ndarray, np.ndarray]:
         """
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between dd73acb and e2f6c4a.

📒 Files selected for processing (1)
  • deepmd/jax/utils/neighbor_stat.py (1 hunks)
🧰 Additional context used
📓 Learnings (1)
deepmd/jax/utils/neighbor_stat.py (1)
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#4258
File: deepmd/jax/utils/neighbor_stat.py:98-101
Timestamp: 2024-10-26T02:09:01.365Z
Learning: The function `to_jax_array` in `deepmd/jax/common.py` can handle `None` values, so it's safe to pass `None` to it without additional checks.

deepmd/jax/utils/neighbor_stat.py Show resolved Hide resolved
Copy link

codecov bot commented Oct 26, 2024

Codecov Report

Attention: Patch coverage is 94.64286% with 3 lines in your changes missing coverage. Please review.

Project coverage is 84.22%. Comparing base (a66afd3) to head (e2f6c4a).
Report is 14 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/jax/utils/auto_batch_size.py 75.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4258      +/-   ##
==========================================
- Coverage   84.23%   84.22%   -0.01%     
==========================================
  Files         548      550       +2     
  Lines       51425    51469      +44     
  Branches     3051     3051              
==========================================
+ Hits        43317    43352      +35     
- Misses       7150     7154       +4     
- Partials      958      963       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@njzjz njzjz added this pull request to the merge queue Oct 29, 2024
Merged via the queue into deepmodeling:devel with commit 82aaa0d Oct 29, 2024
60 checks passed
@njzjz njzjz deleted the jax-neighbor-stat branch October 29, 2024 20:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants