Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): support neural networks #4156

Merged
merged 3 commits into from
Sep 23, 2024
Merged

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Sep 23, 2024

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced JAX support, enhancing functionality and compatibility with JAX library.
    • Added new JAXBackend class for backend integration with JAX.
    • New functions for converting between NumPy and JAX arrays.
  • Bug Fixes

    • Improved compatibility of neural network layers with array API standards.
  • Tests

    • Added tests for JAX functionality and consistency checks against reference outputs.
    • Enhanced testing framework for activation functions and type embeddings.
  • Chores

    • Updated dependency requirements to include JAX library.

Start a JAX support.

Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

coderabbitai bot commented Sep 23, 2024

Walkthrough

Walkthrough

This pull request introduces significant updates to the DeepMD library by integrating the JAX numerical computing library into various components. Changes include modifications to workflow files to install JAX, the addition of new backend support with a dedicated JAX implementation, and enhancements to neural network functionalities to accommodate JAX arrays. New utility functions for converting between NumPy and JAX arrays are also introduced, alongside tests to ensure compatibility with JAX.

Changes

File(s) Change Summary
.github/workflows/test_cuda.yml, .github/workflows/test_python.yml Updated installation commands to include the jax package in both workflow files.
deepmd/backend/jax.py Introduced JAXBackend class for JAX support, including methods to check availability.
deepmd/dpmodel/common.py Added to_numpy_array function for converting inputs to NumPy arrays.
deepmd/dpmodel/utils/network.py, deepmd/dpmodel/utils/type_embed.py Enhanced compatibility with array API standards, including updates to matrix operations and activation functions.
deepmd/jax/common.py, deepmd/jax/env.py, deepmd/jax/utils/network.py, deepmd/jax/utils/type_embed.py Introduced new files and classes for JAX integration, including utility functions and network constructs.
pyproject.toml Added dependency for jax with version constraints for Python 3.10 and above.
source/tests/common/dpmodel/array_api/test_activation_functions.py, source/tests/consistent/common.py, source/tests/consistent/test_activation.py, source/tests/consistent/test_type_embedding.py Added tests and modified existing tests to support JAX functionality and ensure consistency.

Possibly related PRs

Suggested labels

Python, Examples


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share
Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    -- I pushed a fix in commit <commit_id>, please review it.
    -- Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    -- @coderabbitai generate unit testing code for this file.
    -- @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    -- @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    -- @coderabbitai read src/utils.ts and generate unit testing code.
    -- @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    -- @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.



@overload
def to_jax_array(array: np.ndarray) -> jnp.ndarray: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


@overload
def to_jax_array(array: None) -> None: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Outside diff range and nitpick comments (3)
deepmd/backend/jax.py (1)

66-110: Implement the pending backend methods

The methods entry_point_hook, deep_eval, neighbor_stat, serialize_hook, and deserialize_hook currently raise NotImplementedError. If these methods are essential for the JAX backend functionality, consider implementing them to ensure complete integration. If they are slated for future development, adding comments to indicate their planned implementation could enhance maintainability.

Do you want me to help implement these methods or open GitHub issues to track their development?

source/tests/consistent/test_type_embedding.py (1)

116-123: Enhance the exception message in eval_jax method

The exception raised when the output is a NumPy array could provide more context to aid in debugging.

Consider updating the error message to include the unexpected type:

-                raise ValueError("Output is numpy array")
+                raise ValueError(f"Unexpected output type {type(x)}; expected non-NumPy array")
source/tests/consistent/common.py (1)

86-88: Improve clarity and consistency of the JAX skipping comment

The inline comment on line 86 could be rephrased for clarity and to match the commenting style used elsewhere.

Consider revising the comment for better readability:

-# we may usually skip jax before jax is fully supported
+"""Whether to skip the JAX model.

We typically skip JAX tests until JAX support is fully integrated.
"""
Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 6010c73 and a89f1eb.

