Skip to content

Commit

Permalink
ENH: allow list/tuple pad_width in pad
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br committed Jan 7, 2025
1 parent a96dffb commit ef0d484
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
39 changes: 35 additions & 4 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:

def pad(
x: Array,
pad_width: int,
pad_width: int | tuple[int, int] | list[tuple[int, int]],
mode: str = "constant",
*,
xp: ModuleType | None = None,
Expand All @@ -568,8 +568,12 @@ def pad(
----------
x : array
Input array.
pad_width : int
pad_width : int or tuple of ints or list of pairs of ints
Pad the input array with this many elements from each side.
If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
each pair applies to the corresponding axis of ``x``.
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
copies of this tuple.
mode : str, optional
Only "constant" mode is currently supported, which pads with
the value passed to `constant_values`.
Expand All @@ -590,16 +594,43 @@ def pad(

value = constant_values

# make pad_width a list of length-2 tuples of ints
if isinstance(pad_width, int):
pad_width = [(pad_width, pad_width)] * x.ndim

if isinstance(pad_width, tuple):
pad_width = [pad_width] * x.ndim

if xp is None:
xp = array_namespace(x)

slices = []
newshape = []
for ax, w_tpl in enumerate(pad_width):
if len(w_tpl) != 2:
msg = f"expect a 2-tuple (before, after), got {w_tpl}."
raise ValueError(msg)

sh = x.shape[ax]
if w_tpl[0] == 0 and w_tpl[1] == 0:
sl = slice(None, None, None)
else:
start, stop = w_tpl
stop = None if stop == 0 else -stop

sl = slice(start, stop, None)
sh += w_tpl[0] + w_tpl[1]

newshape.append(sh)
slices.append(sl)

padded = xp.full(
tuple(x + 2 * pad_width for x in x.shape),
tuple(newshape),
fill_value=value,
dtype=x.dtype,
device=_compat.device(x),
)
padded[(slice(pad_width, -pad_width, None),) * x.ndim] = x
padded[tuple(slices)] = x
return padded


Expand Down
19 changes: 19 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,22 @@ def test_device(self):

def test_xp(self):
assert_array_equal(pad(xp.asarray(0), 1, xp=xp), xp.zeros(3))

def test_tuple_width(self):
a = xp.reshape(xp.arange(12), (3, 4))
padded = pad(a, (1, 0))
assert padded.shape == (4, 5)

padded = pad(a, (1, 2))
assert padded.shape == (6, 7)

with pytest.raises(ValueError, match="expect a 2-tuple"):
pad(a, [(1, 2, 3)])

def test_list_of_tuples_width(self):
a = xp.reshape(xp.arange(12), (3, 4))
padded = pad(a, [(1, 0), (0, 2)])
assert padded.shape == (4, 6)

padded = pad(a, [(1, 0), (0, 0)])
assert padded.shape == (4, 4)

0 comments on commit ef0d484

Please sign in to comment.