Skip to content

Commit

Permalink
Expose is_supported_dtype to the public interface (#150)
Browse files Browse the repository at this point in the history
Also take this opportunity to clean up a naming inconsistency;
NumPy types are "dtypes", core types are "types".
  • Loading branch information
manopapad authored Mar 21, 2024
1 parent 6dd4320 commit 0624001
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 19 deletions.
1 change: 1 addition & 0 deletions cunumeric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ._array.util import maybe_convert_to_np_ndarray
from ._module import *
from ._ufunc import *
from ._utils.array import is_supported_dtype
from ._utils.coverage import clone_module

clone_module(_np, globals(), maybe_convert_to_np_ndarray)
Expand Down
10 changes: 5 additions & 5 deletions cunumeric/_array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)

from .. import _ufunc
from .._utils.array import calculate_volume, to_core_dtype
from .._utils.array import calculate_volume, to_core_type
from .._utils.coverage import FALLBACK_WARNING, clone_class, is_implemented
from .._utils.linalg import dot_modes
from .._utils.structure import deep_apply
Expand Down Expand Up @@ -128,7 +128,7 @@ def __init__(
for inp in inputs
if isinstance(inp, ndarray)
]
core_dtype = to_core_dtype(dtype)
core_dtype = to_core_type(dtype)
self._thunk = runtime.create_empty_thunk(
sanitized_shape, core_dtype, inputs
)
Expand Down Expand Up @@ -660,7 +660,7 @@ def __contains__(self, item: Any) -> ndarray:
args = (np.array(item, dtype=self.dtype),)
if args[0].size != 1:
raise ValueError("contains needs scalar item")
core_dtype = to_core_dtype(self.dtype)
core_dtype = to_core_type(self.dtype)
return perform_unary_reduction(
UnaryRedCode.CONTAINS,
self,
Expand Down Expand Up @@ -1975,7 +1975,7 @@ def clip(
return convert_to_cunumeric_ndarray(
self.__array__().clip(args[0], args[1])
)
core_dtype = to_core_dtype(self.dtype)
core_dtype = to_core_type(self.dtype)
extra_args = (Scalar(min, core_dtype), Scalar(max, core_dtype))
return perform_unary_op(
UnaryOpCode.CLIP, self, out=out, extra_args=extra_args
Expand Down Expand Up @@ -2971,7 +2971,7 @@ def var(
# FIXME(wonchanl): the following code blocks on mu to convert
# it to a Scalar object. We need to get rid of this blocking by
# allowing the extra arguments to be Legate stores
args=(Scalar(mu.__array__(), to_core_dtype(self.dtype)),),
args=(Scalar(mu.__array__(), to_core_type(self.dtype)),),
)
else:
# TODO(https://github.com/nv-legate/cunumeric/issues/591)
Expand Down
4 changes: 2 additions & 2 deletions cunumeric/_thunk/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
normalize_axis_tuple,
)

from .._utils.array import is_advanced_indexing, to_core_dtype
from .._utils.array import is_advanced_indexing, to_core_type
from ..config import (
BinaryOpCode,
BitGeneratorDistribution,
Expand Down Expand Up @@ -1701,7 +1701,7 @@ def select(
c_arr = c._broadcast(self.shape)
task.add_input(c_arr)
task.add_alignment(c_arr, out_arr)
task.add_scalar_arg(default, to_core_dtype(default.dtype))
task.add_scalar_arg(default, to_core_type(default.dtype))
task.execute()

# Create or extract a diagonal from a matrix
Expand Down
17 changes: 15 additions & 2 deletions cunumeric/_utils/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,24 @@
}


def is_supported_type(dtype: str | np.dtype[Any]) -> bool:
def is_supported_dtype(dtype: str | np.dtype[Any]) -> bool:
"""
Whether a NumPy dtype is supported by cuNumeric
Parameters
----------
dtype : data-type
The dtype to query
Returns
-------
res : bool
True if `dtype` is a supported dtype
"""
return np.dtype(dtype) in SUPPORTED_DTYPES


def to_core_dtype(dtype: str | np.dtype[Any]) -> ty.Type:
def to_core_type(dtype: str | np.dtype[Any]) -> ty.Type:
core_dtype = SUPPORTED_DTYPES.get(np.dtype(dtype))
if core_dtype is None:
raise TypeError(f"cuNumeric does not support dtype={dtype}")
Expand Down
8 changes: 4 additions & 4 deletions cunumeric/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from legate.core import LEGATE_MAX_DIM, Scalar, TaskTarget, get_legate_runtime
from legate.settings import settings as legate_settings

from ._utils.array import calculate_volume, is_supported_type, to_core_dtype
from ._utils.array import calculate_volume, is_supported_dtype, to_core_type
from ._utils.stack import find_last_user_stacklevel
from .config import (
BitGeneratorOperation,
Expand Down Expand Up @@ -60,7 +60,7 @@ def thunk_from_scalar(
from ._thunk.deferred import DeferredArray

store = legate_runtime.create_store_from_scalar(
Scalar(bytes, to_core_dtype(dtype)),
Scalar(bytes, to_core_type(dtype)),
shape=shape,
)
return DeferredArray(store)
Expand Down Expand Up @@ -377,7 +377,7 @@ def find_or_create_array_thunk(
from ._thunk.deferred import DeferredArray

assert isinstance(array, np.ndarray)
if not is_supported_type(array.dtype):
if not is_supported_dtype(array.dtype):
raise TypeError(f"cuNumeric does not support dtype={array.dtype}")

# We have to be really careful here to handle the case of
Expand Down Expand Up @@ -429,7 +429,7 @@ def find_or_create_array_thunk(
# This is not a scalar so make a field.
# We won't try to cache these bigger arrays.
store = legate_runtime.create_store_from_buffer(
to_core_dtype(array.dtype),
to_core_type(array.dtype),
array.shape,
array.copy() if transfer == TransferType.MAKE_COPY else array,
# This argument should really be called "donate"
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_argsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_structured_array_order(self):
# if self.deferred is None:
# if self.parent is None:
#
# > assert self.runtime.is_supported_type(self.array.dtype)
# > assert self.runtime.is_supported_dtype(self.array.dtype)
# E
# AssertionError
#
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_dtype_complex(self, dtype):
# allclose hits assertion error:
# File "/legate/cunumeric/cunumeric/eager.py", line 293,
# in to_deferred_array
# assert self.runtime.is_supported_type(self.array.dtype)
# assert self.runtime.is_supported_dtype(self.array.dtype)
# AssertionError
assert allclose(out_np, out_num)

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_val_none(self):
# cuNumeric raises AssertionError
# if self.deferred is None:
# if self.parent is None:
# > assert self.runtime.is_supported_type
# > assert self.runtime.is_supported_dtype
# (self.array.dtype)
# E AssertionError
# cunumeric/cunumeric/eager.py:to_deferred_array()
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/cunumeric/test_utils_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,17 @@ class Test_is_supported_dtype:
@pytest.mark.parametrize("value", ["foo", 10, 10.2, (), set()])
def test_type_bad(self, value) -> None:
with pytest.raises(TypeError):
m.to_core_dtype(value)
m.to_core_type(value)

@pytest.mark.parametrize("value", EXPECTED_SUPPORTED_DTYPES)
def test_supported(self, value) -> None:
m.to_core_dtype(value)
m.to_core_type(value)

# This is just a representative sample, not exhasutive
@pytest.mark.parametrize("value", [np.float128, np.datetime64, [], {}])
def test_unsupported(self, value) -> None:
with pytest.raises(TypeError):
m.to_core_dtype(value)
m.to_core_type(value)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 0624001

Please sign in to comment.