diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 26dbc28..83808e0 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,6 +1,15 @@ """Extra array functions built on top of the array API standard.""" -from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc +from ._funcs import ( + atleast_nd, + cov, + create_diagonal, + expand_dims, + kron, + pad, + setdiff1d, + sinc, +) __version__ = "0.4.1.dev0" @@ -12,6 +21,7 @@ "create_diagonal", "expand_dims", "kron", + "pad", "setdiff1d", "sinc", ] diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 30e6dc1..369319e 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -12,6 +12,7 @@ "create_diagonal", "expand_dims", "kron", + "pad", "setdiff1d", "sinc", ] @@ -538,3 +539,53 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)), ) return xp.sin(y) / y + + +def pad( + x: Array, + pad_width: int, + mode: str = "constant", + *, + xp: ModuleType | None = None, + constant_values: bool | int | float | complex = 0, +) -> Array: + """ + Pad the input array. + + Parameters + ---------- + x : array + Input array. + pad_width : int + Pad the input array with this many elements from each side. + mode : str, optional + Only "constant" mode is currently supported, which pads with + the value passed to `constant_values`. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + constant_values : python scalar, optional + Use this value to pad the input. Default is zero. + + Returns + ------- + array + The input array, + padded with ``pad_width`` elements equal to ``constant_values``. + """ + if mode != "constant": + msg = "Only `'constant'` mode is currently supported" + raise NotImplementedError(msg) + + value = constant_values + + if xp is None: + xp = array_namespace(x) + + padded = xp.full( + tuple(x + 2 * pad_width for x in x.shape), + fill_value=value, + dtype=x.dtype, + device=_compat.device(x), + ) + padded[(slice(pad_width, -pad_width, None),) * x.ndim] = x + return padded diff --git a/tests/test_funcs.py b/tests/test_funcs.py index d6e8930..938a4f3 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -13,6 +13,7 @@ create_diagonal, expand_dims, kron, + pad, setdiff1d, sinc, ) @@ -385,3 +386,33 @@ def test_device(self): def test_xp(self): assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0)) + + +class TestPad: + def test_simple(self): + a = xp.arange(1, 4) + padded = pad(a, 2) + assert xp.all(padded == xp.asarray([0, 0, 1, 2, 3, 0, 0])) + + def test_fill_value(self): + a = xp.arange(1, 4) + padded = pad(a, 2, constant_values=42) + assert xp.all(padded == xp.asarray([42, 42, 1, 2, 3, 42, 42])) + + def test_ndim(self): + a = xp.reshape(xp.arange(2 * 3 * 4), (2, 3, 4)) + padded = pad(a, 2) + assert padded.shape == (6, 7, 8) + + def test_mode_not_implemented(self): + a = xp.arange(3) + with pytest.raises(NotImplementedError, match="Only `'constant'`"): + pad(a, 2, mode="edge") + + def test_device(self): + device = xp.Device("device1") + a = xp.asarray(0.0, device=device) + assert pad(a, 2).device == device + + def test_xp(self): + assert_array_equal(pad(xp.asarray(0), 1, xp=xp), xp.zeros(3))