-
Notifications
You must be signed in to change notification settings - Fork 520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jax): support neural networks #4156
Conversation
Start a JAX support. Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
WalkthroughWalkthroughThis pull request introduces significant updates to the DeepMD library by integrating the JAX numerical computing library into various components. Changes include modifications to workflow files to install JAX, the addition of new backend support with a dedicated JAX implementation, and enhancements to neural network functionalities to accommodate JAX arrays. New utility functions for converting between NumPy and JAX arrays are also introduced, alongside tests to ensure compatibility with JAX. Changes
Possibly related PRs
Suggested labels
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
|
||
|
||
@overload | ||
def to_jax_array(array: np.ndarray) -> jnp.ndarray: ... |
Check notice
Code scanning / CodeQL
Statement has no effect Note
|
||
|
||
@overload | ||
def to_jax_array(array: None) -> None: ... |
Check notice
Code scanning / CodeQL
Statement has no effect Note
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Outside diff range and nitpick comments (3)
deepmd/backend/jax.py (1)
66-110
: Implement the pending backend methodsThe methods
entry_point_hook
,deep_eval
,neighbor_stat
,serialize_hook
, anddeserialize_hook
currently raiseNotImplementedError
. If these methods are essential for the JAX backend functionality, consider implementing them to ensure complete integration. If they are slated for future development, adding comments to indicate their planned implementation could enhance maintainability.Do you want me to help implement these methods or open GitHub issues to track their development?
source/tests/consistent/test_type_embedding.py (1)
116-123
: Enhance the exception message ineval_jax
methodThe exception raised when the output is a NumPy array could provide more context to aid in debugging.
Consider updating the error message to include the unexpected type:
- raise ValueError("Output is numpy array") + raise ValueError(f"Unexpected output type {type(x)}; expected non-NumPy array")source/tests/consistent/common.py (1)
86-88
: Improve clarity and consistency of the JAX skipping commentThe inline comment on line 86 could be rephrased for clarity and to match the commenting style used elsewhere.
Consider revising the comment for better readability:
-# we may usually skip jax before jax is fully supported +"""Whether to skip the JAX model. We typically skip JAX tests until JAX support is fully integrated. """
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (17)
- .github/workflows/test_cuda.yml (1 hunks)
- .github/workflows/test_python.yml (1 hunks)
- deepmd/backend/jax.py (1 hunks)
- deepmd/dpmodel/common.py (2 hunks)
- deepmd/dpmodel/utils/network.py (4 hunks)
- deepmd/dpmodel/utils/type_embed.py (2 hunks)
- deepmd/jax/init.py (1 hunks)
- deepmd/jax/common.py (1 hunks)
- deepmd/jax/env.py (1 hunks)
- deepmd/jax/utils/init.py (1 hunks)
- deepmd/jax/utils/network.py (1 hunks)
- deepmd/jax/utils/type_embed.py (1 hunks)
- pyproject.toml (1 hunks)
- source/tests/common/dpmodel/array_api/test_activation_functions.py (1 hunks)
- source/tests/consistent/common.py (8 hunks)
- source/tests/consistent/test_activation.py (4 hunks)
- source/tests/consistent/test_type_embedding.py (4 hunks)
Files skipped from review due to trivial changes (3)
- deepmd/jax/init.py
- deepmd/jax/utils/init.py
- source/tests/common/dpmodel/array_api/test_activation_functions.py
Additional context used
GitHub Check: CodeQL
deepmd/jax/common.py
[notice] 15-15: Statement has no effect
This statement has no effect.
[notice] 19-19: Statement has no effect
This statement has no effect.
Additional comments not posted (46)
deepmd/jax/env.py (4)
1-1
: LGTM!Using an SPDX license identifier is a good practice for clearly specifying the license of the file.
4-4
: LGTM!Setting
XLA_PYTHON_CLIENT_PREALLOCATE
to "false" is a good optimization to avoid unnecessary memory pre-allocation in JAX.
9-9
: Verify the need for 64-bit precision and consider the performance impact.Enabling 64-bit floating-point precision in JAX can improve numerical accuracy but may impact performance. Please ensure that the decision to enable 64-bit precision aligns with the project's requirements and consider the potential performance trade-offs.
11-14
: LGTM!Using
__all__
to explicitly specify the public API of the module is a good practice. It clearly communicates the intended exports and helps maintain a clean and well-defined interface.deepmd/jax/utils/type_embed.py (1)
15-21
: LGTM!The
TypeEmbedNet
class correctly extendsTypeEmbedNetDP
and overrides the__setattr__
method to handle JAX-specific attribute assignments. The changes are in line with the PR objective of integrating JAX into the library.The class is importing the necessary utilities and there are no apparent issues with the logic or syntax. The changes seem to be localized to this class and do not appear to have any negative impact on the rest of the codebase.
deepmd/jax/common.py (1)
14-37
: LGTM!The
to_jax_array
function is well-implemented and documented. It correctly handles the conversion of NumPy arrays to JAX arrays, including the case when the input isNone
.Regarding the static analysis hints:
The overload statements at lines 15 and 19 are used for type hinting and do not have any runtime effect. They are not actual statements, so it's safe to ignore these hints.
Tools
GitHub Check: CodeQL
[notice] 15-15: Statement has no effect
This statement has no effect.
[notice] 19-19: Statement has no effect
This statement has no effect.deepmd/dpmodel/common.py (1)
66-81
: LGTM!The
to_numpy_array
function looks good:
- It correctly handles the case when the input is
None
.- It uses
np.asarray
which is the recommended way to convert an array-like object to a NumPy array.- The type hints are correct and improve the readability of the function.
- The function is well-documented with a docstring that follows the NumPy docstring format.
This utility function can be used in other parts of the codebase to ensure consistent conversion of arrays to NumPy arrays.
source/tests/consistent/test_activation.py (5)
2-2
: LGTM!The import changes look good:
sys
module import is valid.INSTALLED_JAX
import is valid and likely used for conditional JAX-related tests.Also applies to: 16-16
33-36
: LGTM!The JAX-related import changes look good:
- The conditional import based on
INSTALLED_JAX
is a good practice.- Importing
jnp
fromdeepmd.jax.env
is valid for JAX-related operations.
67-78
: LGTM!The new test method
test_arary_api_strict
looks good:
- The Python version check and conditional skip is a good practice.
- Using
array_api_strict
helps ensure compatibility with the array API standard.- Setting the
array_api_strict
flags based on thearray_api_version
of the tested function is correct.- The test logic of applying the activation function and comparing with the reference is valid.
80-85
: LGTM!The new test method
test_jax_consistent_with_ref
looks good:
- The conditional skip based on
INSTALLED_JAX
is a good practice for JAX-specific tests.- Converting the input from NumPy to JAX using
jnp.from_dlpack
is correct.- Applying the activation function to the JAX array is the main test logic and is valid.
- Asserting that the result is a JAX array is a good additional check.
- Converting the result back to NumPy for comparison with the reference is correct.
Line range hint
1-85
: Great work on the test file!The file is well-structured and organized, making it easy to understand and maintain. Some notable good practices:
- Parameterization of tests over valid activation functions ensures thorough testing.
- Conditional skipping of tests based on library availability ensures compatibility.
- Test methods cover important aspects of consistency across different frameworks and APIs.
Keep up the good work!
.github/workflows/test_cuda.yml (1)
54-54
: Looks good! Verify JAX integration across the codebase.The addition of
jax
to the list of extras aligns with the goal of integrating JAX into the library. Please ensure that the rest of the codebase is updated to leverage JAX effectively and maintain compatibility.To verify the JAX integration, consider running the following script:
.github/workflows/test_python.yml (1)
31-31
: LGTM!The addition of the
jax
extra in the pip install command is consistent with the PR objectives of supporting JAX in the DeepMD library. This change is necessary to install the required dependencies for JAX support.pyproject.toml (1)
135-137
: LGTM!The new dependency declaration for the
jax
library under thecu12
section looks good. The version constraint ensures a minimum compatible version is used for Python 3.10 and above.deepmd/jax/utils/network.py (1)
1-29
: LGTM!The integration of JAX into the neural network utilities is well-executed. The
NativeLayer
class effectively ensures that the attributesw
,b
, andidt
are converted to JAX arrays upon assignment, maintaining consistency within the JAX computational framework. The creation ofNativeNet
,EmbeddingNet
, andFittingNet
using the provided factory functions is correctly structured and should facilitate the building of complex neural network models with JAX support.source/tests/consistent/test_type_embedding.py (4)
16-16
: AddedINSTALLED_JAX
to importsIncluding
INSTALLED_JAX
in the imports allows for conditional functionality based on JAX installation.
34-40
: Implemented conditional imports for JAX supportThe addition of conditional imports for JAX ensures that JAX-specific code is only executed when JAX is installed, maintaining compatibility.
74-74
: Addedjax_class
attribute to support JAX backendThe inclusion of
jax_class = TypeEmbedNetJAX
maintains consistency with other backend classes and facilitates testing with the JAX backend.
76-76
: Introducedskip_jax
flag to conditionally skip JAX testsThe
skip_jax
flag correctly determines whether to skip JAX-related tests based on the installation status of JAX.deepmd/dpmodel/utils/type_embed.py (5)
8-13
: Imports added for array API compatibilityThe necessary imports
array_api_compat
andsupport_array_api
are correctly added to support array API compatibility.
99-99
: Decorator applied for array API supportThe
@support_array_api(version="2022.12")
decorator is appropriately applied to thecall
method to ensure compatibility with the specified array API version.
102-103
: Retrieving array namespace for flexible array handlingThe code correctly retrieves the array namespace using
array_api_compat.array_namespace(sample_array)
, allowing for consistent handling of arrays across different backends.
105-105
: Using array API'seye
functionReplacing
np.eye
withxp.eye
ensures compatibility with different array backends in accordance with the array API standards.
109-110
: Updating padding logic for array API compatibilityThe padding operation is appropriately updated to use
xp.zeros
andxp.concatenate
, ensuring consistency with the array API approach.source/tests/consistent/common.py (7)
38-38
: JAX backend availability check added correctlyThe
INSTALLED_JAX
variable is successfully introduced to check the availability of the JAX backend.
61-61
: ExportedINSTALLED_JAX
in__all__
Adding
INSTALLED_JAX
to the__all__
list ensures it is properly exported and accessible when the module is imported elsewhere.
76-77
: Addedjax_class
variable for JAX model classIntroducing
jax_class
aligns with the existing structure for other backends liketf_class
andpt_class
, promoting consistency.
175-175
: AddedJAX
entry toRefBackend
enumIncluding
JAX
in theRefBackend
enum extends the reference backend options appropriately.
252-253
: Verify JAX as a reference backend is fully supportedWhen
skip_jax
isFalse
, and JAX is selected as the reference backend, ensure that all necessary methods are implemented to prevent test failures.Run the following script to check the implementation status of required methods for JAX as a reference backend:
#!/bin/bash # Description: Check for implementations of methods required when using JAX as the reference backend. # List methods that require implementation for JAX rg --type py 'def.*(self,.*):' -A 2 | rg 'eval_jax|get_jax_ret_serialization_from_cls|extract_ret' # Confirm that these methods do not raise `NotImplementedError`
236-240
: Ensureget_jax_ret_serialization_from_cls
handles unimplementedeval_jax
Since
eval_jax
raisesNotImplementedError
, callingget_jax_ret_serialization_from_cls
may result in unexpected errors. Verify that calls to this method are appropriately handled or deferred untileval_jax
is implemented.Run the following script to locate calls to
get_jax_ret_serialization_from_cls
and check for exception handling:#!/bin/bash # Description: Find all calls to `get_jax_ret_serialization_from_cls` and assess error handling. # Search for calls to `get_jax_ret_serialization_from_cls` rg --type py '\.get_jax_ret_serialization_from_cls\(' -A 5 # Ensure that calls are safeguarded against `NotImplementedError`
159-168
: Implementeval_jax
method or handleNotImplementedError
appropriatelyThe
eval_jax
method currently raisesNotImplementedError
, which may cause tests to fail ifskip_jax
is set toFalse
. Ensure that this is intentional and that tests handle this exception properly.Run the following script to identify where
eval_jax
is called and verify exception handling:deepmd/dpmodel/utils/network.py (14)
18-18
: Importarray_api_compat
for array API supportThe addition of
import array_api_compat
is necessary for array API compatibility and is correctly implemented.
26-31
: Add required imports for array API supportImports for
support_array_api
andto_numpy_array
are added fromdeepmd.dpmodel.array_api
anddeepmd.dpmodel.common
, respectively. This ensures that array API support and proper array conversion are available.
115-117
: Ensure consistent serialization withto_numpy_array
Using
to_numpy_array
forself.w
,self.b
, andself.idt
before serialization ensures that the data is consistently converted to NumPy arrays, enhancing cross-backend compatibility.
225-225
: Decoratecall
method with@support_array_api
Applying
@support_array_api(version="2022.12")
to thecall
method enables array API support, allowing the method to be compatible with different array backends following the specified version.
241-241
: Retrieve array namespace for backend compatibilityUsing
xp = array_api_compat.array_namespace(x)
correctly retrieves the array namespace corresponding to the inputx
, ensuring that subsequent operations use the appropriate backend functions.
244-246
: Use array API matrix multiplication and additionReplacing
np.matmul
withxp.matmul
and handling biases with array API functions ensures that matrix operations are compatible across different backends.
254-254
: Implement array API concatenationUsing
xp.concatenate
ensures that concatenation operations adhere to the array API standards, maintaining compatibility.
258-258
: Decorateget_activation_fn
with@support_array_api
Applying
@support_array_api(version="2022.12")
toget_activation_fn
ensures that the activation functions returned are compatible with the array API standard.
263-266
: Updatetanh
activation function for array APIThe
tanh
activation function now usesxp.tanh(x)
, ensuring compatibility with different array backends via the array API.
271-273
: Updaterelu
activation function for array APIThe
relu
activation function usesxp
operations, enhancing compatibility with various backends through the array API.
279-285
: Updategelu
activation function for array APIThe
gelu
activation function is updated to usexp
operations, ensuring array API compliance.
291-295
: Updaterelu6
activation function for array APIThe
relu6
activation function now usesxp.where
andxp.full_like
for array API compatibility.
301-303
: Updatesoftplus
activation function for array APIThe
softplus
activation function usesxp.log
andxp.exp
, ensuring compatibility with the array API standard.
309-311
: Updatesigmoid
activation function for array APIThe
sigmoid
activation function now usesxp.exp
to ensure array API compliance.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4156 +/- ##
==========================================
+ Coverage 83.41% 83.42% +0.01%
==========================================
Files 532 537 +5
Lines 52048 52146 +98
Branches 3046 3046
==========================================
+ Hits 43416 43504 +88
- Misses 7684 7692 +8
- Partials 948 950 +2 ☔ View full report in Codecov by Sentry. |
Summary by CodeRabbit
Release Notes
New Features
JAXBackend
class for backend integration with JAX.Bug Fixes
Tests
Chores