Skip to content

Commit

Permalink
ENH: add pad (#71)
Browse files Browse the repository at this point in the history
* ENH: add pad

* remove delegation for now

* tweaks

* add xp, device tests

---------

Co-authored-by: Lucas Colley <[email protected]>
  • Loading branch information
ev-br and lucascolley authored Dec 26, 2024
1 parent 6df1916 commit 169f21d
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
"""Extra array functions built on top of the array API standard."""

from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
from ._funcs import (
atleast_nd,
cov,
create_diagonal,
expand_dims,
kron,
pad,
setdiff1d,
sinc,
)

__version__ = "0.4.1.dev0"

Expand All @@ -12,6 +21,7 @@
"create_diagonal",
"expand_dims",
"kron",
"pad",
"setdiff1d",
"sinc",
]
51 changes: 51 additions & 0 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"create_diagonal",
"expand_dims",
"kron",
"pad",
"setdiff1d",
"sinc",
]
Expand Down Expand Up @@ -538,3 +539,53 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
)
return xp.sin(y) / y


def pad(
x: Array,
pad_width: int,
mode: str = "constant",
*,
xp: ModuleType | None = None,
constant_values: bool | int | float | complex = 0,
) -> Array:
"""
Pad the input array.
Parameters
----------
x : array
Input array.
pad_width : int
Pad the input array with this many elements from each side.
mode : str, optional
Only "constant" mode is currently supported, which pads with
the value passed to `constant_values`.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
constant_values : python scalar, optional
Use this value to pad the input. Default is zero.
Returns
-------
array
The input array,
padded with ``pad_width`` elements equal to ``constant_values``.
"""
if mode != "constant":
msg = "Only `'constant'` mode is currently supported"
raise NotImplementedError(msg)

value = constant_values

if xp is None:
xp = array_namespace(x)

padded = xp.full(
tuple(x + 2 * pad_width for x in x.shape),
fill_value=value,
dtype=x.dtype,
device=_compat.device(x),
)
padded[(slice(pad_width, -pad_width, None),) * x.ndim] = x
return padded
31 changes: 31 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
create_diagonal,
expand_dims,
kron,
pad,
setdiff1d,
sinc,
)
Expand Down Expand Up @@ -385,3 +386,33 @@ def test_device(self):

def test_xp(self):
assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))


class TestPad:
def test_simple(self):
a = xp.arange(1, 4)
padded = pad(a, 2)
assert xp.all(padded == xp.asarray([0, 0, 1, 2, 3, 0, 0]))

def test_fill_value(self):
a = xp.arange(1, 4)
padded = pad(a, 2, constant_values=42)
assert xp.all(padded == xp.asarray([42, 42, 1, 2, 3, 42, 42]))

def test_ndim(self):
a = xp.reshape(xp.arange(2 * 3 * 4), (2, 3, 4))
padded = pad(a, 2)
assert padded.shape == (6, 7, 8)

def test_mode_not_implemented(self):
a = xp.arange(3)
with pytest.raises(NotImplementedError, match="Only `'constant'`"):
pad(a, 2, mode="edge")

def test_device(self):
device = xp.Device("device1")
a = xp.asarray(0.0, device=device)
assert pad(a, 2).device == device

def test_xp(self):
assert_array_equal(pad(xp.asarray(0), 1, xp=xp), xp.zeros(3))

0 comments on commit 169f21d

Please sign in to comment.