Skip to content

Commit

Permalink
add is_*_namespace helper functions
Browse files Browse the repository at this point in the history
closes gh-156
  • Loading branch information
lucascolley committed Aug 15, 2024
1 parent f3145b0 commit 733d17c
Showing 1 changed file with 168 additions and 1 deletion.
169 changes: 168 additions & 1 deletion array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def is_jax_array(x):

return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)


def is_pydata_sparse_array(x) -> bool:
"""
Return True if `x` is an array from the `sparse` package.
Expand Down Expand Up @@ -255,6 +254,166 @@ def is_array_api_obj(x):
or is_pydata_sparse_array(x) \
or hasattr(x, '__array_namespace__')

def _compat_module_name():
assert __name__.endswith('.common._helpers')
return __name__.removesuffix('.common._helpers')

def is_numpy_namespace(xp) -> bool:
"""
Returns True if `xp` is a NumPy namespace.
This includes both NumPy itself and the version wrapped by array-api-compat.
See Also
--------
array_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'numpy', _compat_module_name + '.numpy'}

def is_cupy_namespace(xp) -> bool:
"""
Returns True if `xp` is a CuPy namespace.
This includes both CuPy itself and the version wrapped by array-api-compat.
See Also
--------
array_namespace
is_numpy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'cupy', _compat_module_name + '.cupy'}

def is_torch_namespace(xp) -> bool:
"""
Returns True if `xp` is a PyTorch namespace.
This includes both PyTorch itself and the version wrapped by array-api-compat.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'torch', _compat_module_name + '.torch'}


def is_ndonnx_namespace(xp):
"""
Returns True if `xp` is an NDONNX namespace.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ == 'ndonnx'

def is_dask_namespace(xp):
"""
Returns True if `xp` is a Dask namespace.
This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'dask.array', _compat_module_name + '.dask.array'}

def is_jax_namespace(xp):
"""
Returns True if `xp` is a JAX namespace.
This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
older versions of JAX.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}

def is_pydata_sparse_namespace(xp):
"""
Returns True if `xp` is a pydata/sparse namespace.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_array_api_strict_namespace
"""
return xp.__name__ == 'sparse'

def is_array_api_strict_namespace(xp):
"""
Returns True if `xp` is an array-api-strict namespace.
See Also
--------
array_namespace
is_numpy_namespace
is_cupy_namespace
is_torch_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
"""
return xp.__name__ == 'array_api_strict'

def _check_api_version(api_version):
if api_version == '2021.12':
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
Expand Down Expand Up @@ -643,13 +802,21 @@ def size(x):
"device",
"get_namespace",
"is_array_api_obj",
"is_array_api_strict_namespace",
"is_cupy_array",
"is_cupy_namespace",
"is_dask_array",
"is_dask_namespace",
"is_jax_array",
"is_jax_namespace",
"is_numpy_array",
"is_numpy_namespace",
"is_torch_array",
"is_torch_namespace",
"is_ndonnx_array",
"is_ndonnx_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"size",
"to_device",
]
Expand Down

0 comments on commit 733d17c

Please sign in to comment.