diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index a46c7a8..7924a85 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -67,6 +67,8 @@ def asarray( if dtype is not None: _np_dtype = dtype._np_dtype _check_device(device) + if isinstance(obj, Array) and device is None: + device = obj.device if np.__version__[0] < '2': if copy is False: @@ -158,6 +160,8 @@ def empty_like( _check_valid_dtype(dtype) _check_device(device) + if device is None: + device = x.device if dtype is not None: dtype = dtype._np_dtype @@ -260,6 +264,8 @@ def full_like( _check_valid_dtype(dtype) _check_device(device) + if device is None: + device = x.device if dtype is not None: dtype = dtype._np_dtype diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 819afad..71fd76b 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -23,7 +23,7 @@ zeros_like, ) from .._dtypes import float32, float64 -from .._array_object import Array, CPU_DEVICE +from .._array_object import Array, CPU_DEVICE, Device from .._flags import set_array_api_strict_flags def test_asarray_errors(): @@ -97,6 +97,17 @@ def test_asarray_copy(): a[0] = 0 assert all(b[0] == 0) + +def test_asarray_device_inference(): + assert asarray([1, 2, 3]).device == CPU_DEVICE + + x = asarray([1, 2, 3]) + assert asarray(x).device == CPU_DEVICE + + device1 = Device("device1") + x = asarray([1, 2, 3], device=device1) + assert asarray(x).device == device1 + def test_arange_errors(): arange(1, device=CPU_DEVICE) # Doesn't error assert_raises(ValueError, lambda: arange(1, device="cpu"))