Files selected for processing (17)
  • .github/workflows/test_cuda.yml (1 hunks)
  • .github/workflows/test_python.yml (1 hunks)
  • deepmd/backend/jax.py (1 hunks)
  • deepmd/dpmodel/common.py (2 hunks)
  • deepmd/dpmodel/utils/network.py (4 hunks)
  • deepmd/dpmodel/utils/type_embed.py (2 hunks)
  • deepmd/jax/init.py (1 hunks)
  • deepmd/jax/common.py (1 hunks)
  • deepmd/jax/env.py (1 hunks)
  • deepmd/jax/utils/init.py (1 hunks)
  • deepmd/jax/utils/network.py (1 hunks)
  • deepmd/jax/utils/type_embed.py (1 hunks)
  • pyproject.toml (1 hunks)
  • source/tests/common/dpmodel/array_api/test_activation_functions.py (1 hunks)
  • source/tests/consistent/common.py (8 hunks)
  • source/tests/consistent/test_activation.py (4 hunks)
  • source/tests/consistent/test_type_embedding.py (4 hunks)
Files skipped from review due to trivial changes (3)
  • deepmd/jax/init.py
  • deepmd/jax/utils/init.py
  • source/tests/common/dpmodel/array_api/test_activation_functions.py
Additional context used
GitHub Check: CodeQL
deepmd/jax/common.py

[notice] 15-15: Statement has no effect
This statement has no effect.


[notice] 19-19: Statement has no effect
This statement has no effect.

Additional comments not posted (46)
deepmd/jax/env.py (4)

1-1: LGTM!

Using an SPDX license identifier is a good practice for clearly specifying the license of the file.


4-4: LGTM!

Setting XLA_PYTHON_CLIENT_PREALLOCATE to "false" is a good optimization to avoid unnecessary memory pre-allocation in JAX.


9-9: Verify the need for 64-bit precision and consider the performance impact.

Enabling 64-bit floating-point precision in JAX can improve numerical accuracy but may impact performance. Please ensure that the decision to enable 64-bit precision aligns with the project's requirements and consider the potential performance trade-offs.


11-14: LGTM!

Using __all__ to explicitly specify the public API of the module is a good practice. It clearly communicates the intended exports and helps maintain a clean and well-defined interface.

deepmd/jax/utils/type_embed.py (1)

15-21: LGTM!

The TypeEmbedNet class correctly extends TypeEmbedNetDP and overrides the __setattr__ method to handle JAX-specific attribute assignments. The changes are in line with the PR objective of integrating JAX into the library.

The class is importing the necessary utilities and there are no apparent issues with the logic or syntax. The changes seem to be localized to this class and do not appear to have any negative impact on the rest of the codebase.

deepmd/jax/common.py (1)

14-37: LGTM!

The to_jax_array function is well-implemented and documented. It correctly handles the conversion of NumPy arrays to JAX arrays, including the case when the input is None.

Regarding the static analysis hints:

The overload statements at lines 15 and 19 are used for type hinting and do not have any runtime effect. They are not actual statements, so it's safe to ignore these hints.

Tools
GitHub Check: CodeQL

[notice] 15-15: Statement has no effect
This statement has no effect.


[notice] 19-19: Statement has no effect
This statement has no effect.

deepmd/dpmodel/common.py (1)

66-81: LGTM!

The to_numpy_array function looks good:

  • It correctly handles the case when the input is None.
  • It uses np.asarray which is the recommended way to convert an array-like object to a NumPy array.
  • The type hints are correct and improve the readability of the function.
  • The function is well-documented with a docstring that follows the NumPy docstring format.

This utility function can be used in other parts of the codebase to ensure consistent conversion of arrays to NumPy arrays.

source/tests/consistent/test_activation.py (5)

2-2: LGTM!

The import changes look good:

  • sys module import is valid.
  • INSTALLED_JAX import is valid and likely used for conditional JAX-related tests.

Also applies to: 16-16


33-36: LGTM!

The JAX-related import changes look good:

  • The conditional import based on INSTALLED_JAX is a good practice.
  • Importing jnp from deepmd.jax.env is valid for JAX-related operations.

67-78: LGTM!

The new test method test_arary_api_strict looks good:

  • The Python version check and conditional skip is a good practice.
  • Using array_api_strict helps ensure compatibility with the array API standard.
  • Setting the array_api_strict flags based on the array_api_version of the tested function is correct.
  • The test logic of applying the activation function and comparing with the reference is valid.

80-85: LGTM!

