Skip to content

Commit

Permalink
TST: test is_*_namespace fns
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley committed Aug 17, 2024
1 parent 733d17c commit bdcd14b
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 18 deletions.
8 changes: 4 additions & 4 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def is_numpy_namespace(xp) -> bool:
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'numpy', _compat_module_name + '.numpy'}
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}

def is_cupy_namespace(xp) -> bool:
"""
Expand All @@ -296,7 +296,7 @@ def is_cupy_namespace(xp) -> bool:
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'cupy', _compat_module_name + '.cupy'}
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}

def is_torch_namespace(xp) -> bool:
"""
Expand All @@ -316,7 +316,7 @@ def is_torch_namespace(xp) -> bool:
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'torch', _compat_module_name + '.torch'}
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}


def is_ndonnx_namespace(xp):
Expand Down Expand Up @@ -355,7 +355,7 @@ def is_dask_namespace(xp):
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'dask.array', _compat_module_name + '.dask.array'}
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}

def is_jax_namespace(xp):
"""
Expand Down
44 changes: 34 additions & 10 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401
is_dask_array, is_jax_array, is_pydata_sparse_array)
from array_api_compat import ( # noqa: F401
is_numpy_array, is_cupy_array, is_torch_array,
is_dask_array, is_jax_array, is_pydata_sparse_array,
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
)

from array_api_compat import is_array_api_obj, device, to_device

Expand All @@ -10,7 +14,7 @@
import array
from numpy.testing import assert_allclose

is_functions = {
is_array_functions = {
'numpy': 'is_numpy_array',
'cupy': 'is_cupy_array',
'torch': 'is_torch_array',
Expand All @@ -19,18 +23,38 @@
'sparse': 'is_pydata_sparse_array',
}

@pytest.mark.parametrize('library', is_functions.keys())
@pytest.mark.parametrize('func', is_functions.values())
is_namespace_functions = {
'numpy': 'is_numpy_namespace',
'cupy': 'is_cupy_namespace',
'torch': 'is_torch_namespace',
'dask.array': 'is_dask_namespace',
'jax.numpy': 'is_jax_namespace',
'sparse': 'is_pydata_sparse_namespace',
}


@pytest.mark.parametrize('library', is_array_functions.keys())
@pytest.mark.parametrize('func', is_array_functions.values())
def test_is_xp_array(library, func):
lib = import_(library)
is_func = globals()[func]

x = lib.asarray([1, 2, 3])

assert is_func(x) == (func == is_functions[library])
assert is_func(x) == (func == is_array_functions[library])

assert is_array_api_obj(x)


@pytest.mark.parametrize('library', is_namespace_functions.keys())
@pytest.mark.parametrize('func', is_namespace_functions.values())
def test_is_xp_namespace(library, func):
lib = import_(library)
is_func = globals()[func]

assert is_func(lib) == (func == is_namespace_functions[library])


@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
xp = import_(library, wrapper=True)
Expand Down Expand Up @@ -64,8 +88,8 @@ def test_to_device_host(library):
assert_allclose(x, expected)


@pytest.mark.parametrize("target_library", is_functions.keys())
@pytest.mark.parametrize("source_library", is_functions.keys())
@pytest.mark.parametrize("target_library", is_array_functions.keys())
@pytest.mark.parametrize("source_library", is_array_functions.keys())
def test_asarray_cross_library(source_library, target_library, request):
if source_library == "dask.array" and target_library == "torch":
# Allow rest of test to execute instead of immediately xfailing
Expand All @@ -81,7 +105,7 @@ def test_asarray_cross_library(source_library, target_library, request):
pytest.skip(reason="`sparse` does not allow implicit densification")
src_lib = import_(source_library, wrapper=True)
tgt_lib = import_(target_library, wrapper=True)
is_tgt_type = globals()[is_functions[target_library]]
is_tgt_type = globals()[is_array_functions[target_library]]

a = src_lib.asarray([1, 2, 3])
b = tgt_lib.asarray(a)
Expand All @@ -96,7 +120,7 @@ def test_asarray_copy(library):
# should be able to delete this.
xp = import_(library, wrapper=True)
asarray = xp.asarray
is_lib_func = globals()[is_functions[library]]
is_lib_func = globals()[is_array_functions[library]]
all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()

if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :
Expand Down
1 change: 1 addition & 0 deletions tests/test_vendoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_vendoring_torch():

uses_torch._test_torch()


def test_vendoring_dask():
from vendor_test import uses_dask
uses_dask._test_dask()
9 changes: 8 additions & 1 deletion vendor_test/uses_cupy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Basic test that vendoring works

from .vendored._compat import cupy as cp_compat
from .vendored._compat import (
cupy as cp_compat,
is_cupy_array,
is_cupy_namespace,
)

import cupy as cp

Expand All @@ -16,3 +20,6 @@ def _test_cupy():
assert isinstance(res, cp.ndarray)

cp.testing.assert_allclose(res, [1., 2., 9.])

assert is_cupy_array(res)
assert is_cupy_namespace(cp) and is_cupy_namespace(cp_compat)
9 changes: 8 additions & 1 deletion vendor_test/uses_dask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Basic test that vendoring works

from .vendored._compat.dask import array as dask_compat
from .vendored._compat import (
array as dask_compat,
is_dask_array,
is_dask_namespace,
)

import dask.array as da
import numpy as np
Expand All @@ -17,3 +21,6 @@ def _test_dask():
assert isinstance(res, da.Array)

np.testing.assert_allclose(res, [1., 2., 9.])

assert is_dask_array(res)
assert is_dask_namespace(da) and is_dask_namespace(dask_compat)
10 changes: 9 additions & 1 deletion vendor_test/uses_numpy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Basic test that vendoring works

from .vendored._compat import numpy as np_compat
from .vendored._compat import (
is_numpy_array,
is_numpy_namespace,
numpy as np_compat,
)


import numpy as np

Expand All @@ -16,3 +21,6 @@ def _test_numpy():
assert isinstance(res, np.ndarray)

np.testing.assert_allclose(res, [1., 2., 9.])

assert is_numpy_array(res)
assert is_numpy_namespace(np) and is_numpy_namespace(np_compat)
10 changes: 9 additions & 1 deletion vendor_test/uses_torch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Basic test that vendoring works

from .vendored._compat import torch as torch_compat
from .vendored._compat import (
is_torch_array,
is_torch_namespace,
torch as torch_compat,
)

import torch

Expand All @@ -20,3 +24,7 @@ def _test_torch():
assert isinstance(res, torch.Tensor)

torch.testing.assert_allclose(res, [[1., 2., 3.]])

assert is_torch_array(res)
assert is_torch_namespace(torch) and is_torch_namespace(torch_compat)

0 comments on commit bdcd14b

Please sign in to comment.