Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add pad #71

Merged
merged 4 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Extra array functions built on top of the array API standard."""

from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc, pad

__version__ = "0.4.1.dev0"

Expand All @@ -14,4 +14,5 @@
"kron",
"setdiff1d",
"sinc",
"pad",
]
56 changes: 55 additions & 1 deletion src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import warnings

from ._lib import _compat, _utils
from ._lib._compat import array_namespace
from ._lib._compat import (
array_namespace, is_torch_namespace, is_array_api_strict_namespace
)
from ._lib._typing import Array, ModuleType

__all__ = [
Expand All @@ -14,6 +16,7 @@
"kron",
"setdiff1d",
"sinc",
"pad",
]


Expand Down Expand Up @@ -538,3 +541,54 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
)
return xp.sin(y) / y


def pad(x: Array, pad_width: int, mode: str = 'constant', *, xp: ModuleType = None, **kwargs):
"""
Pad the input array.

Parameters
----------
x : array
Input array
pad_width: int
Pad the input array with this many elements from each side
mode: str, optional
Only "constant" mode is currently supported.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
constant_values: python scalar, optional
Use this value to pad the input. Default is zero.

Returns
-------
array
The input array, padded with ``pad_width`` elements equal to ``constant_values``
"""
# xp.pad is available on numpy, cupy and jax.numpy; on torch, reuse
# http://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py#L2045

if mode != 'constant':
raise NotImplementedError()

value = kwargs.get("constant_values", 0)
if kwargs and list(kwargs.keys()) != ['constant_values']:
raise ValueError(f"Unknown kwargs: {kwargs}")

if xp is None:
xp = array_namespace(x)

if is_array_api_strict_namespace(xp):
padded = xp.full(
tuple(x + 2*pad_width for x in x.shape), fill_value=value, dtype=x.dtype
)
padded[(slice(pad_width, -pad_width, None),)*x.ndim] = x
return padded
elif is_torch_namespace(xp):
pad_width = xp.asarray(pad_width)
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
return xp.nn.functional.pad(x, tuple(pad_width), value=value)

else:
return xp.pad(x, pad_width, mode=mode, **kwargs)
2 changes: 2 additions & 0 deletions src/array_api_extra/_lib/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
array_namespace, # pyright: ignore[reportUnknownVariableType]
device,
is_torch_namespace,
is_array_api_strict_namespace,
)

__all__ = [
Expand Down
22 changes: 22 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
kron,
setdiff1d,
sinc,
pad,
)
from array_api_extra._lib._typing import Array

Expand Down Expand Up @@ -385,3 +386,24 @@ def test_device(self):

def test_xp(self):
assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))


class TestPad:
def test_simple(self):
a = xp.arange(1, 4)
padded = pad(a, 2)
assert xp.all(padded == xp.asarray([0, 0, 1, 2, 3, 0, 0]))

def test_fill_value(self):
a = xp.arange(1, 4)
padded = pad(a, 2, constant_values=42)
assert xp.all(padded == xp.asarray([42, 42, 1, 2, 3, 42, 42]))

def test_ndim(self):
a = xp.reshape(xp.arange(2*3*4), (2, 3, 4))
padded = pad(a, 2)
assert padded.shape == (6, 7, 8)

def test_typo(self):
with pytest.raises(ValueError, match="Unknown"):
pad(xp.arange(2), pad_width=3, oops=3)
Loading