Skip to content

Commit

Permalink
remove delegation for now
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley committed Dec 26, 2024
1 parent 96bfbbf commit 5aac0b4
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 40 deletions.
13 changes: 11 additions & 2 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
"""Extra array functions built on top of the array API standard."""

from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc, pad
from ._funcs import (
atleast_nd,
cov,
create_diagonal,
expand_dims,
kron,
pad,
setdiff1d,
sinc,
)

__version__ = "0.4.1.dev0"

Expand All @@ -12,7 +21,7 @@
"create_diagonal",
"expand_dims",
"kron",
"pad",
"setdiff1d",
"sinc",
"pad",
]
54 changes: 24 additions & 30 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ._lib import _compat, _utils
from ._lib._compat import (
array_namespace, is_torch_namespace, is_array_api_strict_namespace
array_namespace,
)
from ._lib._typing import Array, ModuleType

Expand All @@ -14,9 +14,9 @@
"create_diagonal",
"expand_dims",
"kron",
"pad",
"setdiff1d",
"sinc",
"pad",
]


Expand Down Expand Up @@ -543,52 +543,46 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
return xp.sin(y) / y


def pad(x: Array, pad_width: int, mode: str = 'constant', *, xp: ModuleType = None, **kwargs):
def pad(
x: Array,
pad_width: int,
mode: str = "constant",
*,
xp: ModuleType | None = None,
constant_values: bool | int | float | complex = 0,
) -> Array:
"""
Pad the input array.
Parameters
----------
x : array
Input array
pad_width: int
Pad the input array with this many elements from each side
mode: str, optional
Input array.
pad_width : int
Pad the input array with this many elements from each side.
mode : str, optional
Only "constant" mode is currently supported.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
constant_values: python scalar, optional
constant_values : python scalar, optional
Use this value to pad the input. Default is zero.
Returns
-------
array
The input array, padded with ``pad_width`` elements equal to ``constant_values``
The input array,
padded with ``pad_width`` elements equal to ``constant_values``.
"""
# xp.pad is available on numpy, cupy and jax.numpy; on torch, reuse
# http://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py#L2045

if mode != 'constant':
if mode != "constant":
raise NotImplementedError()

Check warning on line 577 in src/array_api_extra/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_funcs.py#L577

Added line #L577 was not covered by tests

value = kwargs.get("constant_values", 0)
if kwargs and list(kwargs.keys()) != ['constant_values']:
raise ValueError(f"Unknown kwargs: {kwargs}")
value = constant_values

if xp is None:
xp = array_namespace(x)

if is_array_api_strict_namespace(xp):
padded = xp.full(
tuple(x + 2*pad_width for x in x.shape), fill_value=value, dtype=x.dtype
)
padded[(slice(pad_width, -pad_width, None),)*x.ndim] = x
return padded
elif is_torch_namespace(xp):
pad_width = xp.asarray(pad_width)
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
return xp.nn.functional.pad(x, tuple(pad_width), value=value)

else:
return xp.pad(x, pad_width, mode=mode, **kwargs)
padded = xp.full(
tuple(x + 2 * pad_width for x in x.shape), fill_value=value, dtype=x.dtype
)
padded[(slice(pad_width, -pad_width, None),) * x.ndim] = x
return padded
2 changes: 0 additions & 2 deletions src/array_api_extra/_lib/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
array_namespace, # pyright: ignore[reportUnknownVariableType]
device,
is_torch_namespace,
is_array_api_strict_namespace,
)

__all__ = [
Expand Down
8 changes: 2 additions & 6 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
create_diagonal,
expand_dims,
kron,
pad,
setdiff1d,
sinc,
pad,
)
from array_api_extra._lib._typing import Array

Expand Down Expand Up @@ -400,10 +400,6 @@ def test_fill_value(self):
assert xp.all(padded == xp.asarray([42, 42, 1, 2, 3, 42, 42]))

def test_ndim(self):
a = xp.reshape(xp.arange(2*3*4), (2, 3, 4))
a = xp.reshape(xp.arange(2 * 3 * 4), (2, 3, 4))
padded = pad(a, 2)
assert padded.shape == (6, 7, 8)

def test_typo(self):
with pytest.raises(ValueError, match="Unknown"):
pad(xp.arange(2), pad_width=3, oops=3)

0 comments on commit 5aac0b4

Please sign in to comment.