From 96bfbbf2e3d5df995f8a610a72c20a963d38dcff Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 26 Dec 2024 10:28:55 +0200 Subject: [PATCH 1/4] ENH: add pad --- src/array_api_extra/__init__.py | 3 +- src/array_api_extra/_funcs.py | 56 ++++++++++++++++++++++++++++- src/array_api_extra/_lib/_compat.py | 2 ++ tests/test_funcs.py | 22 ++++++++++++ 4 files changed, 81 insertions(+), 2 deletions(-) diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 26dbc28..03d839c 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,6 +1,6 @@ """Extra array functions built on top of the array API standard.""" -from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc +from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc, pad __version__ = "0.4.1.dev0" @@ -14,4 +14,5 @@ "kron", "setdiff1d", "sinc", + "pad", ] diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 30e6dc1..01d8f17 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -3,7 +3,9 @@ import warnings from ._lib import _compat, _utils -from ._lib._compat import array_namespace +from ._lib._compat import ( + array_namespace, is_torch_namespace, is_array_api_strict_namespace +) from ._lib._typing import Array, ModuleType __all__ = [ @@ -14,6 +16,7 @@ "kron", "setdiff1d", "sinc", + "pad", ] @@ -538,3 +541,54 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)), ) return xp.sin(y) / y + + +def pad(x: Array, pad_width: int, mode: str = 'constant', *, xp: ModuleType = None, **kwargs): + """ + Pad the input array. + + Parameters + ---------- + x : array + Input array + pad_width: int + Pad the input array with this many elements from each side + mode: str, optional + Only "constant" mode is currently supported. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + constant_values: python scalar, optional + Use this value to pad the input. Default is zero. + + Returns + ------- + array + The input array, padded with ``pad_width`` elements equal to ``constant_values`` + """ + # xp.pad is available on numpy, cupy and jax.numpy; on torch, reuse + # http://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py#L2045 + + if mode != 'constant': + raise NotImplementedError() + + value = kwargs.get("constant_values", 0) + if kwargs and list(kwargs.keys()) != ['constant_values']: + raise ValueError(f"Unknown kwargs: {kwargs}") + + if xp is None: + xp = array_namespace(x) + + if is_array_api_strict_namespace(xp): + padded = xp.full( + tuple(x + 2*pad_width for x in x.shape), fill_value=value, dtype=x.dtype + ) + padded[(slice(pad_width, -pad_width, None),)*x.ndim] = x + return padded + elif is_torch_namespace(xp): + pad_width = xp.asarray(pad_width) + pad_width = xp.broadcast_to(pad_width, (x.ndim, 2)) + pad_width = xp.flip(pad_width, axis=(0,)).flatten() + return xp.nn.functional.pad(x, tuple(pad_width), value=value) + + else: + return xp.pad(x, pad_width, mode=mode, **kwargs) diff --git a/src/array_api_extra/_lib/_compat.py b/src/array_api_extra/_lib/_compat.py index de7a220..353c1a3 100644 --- a/src/array_api_extra/_lib/_compat.py +++ b/src/array_api_extra/_lib/_compat.py @@ -11,6 +11,8 @@ from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs] array_namespace, # pyright: ignore[reportUnknownVariableType] device, + is_torch_namespace, + is_array_api_strict_namespace, ) __all__ = [ diff --git a/tests/test_funcs.py b/tests/test_funcs.py index d6e8930..c9814e0 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -15,6 +15,7 @@ kron, setdiff1d, sinc, + pad, ) from array_api_extra._lib._typing import Array @@ -385,3 +386,24 @@ def test_device(self): def test_xp(self): assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0)) + + +class TestPad: + def test_simple(self): + a = xp.arange(1, 4) + padded = pad(a, 2) + assert xp.all(padded == xp.asarray([0, 0, 1, 2, 3, 0, 0])) + + def test_fill_value(self): + a = xp.arange(1, 4) + padded = pad(a, 2, constant_values=42) + assert xp.all(padded == xp.asarray([42, 42, 1, 2, 3, 42, 42])) + + def test_ndim(self): + a = xp.reshape(xp.arange(2*3*4), (2, 3, 4)) + padded = pad(a, 2) + assert padded.shape == (6, 7, 8) + + def test_typo(self): + with pytest.raises(ValueError, match="Unknown"): + pad(xp.arange(2), pad_width=3, oops=3) From 5aac0b48df10fa1f53989558a9f9f4f9c392d711 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 26 Dec 2024 14:03:59 +0000 Subject: [PATCH 2/4] remove delegation for now --- src/array_api_extra/__init__.py | 13 +++++-- src/array_api_extra/_funcs.py | 54 +++++++++++++---------------- src/array_api_extra/_lib/_compat.py | 2 -- tests/test_funcs.py | 8 ++--- 4 files changed, 37 insertions(+), 40 deletions(-) diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 03d839c..83808e0 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,6 +1,15 @@ """Extra array functions built on top of the array API standard.""" -from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc, pad +from ._funcs import ( + atleast_nd, + cov, + create_diagonal, + expand_dims, + kron, + pad, + setdiff1d, + sinc, +) __version__ = "0.4.1.dev0" @@ -12,7 +21,7 @@ "create_diagonal", "expand_dims", "kron", + "pad", "setdiff1d", "sinc", - "pad", ] diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 01d8f17..a1eb8b8 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -4,7 +4,7 @@ from ._lib import _compat, _utils from ._lib._compat import ( - array_namespace, is_torch_namespace, is_array_api_strict_namespace + array_namespace, ) from ._lib._typing import Array, ModuleType @@ -14,9 +14,9 @@ "create_diagonal", "expand_dims", "kron", + "pad", "setdiff1d", "sinc", - "pad", ] @@ -543,52 +543,46 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: return xp.sin(y) / y -def pad(x: Array, pad_width: int, mode: str = 'constant', *, xp: ModuleType = None, **kwargs): +def pad( + x: Array, + pad_width: int, + mode: str = "constant", + *, + xp: ModuleType | None = None, + constant_values: bool | int | float | complex = 0, +) -> Array: """ Pad the input array. Parameters ---------- x : array - Input array - pad_width: int - Pad the input array with this many elements from each side - mode: str, optional + Input array. + pad_width : int + Pad the input array with this many elements from each side. + mode : str, optional Only "constant" mode is currently supported. xp : array_namespace, optional The standard-compatible namespace for `x`. Default: infer. - constant_values: python scalar, optional + constant_values : python scalar, optional Use this value to pad the input. Default is zero. Returns ------- array - The input array, padded with ``pad_width`` elements equal to ``constant_values`` + The input array, + padded with ``pad_width`` elements equal to ``constant_values``. """ - # xp.pad is available on numpy, cupy and jax.numpy; on torch, reuse - # http://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py#L2045 - - if mode != 'constant': + if mode != "constant": raise NotImplementedError() - value = kwargs.get("constant_values", 0) - if kwargs and list(kwargs.keys()) != ['constant_values']: - raise ValueError(f"Unknown kwargs: {kwargs}") + value = constant_values if xp is None: xp = array_namespace(x) - if is_array_api_strict_namespace(xp): - padded = xp.full( - tuple(x + 2*pad_width for x in x.shape), fill_value=value, dtype=x.dtype - ) - padded[(slice(pad_width, -pad_width, None),)*x.ndim] = x - return padded - elif is_torch_namespace(xp): - pad_width = xp.asarray(pad_width) - pad_width = xp.broadcast_to(pad_width, (x.ndim, 2)) - pad_width = xp.flip(pad_width, axis=(0,)).flatten() - return xp.nn.functional.pad(x, tuple(pad_width), value=value) - - else: - return xp.pad(x, pad_width, mode=mode, **kwargs) + padded = xp.full( + tuple(x + 2 * pad_width for x in x.shape), fill_value=value, dtype=x.dtype + ) + padded[(slice(pad_width, -pad_width, None),) * x.ndim] = x + return padded diff --git a/src/array_api_extra/_lib/_compat.py b/src/array_api_extra/_lib/_compat.py index 353c1a3..de7a220 100644 --- a/src/array_api_extra/_lib/_compat.py +++ b/src/array_api_extra/_lib/_compat.py @@ -11,8 +11,6 @@ from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs] array_namespace, # pyright: ignore[reportUnknownVariableType] device, - is_torch_namespace, - is_array_api_strict_namespace, ) __all__ = [ diff --git a/tests/test_funcs.py b/tests/test_funcs.py index c9814e0..270d1f0 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -13,9 +13,9 @@ create_diagonal, expand_dims, kron, + pad, setdiff1d, sinc, - pad, ) from array_api_extra._lib._typing import Array @@ -400,10 +400,6 @@ def test_fill_value(self): assert xp.all(padded == xp.asarray([42, 42, 1, 2, 3, 42, 42])) def test_ndim(self): - a = xp.reshape(xp.arange(2*3*4), (2, 3, 4)) + a = xp.reshape(xp.arange(2 * 3 * 4), (2, 3, 4)) padded = pad(a, 2) assert padded.shape == (6, 7, 8) - - def test_typo(self): - with pytest.raises(ValueError, match="Unknown"): - pad(xp.arange(2), pad_width=3, oops=3) From faaa2b62fcc865dde81177b146704a5f84d6509e Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 26 Dec 2024 14:13:42 +0000 Subject: [PATCH 3/4] tweaks --- src/array_api_extra/_funcs.py | 10 +++++----- tests/test_funcs.py | 5 +++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index a1eb8b8..ca1f6ed 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -3,9 +3,7 @@ import warnings from ._lib import _compat, _utils -from ._lib._compat import ( - array_namespace, -) +from ._lib._compat import array_namespace from ._lib._typing import Array, ModuleType __all__ = [ @@ -561,7 +559,8 @@ def pad( pad_width : int Pad the input array with this many elements from each side. mode : str, optional - Only "constant" mode is currently supported. + Only "constant" mode is currently supported, which pads with + the value passed to `constant_values`. xp : array_namespace, optional The standard-compatible namespace for `x`. Default: infer. constant_values : python scalar, optional @@ -574,7 +573,8 @@ def pad( padded with ``pad_width`` elements equal to ``constant_values``. """ if mode != "constant": - raise NotImplementedError() + msg = "Only `'constant'` mode is currently supported" + raise NotImplementedError(msg) value = constant_values diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 270d1f0..49b3b61 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -403,3 +403,8 @@ def test_ndim(self): a = xp.reshape(xp.arange(2 * 3 * 4), (2, 3, 4)) padded = pad(a, 2) assert padded.shape == (6, 7, 8) + + def test_mode_not_implemented(self): + a = xp.arange(3) + with pytest.raises(NotImplementedError, match="Only `'constant'`"): + pad(a, 2, mode="edge") From 112def54364e135e36c9468b2d0cd3634eff2441 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 26 Dec 2024 14:19:17 +0000 Subject: [PATCH 4/4] add xp, device tests --- src/array_api_extra/_funcs.py | 5 ++++- tests/test_funcs.py | 8 ++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index ca1f6ed..369319e 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -582,7 +582,10 @@ def pad( xp = array_namespace(x) padded = xp.full( - tuple(x + 2 * pad_width for x in x.shape), fill_value=value, dtype=x.dtype + tuple(x + 2 * pad_width for x in x.shape), + fill_value=value, + dtype=x.dtype, + device=_compat.device(x), ) padded[(slice(pad_width, -pad_width, None),) * x.ndim] = x return padded diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 49b3b61..938a4f3 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -408,3 +408,11 @@ def test_mode_not_implemented(self): a = xp.arange(3) with pytest.raises(NotImplementedError, match="Only `'constant'`"): pad(a, 2, mode="edge") + + def test_device(self): + device = xp.Device("device1") + a = xp.asarray(0.0, device=device) + assert pad(a, 2).device == device + + def test_xp(self): + assert_array_equal(pad(xp.asarray(0), 1, xp=xp), xp.zeros(3))