Skip to content

Commit

Permalink
Add CuPy inspection APIs
Browse files Browse the repository at this point in the history
I'm not sure if all the details here are correct. See
#127 (comment).
  • Loading branch information
asmeurer committed Aug 19, 2024
1 parent 11cb6ef commit 4c9dd0e
Show file tree
Hide file tree
Showing 2 changed files with 333 additions and 4 deletions.
11 changes: 7 additions & 4 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from ..common import _aliases
from .._internal import get_xp

from ._info import __array_namespace_info__

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union
Expand Down Expand Up @@ -123,9 +125,10 @@ def asarray(
else:
unstack = get_xp(cp)(_aliases.unstack)

__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow']
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
'acos', 'acosh', 'asin', 'asinh', 'atan',
'atan2', 'atanh', 'bitwise_left_shift',
'bitwise_invert', 'bitwise_right_shift',
'concat', 'pow']

_all_ignore = ['cp', 'get_xp']
326 changes: 326 additions & 0 deletions array_api_compat/cupy/_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
"""
Array API Inspection namespace
This is the namespace for inspection functions as defined by the array API
standard. See
https://data-apis.org/array-api/latest/API_specification/inspection.html for
more details.
"""
from cupy import (
dtype,
cuda,
bool_ as bool,
intp,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
float32,
float64,
complex64,
complex128,
)

class __array_namespace_info__:
"""
Get the array API inspection namespace for CuPy.
The array API inspection namespace defines the following functions:
- capabilities()
- default_device()
- default_dtypes()
- dtypes()
- devices()
See
https://data-apis.org/array-api/latest/API_specification/inspection.html
for more details.
Returns
-------
info : ModuleType
The array API inspection namespace for CuPy.
Examples
--------
>>> info = np.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': cupy.float64,
'complex floating': cupy.complex128,
'integral': cupy.int64,
'indexing': cupy.int64}
"""

__module__ = 'cupy'

def capabilities(self):
"""
Return a dictionary of array API library capabilities.
The resulting dictionary has the following keys:
- **"boolean indexing"**: boolean indicating whether an array library
supports boolean indexing. Always ``True`` for CuPy.
- **"data-dependent shapes"**: boolean indicating whether an array
library supports data-dependent output shapes. Always ``True`` for
CuPy.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html
for more details.
See Also
--------
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
capabilities : dict
A dictionary of array API library capabilities.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.capabilities()
{'boolean indexing': True,
'data-dependent shapes': True}
"""
return {
"boolean indexing": True,
"data-dependent shapes": True,
# 'max rank' will be part of the 2024.12 standard
# "max rank": 64,
}

def default_device(self):
"""
The default device used for new CuPy arrays.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Returns
-------
device : str
The default device used for new CuPy arrays.
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_device()
Device(0)
"""
return cuda.Device(0)

def default_dtypes(self, *, device=None):
"""
The default data types used for new CuPy arrays.
For CuPy, this always returns the following dictionary:
- **"real floating"**: ``cupy.float64``
- **"complex floating"**: ``cupy.complex128``
- **"integral"**: ``cupy.intp``
- **"indexing"**: ``cupy.intp``
Parameters
----------
device : str, optional
The device to get the default data types for.
Returns
-------
dtypes : dict
A dictionary describing the default data types used for new CuPy
arrays.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.default_dtypes()
{'real floating': cupy.float64,
'complex floating': cupy.complex128,
'integral': cupy.int64,
'indexing': cupy.int64}
"""
# TODO: Does this depend on device?
return {
"real floating": dtype(float64),
"complex floating": dtype(complex128),
"integral": dtype(intp),
"indexing": dtype(intp),
}

def dtypes(self, *, device=None, kind=None):
"""
The array API data types supported by CuPy.
Note that this function only returns data types that are defined by
the array API.
Parameters
----------
device : str, optional
The device to get the data types for.
kind : str or tuple of str, optional
The kind of data types to return. If ``None``, all data types are
returned. If a string, only data types of that kind are returned.
If a tuple, a dictionary containing the union of the given kinds
is returned. The following kinds are supported:
- ``'bool'``: boolean data types (i.e., ``bool``).
- ``'signed integer'``: signed integer data types (i.e., ``int8``,
``int16``, ``int32``, ``int64``).
- ``'unsigned integer'``: unsigned integer data types (i.e.,
``uint8``, ``uint16``, ``uint32``, ``uint64``).
- ``'integral'``: integer data types. Shorthand for ``('signed
integer', 'unsigned integer')``.
- ``'real floating'``: real-valued floating-point data types
(i.e., ``float32``, ``float64``).
- ``'complex floating'``: complex floating-point data types (i.e.,
``complex64``, ``complex128``).
- ``'numeric'``: numeric data types. Shorthand for ``('integral',
'real floating', 'complex floating')``.
Returns
-------
dtypes : dict
A dictionary mapping the names of data types to the corresponding
CuPy data types.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.devices
Examples
--------
>>> info = xp.__array_namespace_info__()
>>> info.dtypes(kind='signed integer')
{'int8': cupy.int8,
'int16': cupy.int16,
'int32': cupy.int32,
'int64': cupy.int64}
"""
# TODO: Does this depend on device?
if kind is None:
return {
"bool": dtype(bool),
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
"float32": dtype(float32),
"float64": dtype(float64),
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if kind == "bool":
return {"bool": bool}
if kind == "signed integer":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
}
if kind == "unsigned integer":
return {
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
}
if kind == "integral":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
}
if kind == "real floating":
return {
"float32": dtype(float32),
"float64": dtype(float64),
}
if kind == "complex floating":
return {
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if kind == "numeric":
return {
"int8": dtype(int8),
"int16": dtype(int16),
"int32": dtype(int32),
"int64": dtype(int64),
"uint8": dtype(uint8),
"uint16": dtype(uint16),
"uint32": dtype(uint32),
"uint64": dtype(uint64),
"float32": dtype(float32),
"float64": dtype(float64),
"complex64": dtype(complex64),
"complex128": dtype(complex128),
}
if isinstance(kind, tuple):
res = {}
for k in kind:
res.update(self.dtypes(kind=k))
return res
raise ValueError(f"unsupported kind: {kind!r}")

def devices(self):
"""
The devices supported by CuPy.
Returns
-------
devices : list of str
The devices supported by CuPy.
See Also
--------
__array_namespace_info__.capabilities,
__array_namespace_info__.default_device,
__array_namespace_info__.default_dtypes,
__array_namespace_info__.dtypes
"""
return [cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())]

0 comments on commit 4c9dd0e

Please sign in to comment.