diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 57536f62..44ef3522 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -202,7 +202,6 @@ def is_jax_array(x): return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) - def is_pydata_sparse_array(x) -> bool: """ Return True if `x` is an array from the `sparse` package. @@ -255,6 +254,161 @@ def is_array_api_obj(x): or is_pydata_sparse_array(x) \ or hasattr(x, '__array_namespace__') +def is_numpy_namespace(xp) -> bool: + """ + Returns True if `xp` is a NumPy namespace. + + This includes both NumPy itself and the version wrapped by array-api-compat. + + See Also + -------- + + array_namespace + is_cupy_namespace + is_torch_namespace + is_ndonnx_namespace + is_dask_namespace + is_jax_namespace + is_pydata_sparse_namespace + is_array_api_strict_namespace + """ + return xp.__name__ == 'numpy' or 'array_api_compat.numpy' in xp.__name__ + +def is_cupy_namespace(xp) -> bool: + """ + Returns True if `xp` is a CuPy namespace. + + This includes both CuPy itself and the version wrapped by array-api-compat. + + See Also + -------- + + array_namespace + is_numpy_namespace + is_torch_namespace + is_ndonnx_namespace + is_dask_namespace + is_jax_namespace + is_pydata_sparse_namespace + is_array_api_strict_namespace + """ + return xp.__name__ == 'cupy' or 'array_api_compat.cupy' in xp.__name__ + +def is_torch_namespace(xp) -> bool: + """ + Returns True if `xp` is a PyTorch namespace. + + This includes both PyTorch itself and the version wrapped by array-api-compat. + + See Also + -------- + + array_namespace + is_numpy_namespace + is_cupy_namespace + is_ndonnx_namespace + is_dask_namespace + is_jax_namespace + is_pydata_sparse_namespace + is_array_api_strict_namespace + """ + return xp.__name__ == 'torch' or 'array_api_compat.torch' in xp.__name__ + +def is_ndonnx_namespace(xp): + """ + Returns True if `xp` is an NDONNX namespace. + + See Also + -------- + + array_namespace + is_numpy_namespace + is_cupy_namespace + is_torch_namespace + is_dask_namespace + is_jax_namespace + is_pydata_sparse_namespace + is_array_api_strict_namespace + """ + return xp.__name__ == 'ndonnx' + +def is_dask_namespace(xp): + """ + Returns True if `xp` is a Dask namespace. + + This includes both ``dask.array`` itself and the version wrapped by array-api-compat. + + See Also + -------- + + array_namespace + is_numpy_namespace + is_cupy_namespace + is_torch_namespace + is_ndonnx_namespace + is_jax_namespace + is_pydata_sparse_namespace + is_array_api_strict_namespace + """ + return xp.__name__ == 'dask.array' or 'array_api_compat.dask.array' in xp.__name__ + +def is_jax_namespace(xp): + """ + Returns True if `xp` is a JAX namespace. + + This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in + older versions of JAX. + + See Also + -------- + + array_namespace + is_numpy_namespace + is_cupy_namespace + is_torch_namespace + is_ndonnx_namespace + is_dask_namespace + is_pydata_sparse_namespace + is_array_api_strict_namespace + """ + return xp.__name__ in ('jax.numpy', 'jax.experimental.array_api') + +def is_pydata_sparse_namespace(xp): + """ + Returns True if `xp` is a pydata/sparse namespace. + + See Also + -------- + + array_namespace + is_numpy_namespace + is_cupy_namespace + is_torch_namespace + is_ndonnx_namespace + is_dask_namespace + is_jax_namespace + is_array_api_strict_namespace + """ + return xp.__name__ == 'sparse' + +def is_array_api_strict_namespace(xp): + """ + Returns True if `xp` is an array-api-strict namespace. + + See Also + -------- + + array_namespace + is_numpy_namespace + is_cupy_namespace + is_torch_namespace + is_ndonnx_namespace + is_dask_namespace + is_jax_namespace + is_pydata_sparse_namespace + """ + return xp.__name__ == 'array_api_strict' + def _check_api_version(api_version): if api_version == '2021.12': warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") @@ -643,13 +797,21 @@ def size(x): "device", "get_namespace", "is_array_api_obj", + "is_array_api_strict_namespace", "is_cupy_array", + "is_cupy_namespace", "is_dask_array", + "is_dask_namespace", "is_jax_array", + "is_jax_namespace", "is_numpy_array", + "is_numpy_namespace", "is_torch_array", + "is_torch_namespace", "is_ndonnx_array", + "is_ndonnx_namespace", "is_pydata_sparse_array", + "is_pydata_sparse_namespace", "size", "to_device", ]