Skip to content

Commit

Permalink
ENH: pad: add delegation
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley committed Dec 26, 2024
1 parent 169f21d commit d17fd2f
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 33 deletions.
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
create_diagonal
expand_dims
kron
pad
setdiff1d
sinc
```
2 changes: 1 addition & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Extra array functions built on top of the array API standard."""

from ._funcs import (
from ._lib._funcs import (
atleast_nd,
cov,
create_diagonal,
Expand Down
59 changes: 59 additions & 0 deletions src/array_api_extra/_delegators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Delegators to existing implementations for Public API Functions."""

from ._lib import _funcs
from ._lib._utils._compat import (
array_namespace,
is_cupy_namespace,
is_jax_namespace,
is_numpy_namespace,
is_torch_namespace,
)
from ._lib._utils._typing import Array, ModuleType


def pad(
x: Array,
pad_width: int,
mode: str = "constant",
*,
constant_values: bool | int | float | complex = 0,
xp: ModuleType | None = None,
) -> Array:
"""
Pad the input array.
Parameters
----------
x : array
Input array.
pad_width : int
Pad the input array with this many elements from each side.
mode : str, optional
Only "constant" mode is currently supported, which pads with
the value passed to `constant_values`.
constant_values : python scalar, optional
Use this value to pad the input. Default is zero.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
Returns
-------
array
The input array,
padded with ``pad_width`` elements equal to ``constant_values``.
"""
xp = array_namespace(x) if xp is None else xp

value = constant_values

# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
if is_torch_namespace(xp):
pad_width = xp.asarray(pad_width)
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
return xp.nn.functional.pad(x, (pad_width,), value=value)

if is_numpy_namespace(x) or is_jax_namespace(xp) or is_cupy_namespace(xp):
return xp.pad(x, pad_width, mode, constant_values=value)

return _funcs.pad(x, pad_width, mode, constant_values=value, xp=xp)
2 changes: 1 addition & 1 deletion src/array_api_extra/_lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""Modules housing private functions."""
"""Array-agnostic implementations for the public API."""
19 changes: 0 additions & 19 deletions src/array_api_extra/_lib/_compat.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import warnings

from ._lib import _compat, _utils
from ._lib._compat import array_namespace
from ._lib._typing import Array, ModuleType
from ._utils import _compat, _helpers
from ._utils._compat import array_namespace
from ._utils._typing import Array, ModuleType

__all__ = [
"atleast_nd",
Expand Down Expand Up @@ -136,7 +136,7 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
m = atleast_nd(m, ndim=2, xp=xp)
m = xp.astype(m, dtype)

avg = _utils.mean(m, axis=1, xp=xp)
avg = _helpers.mean(m, axis=1, xp=xp)
fact = m.shape[1] - 1

if fact <= 0:
Expand Down Expand Up @@ -449,7 +449,7 @@ def setdiff1d(
else:
x1 = xp.unique_values(x1)
x2 = xp.unique_values(x2)
return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]


def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
Expand Down Expand Up @@ -546,8 +546,8 @@ def pad(
pad_width: int,
mode: str = "constant",
*,
xp: ModuleType | None = None,
constant_values: bool | int | float | complex = 0,
xp: ModuleType | None = None,
) -> Array:
"""
Pad the input array.
Expand All @@ -561,10 +561,10 @@ def pad(
mode : str, optional
Only "constant" mode is currently supported, which pads with
the value passed to `constant_values`.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
constant_values : python scalar, optional
Use this value to pad the input. Default is zero.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
Returns
-------
Expand Down
1 change: 1 addition & 0 deletions src/array_api_extra/_lib/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Modules housing private utility functions."""
31 changes: 31 additions & 0 deletions src/array_api_extra/_lib/_utils/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Acquire helpers from array-api-compat."""
# Allow packages that vendor both `array-api-extra` and
# `array-api-compat` to override the import location

try:
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
array_namespace, # pyright: ignore[reportUnknownVariableType]
device, # pyright: ignore[reportUnknownVariableType]
is_cupy_namespace, # pyright: ignore[reportUnknownVariableType]
is_jax_namespace, # pyright: ignore[reportUnknownVariableType]
is_numpy_namespace, # pyright: ignore[reportUnknownVariableType]
is_torch_namespace, # pyright: ignore[reportUnknownVariableType]
)
except ImportError:
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
array_namespace, # pyright: ignore[reportUnknownVariableType]
device,
is_cupy_namespace, # pyright: ignore[reportUnknownVariableType]
is_jax_namespace, # pyright: ignore[reportUnknownVariableType]
is_numpy_namespace, # pyright: ignore[reportUnknownVariableType]
is_torch_namespace, # pyright: ignore[reportUnknownVariableType]
)

__all__ = [
"array_namespace",
"device",
"is_cupy_namespace",
"is_jax_namespace",
"is_numpy_namespace",
"is_torch_namespace",
]
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ def array_namespace(
use_compat: bool | None = None,
) -> ArrayModule: ... # numpydoc ignore=GL08
def device(x: Array, /) -> Device: ... # numpydoc ignore=GL08
def is_cupy_namespace(xp: ModuleType, /) -> bool: ...
def is_jax_namespace(xp: ModuleType, /) -> bool: ...
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Utility functions used by `array_api_extra/_funcs.py`."""
"""Helper functions used by `array_api_extra/_funcs.py`."""

from . import _compat
from ._typing import Array, ModuleType
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
setdiff1d,
sinc,
)
from array_api_extra._lib._typing import Array
from array_api_extra._lib._utils._typing import Array


class TestAtLeastND:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import pytest
from numpy.testing import assert_array_equal

from array_api_extra._lib._typing import Array
from array_api_extra._lib._utils import in1d
from array_api_extra._lib._utils._helpers import in1d
from array_api_extra._lib._utils._typing import Array


# some test coverage already provided by TestSetDiff1D
Expand Down

0 comments on commit d17fd2f

Please sign in to comment.