Skip to content

Commit

Permalink
add paddle support in array-api-compat
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Nov 26, 2024
1 parent ee25aae commit 8e5cc94
Show file tree
Hide file tree
Showing 19 changed files with 2,088 additions and 97 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/array-api-tests-paddle.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: Array API Tests (Paddle Latest)

on: [push, pull_request]

jobs:
array-api-tests-paddle:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: paddle
extra-env-vars: |
ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64
79 changes: 79 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,33 @@ def is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)

def is_paddle_array(x):
"""
Return True if `x` is a Paddle tensor.
This function does not import Paddle if it has not already been imported
and is therefore cheap to use.
See Also
--------
array_namespace
is_array_api_obj
is_numpy_array
is_cupy_array
is_dask_array
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing paddle if it isn't already
if 'paddle' not in sys.modules:
return False

import paddle

# TODO: Should we reject ndarray subclasses?
return paddle.is_tensor(x)

def is_ndonnx_array(x):
"""
Return True if `x` is a ndonnx Array.
Expand Down Expand Up @@ -252,6 +279,7 @@ def is_array_api_obj(x):
or is_dask_array(x) \
or is_jax_array(x) \
or is_pydata_sparse_array(x) \
or is_paddle_array(x) \
or hasattr(x, '__array_namespace__')

def _compat_module_name():
Expand Down Expand Up @@ -319,6 +347,27 @@ def is_torch_namespace(xp) -> bool:
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}


def is_paddle_namespace(xp) -> bool:
"""
Returns True if `xp` is a Paddle namespace.
This includes both Paddle itself and the version wrapped by array-api-compat.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'paddle', _compat_module_name() + '.paddle'}


def is_ndonnx_namespace(xp):
"""
Returns True if `xp` is an NDONNX namespace.
Expand Down Expand Up @@ -543,6 +592,14 @@ def your_function(x, y):
else:
import jax.experimental.array_api as jnp
namespaces.add(jnp)
elif is_paddle_array(x):
if _use_compat:
_check_api_version(api_version)
from .. import paddle as paddle_namespace
namespaces.add(paddle_namespace)
else:
import paddle
namespaces.add(paddle)
elif is_pydata_sparse_array(x):
if use_compat is True:
_check_api_version(api_version)
Expand Down Expand Up @@ -660,6 +717,16 @@ def device(x: Array, /) -> Device:
return "cpu"
# Return the device of the constituent array
return device(inner)
elif is_paddle_array(x):
raw_place_str = str(x.place)
if "gpu_pinned" in raw_place_str:
return "cpu"
elif "cpu" in raw_place_str:
return "cpu"
elif "gpu" in raw_place_str:
return "gpu"
raise NotImplementedError(f"Unsupported device {raw_place_str}")

return x.device

# Prevent shadowing, used below
Expand Down Expand Up @@ -709,6 +776,14 @@ def _torch_to_device(x, device, /, stream=None):
raise NotImplementedError
return x.to(device)

def _paddle_to_device(x, device, /, stream=None):
if stream is not None:
raise NotImplementedError(
"paddle.Tensor.to() do not support stream argument yet"
)
return x.to(device)


def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
"""
Copy the array from the device on which it currently resides to the specified ``device``.
Expand Down Expand Up @@ -781,6 +856,8 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
# In JAX v0.4.31 and older, this import adds to_device method to x.
import jax.experimental.array_api # noqa: F401
return x.to_device(device, stream=stream)
elif is_paddle_array(x):
return _paddle_to_device(x, device, stream=stream)
elif is_pydata_sparse_array(x) and device == _device(x):
# Perform trivial check to return the same array if
# device is same instead of err-ing.
Expand Down Expand Up @@ -819,6 +896,8 @@ def size(x):
"is_torch_namespace",
"is_ndonnx_array",
"is_ndonnx_namespace",
"is_paddle_array",
"is_paddle_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"size",
Expand Down
28 changes: 28 additions & 0 deletions array_api_compat/paddle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from paddle import * # noqa: F403

# Several names are not included in the above import *
import paddle

for n in dir(paddle):
if (
n.startswith("_")
or n.endswith("_")
or "gpu" in n
or "cpu" in n
or "backward" in n
):
continue
exec(n + " = paddle." + n)
exec("asarray = paddle.to_tensor")

# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403

# See the comment in the numpy __init__.py
__import__(__package__ + ".linalg")

__import__(__package__ + ".fft")

from ..common._helpers import * # noqa: F403

__array_api_version__ = "2023.12"
Loading

0 comments on commit 8e5cc94

Please sign in to comment.