Skip to content

Commit

Permalink
add is_*_namespace helper functions
Browse files Browse the repository at this point in the history
closes gh-156
  • Loading branch information
lucascolley committed Aug 15, 2024
1 parent f3145b0 commit 67f3c9f
Showing 1 changed file with 163 additions and 1 deletion.
164 changes: 163 additions & 1 deletion array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def is_jax_array(x):

return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)


def is_pydata_sparse_array(x) -> bool:
"""
Return True if `x` is an array from the `sparse` package.
Expand Down Expand Up @@ -255,6 +254,161 @@ def is_array_api_obj(x):
or is_pydata_sparse_array(x) \
or hasattr(x, '__array_namespace__')

def is_numpy_namespace(xp) -> bool:
"""
Returns True if `xp` is a NumPy namespace.
This includes both NumPy itself and the version wrapped by array-api-compat.
See Also
--------
array_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ == 'numpy' or 'array_api_compat.numpy' in xp.__name__

def is_cupy_namespace(xp) -> bool:
"""
Returns True if `xp` is a CuPy namespace.
This includes both CuPy itself and the version wrapped by array-api-compat.
See Also
--------
array_namespace
is_numpy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ == 'cupy' or 'array_api_compat.cupy' in xp.__name__

def is_torch_namespace(xp) -> bool:
"""
Returns True if `xp` is a PyTorch namespace.
This includes both PyTorch itself and the version wrapped by array-api-compat.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ == 'torch' or 'array_api_compat.torch' in xp.__name__

def is_ndonnx_namespace(xp):
"""
Returns True if `xp` is an NDONNX namespace.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ == 'ndonnx'

def is_dask_namespace(xp):
"""
Returns True if `xp` is a Dask namespace.
This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ == 'dask.array' or 'array_api_compat.dask.array' in xp.__name__

def is_jax_namespace(xp):
"""
Returns True if `xp` is a JAX namespace.
This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
older versions of JAX.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in ('jax.numpy', 'jax.experimental.array_api')

def is_pydata_sparse_namespace(xp):
"""
Returns True if `xp` is a pydata/sparse namespace.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_array_api_strict_namespace
"""
return xp.__name__ == 'sparse'

def is_array_api_strict_namespace(xp):
"""
Returns True if `xp` is an array-api-strict namespace.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
"""
return xp.__name__ == 'array_api_strict'

def _check_api_version(api_version):
if api_version == '2021.12':
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
Expand Down Expand Up @@ -643,13 +797,21 @@ def size(x):
"device",
"get_namespace",
"is_array_api_obj",
"is_array_api_strict_namespace",
"is_cupy_array",
"is_cupy_namespace",
"is_dask_array",
"is_dask_namespace",
"is_jax_array",
"is_jax_namespace",
"is_numpy_array",
"is_numpy_namespace",
"is_torch_array",
"is_torch_namespace",
"is_ndonnx_array",
"is_ndonnx_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"size",
"to_device",
]
Expand Down

0 comments on commit 67f3c9f

Please sign in to comment.