From fcf6b0a395d9013a7e327c27e79ef5ce437906e6 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 15 Aug 2024 16:35:39 +0000 Subject: [PATCH] add `is_*_namespace` helper functions closes gh-156 --- array_api_compat/common/_helpers.py | 156 +++++++++++++++++++++++++++- 1 file changed, 155 insertions(+), 1 deletion(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 57536f62..06f940b9 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -202,7 +202,6 @@ def is_jax_array(x): return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) - def is_pydata_sparse_array(x) -> bool: """ Return True if `x` is an array from the `sparse` package. @@ -255,6 +254,161 @@ def is_array_api_obj(x): or is_pydata_sparse_array(x) \ or hasattr(x, '__array_namespace__') +def is_numpy_namespace(xp) -> bool: + """ + Returns True if `xp` is a NumPy namespace. + + This includes both NumPy itself and the version wrapped by array-api-compat. + + See Also + -------- + + array_namespace + is_cupy_namespace + is_torch_namespace + is_ndonnx_namespace + is_dask_namespace + is_jax_namespace + is_pydata_sparse_namespace + is_array_api_strict_namespace + """ + return xp.__name__ == 'numpy' or 'array_api_compat.numpy' in xp.__name__ + +def is_cupy_namespace(xp) -> bool: + """ + Returns True if `xp` is a CuPy namespace. + + This includes both CuPy itself and the version wrapped by array-api-compat. + + See Also + -------- + + array_namespace + is_numpy_namespace + is_torch_namespace + is_ndonnx_namespace + is_dask_namespace + is_jax_namespace + is_pydata_sparse_namespace + is_array_api_strict_namespace + """ + return xp.__name__ == 'cupy' or 'array_api_compat.cupy' in xp.__name__ + +def is_torch_namespace(xp) -> bool: + """ + Returns True if `xp` is a PyTorch namespace. + + This includes both PyTorch itself and the version wrapped by array-api-compat. + + See Also + -------- + + array_namespace + is_numpy_namespace + is_cupy_namespace + is_ndonnx_namespace + is_dask_namespace + is_jax_namespace + is_pydata_sparse_namespace + is_array_api_strict_namespace + """ + return xp.__name__ == 'torch' or 'array_api_compat.torch' in xp.__name__ + +def is_ndonnx_namespace(xp): + """ + Returns True if `xp` is an NDONNX namespace. + + See Also + -------- + + array_namespace + is_numpy_namespace + is_cupy_namespace + is_torch_namespace + is_dask_namespace + is_jax_namespace + is_pydata_sparse_namespace + is_array_api_strict_namespace + """ + return xp.__name__ == 'ndonnx' + +def is_dask_namespace(xp): + """ + Returns True if `xp` is a Dask namespace. + + This includes both ``dask.array`` itself and the version wrapped by array-api-compat. + + See Also + -------- + + array_namespace + is_numpy_namespace + is_cupy_namespace + is_torch_namespace + is_ndonnx_namespace + is_jax_namespace + is_pydata_sparse_namespace + is_array_api_strict_namespace + """ + return xp.__name__ == 'dask.array' or 'array_api_compat.dask.array' in xp.__name__ + +def is_jax_namespace(xp): + """ + Returns True if `xp` is a JAX namespace. + + This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in + older versions of JAX. + + See Also + -------- + + array_namespace + is_numpy_namespace + is_cupy_namespace + is_torch_namespace + is_ndonnx_namespace + is_dask_namespace + is_pydata_sparse_namespace + is_array_api_strict_namespace + """ + return xp.__name__ in ('jax.numpy', 'jax.experimental.array_api') + +def is_pydata_sparse_namespace(xp): + """ + Returns True if `xp` is a pydata/sparse namespace. + + See Also + -------- + + array_namespace + is_numpy_namespace + is_cupy_namespace + is_torch_namespace + is_ndonnx_namespace + is_dask_namespace + is_jax_namespace + is_array_api_strict_namespace + """ + return xp.__name__ == 'sparse' + +def is_array_api_strict_namespace(xp): + """ + Returns True if `xp` is an array-api-strict namespace. + + See Also + -------- + + array_namespace + is_numpy_namespace + is_cupy_namespace + is_torch_namespace + is_ndonnx_namespace + is_dask_namespace + is_jax_namespace + is_pydata_sparse_namespace + """ + return xp.__name__ == 'array_api_strict' + def _check_api_version(api_version): if api_version == '2021.12': warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")