Skip to content

Commit

Permalink
Run more tests on array-api-strict and sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 6, 2025
1 parent beac55b commit b87e0aa
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
8 changes: 4 additions & 4 deletions tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import pytest

wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
all_libraries = wrapped_libraries + ["jax.numpy"]
all_libraries = wrapped_libraries + [
"array_api_strict", "jax.numpy", "sparse"
]

# `sparse` added array API support as of Python 3.10.
if sys.version_info >= (3, 10):
Expand All @@ -20,9 +22,7 @@ def import_(library, wrapper=False):
jax_numpy = import_module("jax.numpy")
if not hasattr(jax_numpy, "__array_api_version__"):
library = 'jax.experimental.array_api'
elif library.startswith('sparse'):
library = 'sparse'
else:
elif library in wrapped_libraries:
library = 'array_api_compat.' + library

return import_module(library)
5 changes: 4 additions & 1 deletion tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

@pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
def test_all(library):
import_(library, wrapper=True)
if library == "common":
import array_api_compat.common # noqa: F401
else:
import_(library, wrapper=True)

for mod_name in sys.modules:
if not mod_name.startswith('array_api_compat.' + library):
Expand Down
9 changes: 8 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
is_dask_array, is_jax_array, is_pydata_sparse_array,
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
is_array_api_strict_namespace,
)

from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device
Expand Down Expand Up @@ -30,6 +31,7 @@
'dask.array': 'is_dask_namespace',
'jax.numpy': 'is_jax_namespace',
'sparse': 'is_pydata_sparse_namespace',
'array_api_strict': 'is_array_api_strict_namespace',
}


Expand Down Expand Up @@ -71,7 +73,12 @@ def test_xp_is_array_generics(library):
is_func = globals()[func]
if is_func(x0):
matches.append(library2)
assert matches in ([library], ["numpy"])

if library == "array_api_strict":
# There is no is_array_api_strict_array() function
assert matches == []
else:
assert matches in ([library], ["numpy"])


@pytest.mark.parametrize("library", all_libraries)
Expand Down

0 comments on commit b87e0aa

Please sign in to comment.