diff --git a/src/array_api_extra/_lib/_utils/_compat.pyi b/src/array_api_extra/_lib/_utils/_compat.pyi index 6086dcf..1e81c98 100644 --- a/src/array_api_extra/_lib/_utils/_compat.pyi +++ b/src/array_api_extra/_lib/_utils/_compat.pyi @@ -23,8 +23,8 @@ def is_cupy_namespace(xp: ModuleType, /) -> bool: ... def is_dask_namespace(xp: ModuleType, /) -> bool: ... def is_jax_namespace(xp: ModuleType, /) -> bool: ... def is_numpy_namespace(xp: ModuleType, /) -> bool: ... +def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ... def is_torch_namespace(xp: ModuleType, /) -> bool: ... def is_jax_array(x: object, /) -> bool: ... -def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ... def is_writeable_array(x: object, /) -> bool: ... def size(x: Array, /) -> int | None: ... diff --git a/vendor_tests/test_vendor.py b/vendor_tests/test_vendor.py index d94b73f..3824937 100644 --- a/vendor_tests/test_vendor.py +++ b/vendor_tests/test_vendor.py @@ -7,6 +7,7 @@ def test_vendor_compat(): array_namespace, device, is_cupy_namespace, + is_dask_namespace, is_jax_array, is_jax_namespace, is_pydata_sparse_namespace, @@ -18,12 +19,13 @@ def test_vendor_compat(): x = xp.asarray([1, 2, 3]) assert array_namespace(x) is xp device(x) - assert not is_jax_array(x) - assert is_writeable_array(x) assert not is_cupy_namespace(xp) + assert not is_dask_namespace(xp) + assert not is_jax_array(x) assert not is_jax_namespace(xp) assert not is_pydata_sparse_namespace(xp) assert not is_torch_namespace(xp) + assert is_writeable_array(x) assert size(x) == 3