From 7118894fae4c1d101a8262a61248c4208fe83d56 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 26 Nov 2024 12:45:05 +0800 Subject: [PATCH] update README --- README.md | 4 +- array_api_compat/torch/fft.py | 26 +++++------ array_api_compat/torch/linalg.py | 76 ++++++++++++++++---------------- 3 files changed, 54 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index 4b0b0c9c..5c30919d 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ This is a small wrapper around common array libraries that is compatible with the [Array API standard](https://data-apis.org/array-api/latest/). Currently, -NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want +NumPy, CuPy, PyTorch, Dask, JAX, ndonnx, `sparse` and Paddle are supported. If you want support for other array libraries, or if you encounter any issues, please [open an issue](https://github.com/data-apis/array-api-compat/issues). -See the documentation for more details https://data-apis.org/array-api-compat/ +See the documentation for more details diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 59c306af..3c9117ee 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -2,14 +2,14 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - import paddle - array = paddle.Tensor + import torch + array = torch.Tensor from typing import Union, Sequence, Literal -from paddle.fft import * # noqa: F403 -import paddle.fft +from torch.fft import * # noqa: F403 +import torch.fft -# Several paddle fft functions do not map axes to dim +# Several torch fft functions do not map axes to dim def fftn( x: array, @@ -20,7 +20,7 @@ def fftn( norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, ) -> array: - return paddle.fft.fftn(x, s=s, axes=axes, norm=norm, **kwargs) + return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) def ifftn( x: array, @@ -31,7 +31,7 @@ def ifftn( norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, ) -> array: - return paddle.fft.ifftn(x, s=s, axes=axes, norm=norm, **kwargs) + return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) def rfftn( x: array, @@ -42,7 +42,7 @@ def rfftn( norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, ) -> array: - return paddle.fft.rfftn(x, s=s, axes=axes, norm=norm, **kwargs) + return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) def irfftn( x: array, @@ -53,7 +53,7 @@ def irfftn( norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, ) -> array: - return paddle.fft.irfftn(x, s=s, axes=axes, norm=norm, **kwargs) + return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) def fftshift( x: array, @@ -62,7 +62,7 @@ def fftshift( axes: Union[int, Sequence[int]] = None, **kwargs, ) -> array: - return paddle.fft.fftshift(x, axes=axes, **kwargs) + return torch.fft.fftshift(x, dim=axes, **kwargs) def ifftshift( x: array, @@ -71,10 +71,10 @@ def ifftshift( axes: Union[int, Sequence[int]] = None, **kwargs, ) -> array: - return paddle.fft.ifftshift(x, axes=axes, **kwargs) + return torch.fft.ifftshift(x, dim=axes, **kwargs) -__all__ = paddle.fft.__all__ + [ +__all__ = torch.fft.__all__ + [ "fftn", "ifftn", "rfftn", @@ -83,4 +83,4 @@ def ifftshift( "ifftshift", ] -_all_ignore = ['paddle'] +_all_ignore = ['torch'] diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 5e4ee47b..e26198b9 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -2,84 +2,86 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - import paddle - array = paddle.Tensor - from paddle import dtype as Dtype + import torch + array = torch.Tensor + from torch import dtype as Dtype from typing import Optional, Union, Tuple, Literal inf = float('inf') from ._aliases import _fix_promotion, sum -from paddle.linalg import * # noqa: F403 +from torch.linalg import * # noqa: F403 -# paddle.linalg doesn't define __all__ -# from paddle.linalg import __all__ as linalg_all -from paddle import linalg as paddle_linalg -linalg_all = [i for i in dir(paddle_linalg) if not i.startswith('_')] +# torch.linalg doesn't define __all__ +# from torch.linalg import __all__ as linalg_all +from torch import linalg as torch_linalg +linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] -# outer is implemented in paddle but aren't in the linalg namespace -from paddle import outer +# outer is implemented in torch but aren't in the linalg namespace +from torch import outer # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot -# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the +# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the +# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 -# paddle.cross also does not support broadcasting when it would add new +# torch.cross also does not support broadcasting when it would add new +# dimensions https://github.com/pytorch/pytorch/issues/39656 def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)): raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}") if not (x1.shape[axis] == x2.shape[axis] == 3): raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}") - x1, x2 = paddle.broadcast_tensors(x1, x2) - return paddle_linalg.cross(x1, x2, axis=axis) + x1, x2 = torch.broadcast_tensors(x1, x2) + return torch_linalg.cross(x1, x2, dim=axis) def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: from ._aliases import isdtype x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - # paddle.linalg.vecdot incorrectly allows broadcasting along the contracted dimension + # torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") - # paddle.linalg.vecdot doesn't support integer dtypes + # torch.linalg.vecdot doesn't support integer dtypes if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'): if kwargs: raise RuntimeError("vecdot kwargs not supported for integral dtypes") - x1_ = paddle.moveaxis(x1, axis, -1) - x2_ = paddle.moveaxis(x2, axis, -1) - x1_, x2_ = paddle.broadcast_tensors(x1_, x2_) + x1_ = torch.moveaxis(x1, axis, -1) + x2_ = torch.moveaxis(x2, axis, -1) + x1_, x2_ = torch.broadcast_tensors(x1_, x2_) res = x1_[..., None, :] @ x2_[..., None] return res[..., 0, 0] - return paddle.linalg.vecdot(x1, x2, axis=axis, **kwargs) + return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) def solve(x1: array, x2: array, /, **kwargs) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) - # paddle tries to emulate NumPy 1 solve behavior by using batched 1-D solve + # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve # whenever # 1. x1.ndim - 1 == x2.ndim # 2. x1.shape[:-1] == x2.shape # # See linalg_solve_is_vector_rhs in # aten/src/ATen/native/LinearAlgebraUtils.h and - # paddle_META_FUNC(_linalg_solve_ex) in - # aten/src/ATen/native/BatchLinearAlgebra.cpp in the Pypaddle source code. + # TORCH_META_FUNC(_linalg_solve_ex) in + # aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code. # # The easiest way to work around this is to prepend a size 1 dimension to # x2, since x2 is already one dimension less than x1. # - # See https://github.com/pypaddle/pypaddle/issues/52915 + # See https://github.com/pytorch/pytorch/issues/52915 if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape: x2 = x2[None] - return paddle.linalg.solve(x1, x2, **kwargs) + return torch.linalg.solve(x1, x2, **kwargs) -# paddle.trace doesn't support the offset argument and doesn't support stacking +# torch.trace doesn't support the offset argument and doesn't support stacking def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: # Use our wrapped sum to make sure it does upcasting correctly - return sum(paddle.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) + return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) def vector_norm( x: array, @@ -90,30 +92,30 @@ def vector_norm( ord: Union[int, float, Literal[inf, -inf]] = 2, **kwargs, ) -> array: - # paddle.vector_norm incorrectly treats axis=() the same as axis=None + # torch.vector_norm incorrectly treats axis=() the same as axis=None if axis == (): out = kwargs.get('out') if out is None: dtype = None - if x.dtype == paddle.complex64: - dtype = paddle.float32 - elif x.dtype == paddle.complex128: - dtype = paddle.float64 + if x.dtype == torch.complex64: + dtype = torch.float32 + elif x.dtype == torch.complex128: + dtype = torch.float64 - out = paddle.zeros_like(x, dtype=dtype) + out = torch.zeros_like(x, dtype=dtype) # The norm of a single scalar works out to abs(x) in every case except - # for p=0, which is x != 0. + # for ord=0, which is x != 0. if ord == 0: out[:] = (x != 0) else: - out[:] = paddle.abs(x) + out[:] = torch.abs(x) return out - return paddle.linalg.vector_norm(x, p=ord, axis=axis, keepdim=keepdims, **kwargs) + return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs) __all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot', 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] -_all_ignore = ['paddle_linalg', 'sum'] +_all_ignore = ['torch_linalg', 'sum'] del linalg_all