diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index db0a1af..3545aa7 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -555,7 +555,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: def pad( x: Array, - pad_width: int, + pad_width: int | tuple[int, int] | list[tuple[int, int],...], mode: str = "constant", *, xp: ModuleType | None = None, @@ -568,8 +568,12 @@ def pad( ---------- x : array Input array. - pad_width : int + pad_width : int or tuple of ints or list of pairs of ints Pad the input array with this many elements from each side. + If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``, + each pair applies to the corresponding axis of ``x``. + A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim`` + copies of this tuple. mode : str, optional Only "constant" mode is currently supported, which pads with the value passed to `constant_values`. @@ -590,16 +594,43 @@ def pad( value = constant_values + # make pad_width a list of length-2 tuples of ints + if isinstance(pad_width, int): + pad_width = [(pad_width, pad_width)] * x.ndim + + if isinstance(pad_width, tuple): + pad_width = [pad_width] * x.ndim + if xp is None: xp = array_namespace(x) + slices = [] + newshape = [] + for ax, w_tpl in enumerate(pad_width): + if len(w_tpl) != 2: + msg = f"expect a 2-tuple (before, after), got {w_tpl}." + raise ValueError(msg) + + sh = x.shape[ax] + if w_tpl[0] == 0 and w_tpl[1] == 0: + sl = slice(None, None, None) + else: + start, stop = w_tpl + stop = None if stop == 0 else -stop + + sl = slice(start, stop, None) + sh += w_tpl[0] + w_tpl[1] + + newshape.append(sh) + slices.append(sl) + padded = xp.full( - tuple(x + 2 * pad_width for x in x.shape), + tuple(newshape), fill_value=value, dtype=x.dtype, device=_compat.device(x), ) - padded[(slice(pad_width, -pad_width, None),) * x.ndim] = x + padded[tuple(slices)] = x return padded diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 938a4f3..67b004d 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -416,3 +416,19 @@ def test_device(self): def test_xp(self): assert_array_equal(pad(xp.asarray(0), 1, xp=xp), xp.zeros(3)) + + def test_tuple_width(self): + a = xp.reshape(xp.arange(12), (3, 4)) + padded = pad(a, (1, 0)) + assert padded.shape == (4, 5) + + padded = pad(a, (1, 2)) + assert padded.shape == (6, 7) + + def test_list_of_tuples_width(self): + a = xp.reshape(xp.arange(12), (3, 4)) + padded = pad(a, [(1, 0), (0, 2)]) + assert padded.shape == (4, 6) + + padded = pad(a, [(1, 0), (0, 0)]) + assert padded.shape == (4, 4)