Skip to content

Commit

Permalink
update README
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Nov 26, 2024
1 parent 8e5cc94 commit 7118894
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 52 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

This is a small wrapper around common array libraries that is compatible with
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want
NumPy, CuPy, PyTorch, Dask, JAX, ndonnx, `sparse` and Paddle are supported. If you want
support for other array libraries, or if you encounter any issues, please [open
an issue](https://github.com/data-apis/array-api-compat/issues).

See the documentation for more details https://data-apis.org/array-api-compat/
See the documentation for more details <https://data-apis.org/array-api-compat/>
26 changes: 13 additions & 13 deletions array_api_compat/torch/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from typing import TYPE_CHECKING
if TYPE_CHECKING:
import paddle
array = paddle.Tensor
import torch
array = torch.Tensor
from typing import Union, Sequence, Literal

from paddle.fft import * # noqa: F403
import paddle.fft
from torch.fft import * # noqa: F403
import torch.fft

# Several paddle fft functions do not map axes to dim
# Several torch fft functions do not map axes to dim

def fftn(
x: array,
Expand All @@ -20,7 +20,7 @@ def fftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return paddle.fft.fftn(x, s=s, axes=axes, norm=norm, **kwargs)
return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)

def ifftn(
x: array,
Expand All @@ -31,7 +31,7 @@ def ifftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return paddle.fft.ifftn(x, s=s, axes=axes, norm=norm, **kwargs)
return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)

def rfftn(
x: array,
Expand All @@ -42,7 +42,7 @@ def rfftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return paddle.fft.rfftn(x, s=s, axes=axes, norm=norm, **kwargs)
return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)

def irfftn(
x: array,
Expand All @@ -53,7 +53,7 @@ def irfftn(
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return paddle.fft.irfftn(x, s=s, axes=axes, norm=norm, **kwargs)
return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)

def fftshift(
x: array,
Expand All @@ -62,7 +62,7 @@ def fftshift(
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
return paddle.fft.fftshift(x, axes=axes, **kwargs)
return torch.fft.fftshift(x, dim=axes, **kwargs)

def ifftshift(
x: array,
Expand All @@ -71,10 +71,10 @@ def ifftshift(
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
return paddle.fft.ifftshift(x, axes=axes, **kwargs)
return torch.fft.ifftshift(x, dim=axes, **kwargs)


__all__ = paddle.fft.__all__ + [
__all__ = torch.fft.__all__ + [
"fftn",
"ifftn",
"rfftn",
Expand All @@ -83,4 +83,4 @@ def ifftshift(
"ifftshift",
]

_all_ignore = ['paddle']
_all_ignore = ['torch']
76 changes: 39 additions & 37 deletions array_api_compat/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,84 +2,86 @@

from typing import TYPE_CHECKING
if TYPE_CHECKING:
import paddle
array = paddle.Tensor
from paddle import dtype as Dtype
import torch
array = torch.Tensor
from torch import dtype as Dtype
from typing import Optional, Union, Tuple, Literal
inf = float('inf')

from ._aliases import _fix_promotion, sum

from paddle.linalg import * # noqa: F403
from torch.linalg import * # noqa: F403

# paddle.linalg doesn't define __all__
# from paddle.linalg import __all__ as linalg_all
from paddle import linalg as paddle_linalg
linalg_all = [i for i in dir(paddle_linalg) if not i.startswith('_')]
# torch.linalg doesn't define __all__
# from torch.linalg import __all__ as linalg_all
from torch import linalg as torch_linalg
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]

# outer is implemented in paddle but aren't in the linalg namespace
from paddle import outer
# outer is implemented in torch but aren't in the linalg namespace
from torch import outer
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot

# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743

# paddle.cross also does not support broadcasting when it would add new
# torch.cross also does not support broadcasting when it would add new
# dimensions https://github.com/pytorch/pytorch/issues/39656
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
if not (x1.shape[axis] == x2.shape[axis] == 3):
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
x1, x2 = paddle.broadcast_tensors(x1, x2)
return paddle_linalg.cross(x1, x2, axis=axis)
x1, x2 = torch.broadcast_tensors(x1, x2)
return torch_linalg.cross(x1, x2, dim=axis)

def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
from ._aliases import isdtype

x1, x2 = _fix_promotion(x1, x2, only_scalar=False)

# paddle.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
# torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")

# paddle.linalg.vecdot doesn't support integer dtypes
# torch.linalg.vecdot doesn't support integer dtypes
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
if kwargs:
raise RuntimeError("vecdot kwargs not supported for integral dtypes")

x1_ = paddle.moveaxis(x1, axis, -1)
x2_ = paddle.moveaxis(x2, axis, -1)
x1_, x2_ = paddle.broadcast_tensors(x1_, x2_)
x1_ = torch.moveaxis(x1, axis, -1)
x2_ = torch.moveaxis(x2, axis, -1)
x1_, x2_ = torch.broadcast_tensors(x1_, x2_)

res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
return paddle.linalg.vecdot(x1, x2, axis=axis, **kwargs)
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)

def solve(x1: array, x2: array, /, **kwargs) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
# paddle tries to emulate NumPy 1 solve behavior by using batched 1-D solve
# Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
# whenever
# 1. x1.ndim - 1 == x2.ndim
# 2. x1.shape[:-1] == x2.shape
#
# See linalg_solve_is_vector_rhs in
# aten/src/ATen/native/LinearAlgebraUtils.h and
# paddle_META_FUNC(_linalg_solve_ex) in
# aten/src/ATen/native/BatchLinearAlgebra.cpp in the Pypaddle source code.
# TORCH_META_FUNC(_linalg_solve_ex) in
# aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
#
# The easiest way to work around this is to prepend a size 1 dimension to
# x2, since x2 is already one dimension less than x1.
#
# See https://github.com/pypaddle/pypaddle/issues/52915
# See https://github.com/pytorch/pytorch/issues/52915
if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
x2 = x2[None]
return paddle.linalg.solve(x1, x2, **kwargs)
return torch.linalg.solve(x1, x2, **kwargs)

# paddle.trace doesn't support the offset argument and doesn't support stacking
# torch.trace doesn't support the offset argument and doesn't support stacking
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
# Use our wrapped sum to make sure it does upcasting correctly
return sum(paddle.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)

def vector_norm(
x: array,
Expand All @@ -90,30 +92,30 @@ def vector_norm(
ord: Union[int, float, Literal[inf, -inf]] = 2,
**kwargs,
) -> array:
# paddle.vector_norm incorrectly treats axis=() the same as axis=None
# torch.vector_norm incorrectly treats axis=() the same as axis=None
if axis == ():
out = kwargs.get('out')
if out is None:
dtype = None
if x.dtype == paddle.complex64:
dtype = paddle.float32
elif x.dtype == paddle.complex128:
dtype = paddle.float64
if x.dtype == torch.complex64:
dtype = torch.float32
elif x.dtype == torch.complex128:
dtype = torch.float64

out = paddle.zeros_like(x, dtype=dtype)
out = torch.zeros_like(x, dtype=dtype)

# The norm of a single scalar works out to abs(x) in every case except
# for p=0, which is x != 0.
# for ord=0, which is x != 0.
if ord == 0:
out[:] = (x != 0)
else:
out[:] = paddle.abs(x)
out[:] = torch.abs(x)
return out
return paddle.linalg.vector_norm(x, p=ord, axis=axis, keepdim=keepdims, **kwargs)
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)

__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']

_all_ignore = ['paddle_linalg', 'sum']
_all_ignore = ['torch_linalg', 'sum']

del linalg_all

0 comments on commit 7118894

Please sign in to comment.