diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 1ffba901..fad41102 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -5,6 +5,8 @@ from ..common import _aliases from .._internal import get_xp +from ._info import __array_namespace_info__ + from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Union @@ -123,9 +125,10 @@ def asarray( else: unstack = get_xp(cp)(_aliases.unstack) -__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] +__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool', + 'acos', 'acosh', 'asin', 'asinh', 'atan', + 'atan2', 'atanh', 'bitwise_left_shift', + 'bitwise_invert', 'bitwise_right_shift', + 'concat', 'pow'] _all_ignore = ['cp', 'get_xp'] diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py new file mode 100644 index 00000000..4440807d --- /dev/null +++ b/array_api_compat/cupy/_info.py @@ -0,0 +1,326 @@ +""" +Array API Inspection namespace + +This is the namespace for inspection functions as defined by the array API +standard. See +https://data-apis.org/array-api/latest/API_specification/inspection.html for +more details. + +""" +from cupy import ( + dtype, + cuda, + bool_ as bool, + intp, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + complex64, + complex128, +) + +class __array_namespace_info__: + """ + Get the array API inspection namespace for CuPy. + + The array API inspection namespace defines the following functions: + + - capabilities() + - default_device() + - default_dtypes() + - dtypes() + - devices() + + See + https://data-apis.org/array-api/latest/API_specification/inspection.html + for more details. + + Returns + ------- + info : ModuleType + The array API inspection namespace for CuPy. + + Examples + -------- + >>> info = np.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': cupy.float64, + 'complex floating': cupy.complex128, + 'integral': cupy.int64, + 'indexing': cupy.int64} + + """ + + __module__ = 'cupy' + + def capabilities(self): + """ + Return a dictionary of array API library capabilities. + + The resulting dictionary has the following keys: + + - **"boolean indexing"**: boolean indicating whether an array library + supports boolean indexing. Always ``True`` for CuPy. + + - **"data-dependent shapes"**: boolean indicating whether an array + library supports data-dependent output shapes. Always ``True`` for + CuPy. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html + for more details. + + See Also + -------- + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + capabilities : dict + A dictionary of array API library capabilities. + + Examples + -------- + >>> info = xp.__array_namespace_info__() + >>> info.capabilities() + {'boolean indexing': True, + 'data-dependent shapes': True} + + """ + return { + "boolean indexing": True, + "data-dependent shapes": True, + # 'max rank' will be part of the 2024.12 standard + # "max rank": 64, + } + + def default_device(self): + """ + The default device used for new CuPy arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Returns + ------- + device : str + The default device used for new CuPy arrays. + + Examples + -------- + >>> info = xp.__array_namespace_info__() + >>> info.default_device() + Device(0) + + """ + return cuda.Device(0) + + def default_dtypes(self, *, device=None): + """ + The default data types used for new CuPy arrays. + + For CuPy, this always returns the following dictionary: + + - **"real floating"**: ``cupy.float64`` + - **"complex floating"**: ``cupy.complex128`` + - **"integral"**: ``cupy.intp`` + - **"indexing"**: ``cupy.intp`` + + Parameters + ---------- + device : str, optional + The device to get the default data types for. + + Returns + ------- + dtypes : dict + A dictionary describing the default data types used for new CuPy + arrays. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = xp.__array_namespace_info__() + >>> info.default_dtypes() + {'real floating': cupy.float64, + 'complex floating': cupy.complex128, + 'integral': cupy.int64, + 'indexing': cupy.int64} + + """ + # TODO: Does this depend on device? + return { + "real floating": dtype(float64), + "complex floating": dtype(complex128), + "integral": dtype(intp), + "indexing": dtype(intp), + } + + def dtypes(self, *, device=None, kind=None): + """ + The array API data types supported by CuPy. + + Note that this function only returns data types that are defined by + the array API. + + Parameters + ---------- + device : str, optional + The device to get the data types for. + kind : str or tuple of str, optional + The kind of data types to return. If ``None``, all data types are + returned. If a string, only data types of that kind are returned. + If a tuple, a dictionary containing the union of the given kinds + is returned. The following kinds are supported: + + - ``'bool'``: boolean data types (i.e., ``bool``). + - ``'signed integer'``: signed integer data types (i.e., ``int8``, + ``int16``, ``int32``, ``int64``). + - ``'unsigned integer'``: unsigned integer data types (i.e., + ``uint8``, ``uint16``, ``uint32``, ``uint64``). + - ``'integral'``: integer data types. Shorthand for ``('signed + integer', 'unsigned integer')``. + - ``'real floating'``: real-valued floating-point data types + (i.e., ``float32``, ``float64``). + - ``'complex floating'``: complex floating-point data types (i.e., + ``complex64``, ``complex128``). + - ``'numeric'``: numeric data types. Shorthand for ``('integral', + 'real floating', 'complex floating')``. + + Returns + ------- + dtypes : dict + A dictionary mapping the names of data types to the corresponding + CuPy data types. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.devices + + Examples + -------- + >>> info = xp.__array_namespace_info__() + >>> info.dtypes(kind='signed integer') + {'int8': cupy.int8, + 'int16': cupy.int16, + 'int32': cupy.int32, + 'int64': cupy.int64} + + """ + # TODO: Does this depend on device? + if kind is None: + return { + "bool": dtype(bool), + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + "float32": dtype(float32), + "float64": dtype(float64), + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + } + if kind == "unsigned integer": + return { + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + } + if kind == "integral": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + } + if kind == "real floating": + return { + "float32": dtype(float32), + "float64": dtype(float64), + } + if kind == "complex floating": + return { + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if kind == "numeric": + return { + "int8": dtype(int8), + "int16": dtype(int16), + "int32": dtype(int32), + "int64": dtype(int64), + "uint8": dtype(uint8), + "uint16": dtype(uint16), + "uint32": dtype(uint32), + "uint64": dtype(uint64), + "float32": dtype(float32), + "float64": dtype(float64), + "complex64": dtype(complex64), + "complex128": dtype(complex128), + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(self.dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") + + def devices(self): + """ + The devices supported by CuPy. + + Returns + ------- + devices : list of str + The devices supported by CuPy. + + See Also + -------- + __array_namespace_info__.capabilities, + __array_namespace_info__.default_device, + __array_namespace_info__.default_dtypes, + __array_namespace_info__.dtypes + + """ + return [cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())]