Skip to content

Commit

Permalink
feat(jax/array-api): dpa1 (deepmodeling#4160)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Updated method for converting input to NumPy arrays, enhancing
performance and compatibility with array-like structures.
- Simplified handling of weight, bias, and identity variables for
improved compatibility with array backends.
- Introduced new network classes and enhanced network management
functionalities.
	- Added support for the new `array_api_strict` backend in testing.

- **Bug Fixes**
- Fixed serialization process to ensure accurate conversion of weights
and biases.

- **Tests**
- Added tests to validate the new functionalities and ensure
compatibility across various backends, including JAX and Array API
Strict.

- **Chores**
- Continued improvements to project structure and dependencies for
better maintainability.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 9, 2024
1 parent 9a15bc0 commit 3939786
Show file tree
Hide file tree
Showing 29 changed files with 1,022 additions and 173 deletions.
40 changes: 40 additions & 0 deletions deepmd/dpmodel/array_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Utilities for the array API."""

import array_api_compat


def support_array_api(version: str) -> callable:
"""Mark a function as supporting the specific version of the array API.
Expand All @@ -27,3 +29,41 @@ def set_version(func: callable) -> callable:
return func

return set_version


# array api adds take_along_axis in https://github.com/data-apis/array-api/pull/816
# but it hasn't been released yet
# below is a pure Python implementation of take_along_axis
# https://github.com/data-apis/array-api/issues/177#issuecomment-2093630595
def xp_swapaxes(a, axis1, axis2):
xp = array_api_compat.array_namespace(a)
axes = list(range(a.ndim))
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
a = xp.permute_dims(a, axes)
return a


def xp_take_along_axis(arr, indices, axis):
xp = array_api_compat.array_namespace(arr)
arr = xp_swapaxes(arr, axis, -1)
indices = xp_swapaxes(indices, axis, -1)

m = arr.shape[-1]
n = indices.shape[-1]

shape = list(arr.shape)
shape.pop(-1)
shape = [*shape, n]

arr = xp.reshape(arr, (-1,))
if n != 0:
indices = xp.reshape(indices, (-1, n))
else:
indices = xp.reshape(indices, (0, 0))

offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis]
indices = xp.reshape(offset + indices, (-1,))

out = xp.take(arr, indices)
out = xp.reshape(out, shape)
return xp_swapaxes(out, axis, -1)
Loading

0 comments on commit 3939786

Please sign in to comment.