Skip to content

Commit

Permalink
Merge pull request #231 from crusaderky/test_size
Browse files Browse the repository at this point in the history
ENH: size() to return None on dask instead of nan
  • Loading branch information
ev-br authored Jan 9, 2025
2 parents beac55b + d947529 commit e5dd419
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
11 changes: 8 additions & 3 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,19 +788,24 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
return x.to_device(device, stream=stream)


def size(x):
def size(x: Array) -> int | None:
"""
Return the total number of elements of x.
This is equivalent to `x.size` according to the `standard
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
This helper is included because PyTorch defines `size` in an
:external+torch:meth:`incompatible way <torch.Tensor.size>`.
It also fixes dask.array's behaviour which returns nan for unknown sizes, whereas
the standard requires None.
"""
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape
if None in x.shape:
return None
return math.prod(x.shape)
out = math.prod(x.shape)
# dask.array.Array.shape can contain NaN
return None if math.isnan(out) else out


def is_writeable_array(x) -> bool:
Expand Down
27 changes: 25 additions & 2 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
)

from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device

from array_api_compat import (
device, is_array_api_obj, is_writeable_array, size, to_device
)
from ._helpers import import_, wrapped_libraries, all_libraries

import pytest
Expand Down Expand Up @@ -92,6 +93,28 @@ def test_is_writeable_array_numpy():
assert not is_writeable_array(x)


@pytest.mark.parametrize("library", all_libraries)
def test_size(library):
xp = import_(library)
x = xp.asarray([1, 2, 3])
assert size(x) == 3


@pytest.mark.parametrize("library", all_libraries)
def test_size_none(library):
if library == "sparse":
pytest.skip("No arange(); no indexing by sparse arrays")

xp = import_(library)
x = xp.arange(10)
x = x[x < 5]

# dask.array now has shape=(nan, ) and size=nan
# ndonnx now has shape=(None, ) and size=None
# Eager libraries have shape=(5, ) and size=5
assert size(x) in (None, 5)


@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
xp = import_(library, wrapper=True)
Expand Down

0 comments on commit e5dd419

Please sign in to comment.