The new test method test_jax_consistent_with_ref looks good:

  • The conditional skip based on INSTALLED_JAX is a good practice for JAX-specific tests.
  • Converting the input from NumPy to JAX using jnp.from_dlpack is correct.
  • Applying the activation function to the JAX array is the main test logic and is valid.
  • Asserting that the result is a JAX array is a good additional check.
  • Converting the result back to NumPy for comparison with the reference is correct.

Line range hint 1-85: Great work on the test file!

The file is well-structured and organized, making it easy to understand and maintain. Some notable good practices:

  • Parameterization of tests over valid activation functions ensures thorough testing.
  • Conditional skipping of tests based on library availability ensures compatibility.
  • Test methods cover important aspects of consistency across different frameworks and APIs.

Keep up the good work!

.github/workflows/test_cuda.yml (1)

54-54: Looks good! Verify JAX integration across the codebase.

The addition of jax to the list of extras aligns with the goal of integrating JAX into the library. Please ensure that the rest of the codebase is updated to leverage JAX effectively and maintain compatibility.

To verify the JAX integration, consider running the following script:

.github/workflows/test_python.yml (1)

31-31: LGTM!

The addition of the jax extra in the pip install command is consistent with the PR objectives of supporting JAX in the DeepMD library. This change is necessary to install the required dependencies for JAX support.

pyproject.toml (1)

135-137: LGTM!

The new dependency declaration for the jax library under the cu12 section looks good. The version constraint ensures a minimum compatible version is used for Python 3.10 and above.

deepmd/jax/utils/network.py (1)

1-29: LGTM!

The integration of JAX into the neural network utilities is well-executed. The NativeLayer class effectively ensures that the attributes w, b, and idt are converted to JAX arrays upon assignment, maintaining consistency within the JAX computational framework. The creation of NativeNet, EmbeddingNet, and FittingNet using the provided factory functions is correctly structured and should facilitate the building of complex neural network models with JAX support.

source/tests/consistent/test_type_embedding.py (4)

16-16: Added INSTALLED_JAX to imports

Including INSTALLED_JAX in the imports allows for conditional functionality based on JAX installation.


34-40: Implemented conditional imports for JAX support

The addition of conditional imports for JAX ensures that JAX-specific code is only executed when JAX is installed, maintaining compatibility.


74-74: Added jax_class attribute to support JAX backend

The inclusion of jax_class = TypeEmbedNetJAX maintains consistency with other backend classes and facilitates testing with the JAX backend.


76-76: Introduced skip_jax flag to conditionally skip JAX tests

The skip_jax flag correctly determines whether to skip JAX-related tests based on the installation status of JAX.

deepmd/dpmodel/utils/type_embed.py (5)

8-13: Imports added for array API compatibility

The necessary imports array_api_compat and support_array_api are correctly added to support array API compatibility.


99-99: Decorator applied for array API support

The @support_array_api(version="2022.12") decorator is appropriately applied to the call method to ensure compatibility with the specified array API version.


102-103: Retrieving array namespace for flexible array handling

The code correctly retrieves the array namespace using array_api_compat.array_namespace(sample_array), allowing for consistent handling of arrays across different backends.


105-105: Using array API's eye function

Replacing np.eye with xp.eye ensures compatibility with different array backends in accordance with the array API standards.


109-110: Updating padding logic for array API compatibility

The padding operation is appropriately updated to use xp.zeros and xp.concatenate, ensuring consistency with the array API approach.

source/tests/consistent/common.py (7)

38-38: JAX backend availability check added correctly

The INSTALLED_JAX variable is successfully introduced to check the availability of the JAX backend.


61-61: Exported INSTALLED_JAX in __all__

Adding INSTALLED_JAX to the __all__ list ensures it is properly exported and accessible when the module is imported elsewhere.


76-77: Added jax_class variable for JAX model class

Introducing jax_class aligns with the existing structure for other backends like tf_class and pt_class, promoting consistency.


175-175: Added JAX entry to RefBackend enum

Including JAX in the RefBackend enum extends the reference backend options appropriately.


252-253: Verify JAX as a reference backend is fully supported

When skip_jax is False, and JAX is selected as the reference backend, ensure that all necessary methods are implemented to prevent test failures.

Run the following script to check the implementation status of required methods for JAX as a reference backend:

#!/bin/bash
# Description: Check for implementations of methods required when using JAX as the reference backend.

# List methods that require implementation for JAX
rg --type py 'def.*(self,.*):' -A 2 | rg 'eval_jax|get_jax_ret_serialization_from_cls|extract_ret'

# Confirm that these methods do not raise `NotImplementedError`

236-240: Ensure get_jax_ret_serialization_from_cls handles unimplemented eval_jax

Since eval_jax raises NotImplementedError, calling get_jax_ret_serialization_from_cls may result in unexpected errors. Verify that calls to this method are appropriately handled or deferred until eval_jax is implemented.

Run the following script to locate calls to get_jax_ret_serialization_from_cls and check for exception handling:

#!/bin/bash
# Description: Find all calls to `get_jax_ret_serialization_from_cls` and assess error handling.

# Search for calls to `get_jax_ret_serialization_from_cls`
rg --type py '\.get_jax_ret_serialization_from_cls\(' -A 5

# Ensure that calls are safeguarded against `NotImplementedError`

159-168: Implement eval_jax method or handle NotImplementedError appropriately

The eval_jax method currently raises NotImplementedError, which may cause tests to fail if skip_jax is set to False. Ensure that this is intentional and that tests handle this exception properly.

Run the following script to identify where eval_jax is called and verify exception handling:

deepmd/dpmodel/utils/network.py (14)

18-18: Import array_api_compat for array API support

The addition of import array_api_compat is necessary for array API compatibility and is correctly implemented.


26-31: Add required imports for array API support

Imports for support_array_api and to_numpy_array are added from deepmd.dpmodel.array_api and deepmd.dpmodel.common, respectively. This ensures that array API support and proper array conversion are available.


115-117: Ensure consistent serialization with to_numpy_array

Using to_numpy_array for self.w, self.b, and self.idt before serialization ensures that the data is consistently converted to NumPy arrays, enhancing cross-backend compatibility.


225-225: Decorate call method with @support_array_api

Applying @support_array_api(version="2022.12") to the call method enables array API support, allowing the method to be compatible with different array backends following the specified version.


241-241: Retrieve array namespace for backend compatibility

Using xp = array_api_compat.array_namespace(x) correctly retrieves the array namespace corresponding to the input x, ensuring that subsequent operations use the appropriate backend functions.


244-246: Use array API matrix multiplication and addition

Replacing np.matmul with xp.matmul and handling biases with array API functions ensures that matrix operations are compatible across different backends.


254-254: Implement array API concatenation

Using xp.concatenate ensures that concatenation operations adhere to the array API standards, maintaining compatibility.


258-258: Decorate get_activation_fn with @support_array_api

Applying @support_array_api(version="2022.12") to get_activation_fn ensures that the activation functions returned are compatible with the array API standard.


263-266: Update tanh activation function for array API

The tanh activation function now uses xp.tanh(x), ensuring compatibility with different array backends via the array API.


271-273: Update relu activation function for array API

The relu activation function uses xp operations, enhancing compatibility with various backends through the array API.


279-285: Update gelu activation function for array API

The gelu activation function is updated to use xp operations, ensuring array API compliance.


291-295: Update relu6 activation function for array API

The relu6 activation function now uses xp.where and xp.full_like for array API compatibility.


301-303: Update softplus activation function for array API

The softplus activation function uses xp.log and xp.exp, ensuring compatibility with the array API standard.


309-311: Update sigmoid activation function for array API

The sigmoid activation function now uses xp.exp to ensure array API compliance.

deepmd/backend/jax.py Show resolved Hide resolved
source/tests/consistent/common.py Show resolved Hide resolved
source/tests/consistent/common.py Show resolved Hide resolved
Copy link

codecov bot commented Sep 23, 2024

Codecov Report

Attention: Patch coverage is 92.52336% with 8 lines in your changes missing coverage. Please review.

Project coverage is 83.42%. Comparing base (6010c73) to head (a89f1eb).
Report is 2 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/backend/jax.py 75.00% 8 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4156      +/-   ##
==========================================
+ Coverage   83.41%   83.42%   +0.01%     
==========================================
  Files         532      537       +5     
  Lines       52048    52146      +98     
  Branches     3046     3046              
==========================================
+ Hits        43416    43504      +88     
- Misses       7684     7692       +8     
- Partials      948      950       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants