Skip to content

Commit

Permalink
Propagate input array's device in asarray
Browse files Browse the repository at this point in the history
  • Loading branch information
betatim committed Oct 21, 2024
1 parent b625bbe commit 349c4ff
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
6 changes: 6 additions & 0 deletions array_api_strict/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def asarray(
if dtype is not None:
_np_dtype = dtype._np_dtype
_check_device(device)
if isinstance(obj, Array) and device is None:
device = obj.device

if np.__version__[0] < '2':
if copy is False:
Expand Down Expand Up @@ -158,6 +160,8 @@ def empty_like(

_check_valid_dtype(dtype)
_check_device(device)
if device is None:
device = x.device

if dtype is not None:
dtype = dtype._np_dtype
Expand Down Expand Up @@ -260,6 +264,8 @@ def full_like(

_check_valid_dtype(dtype)
_check_device(device)
if device is None:
device = x.device

if dtype is not None:
dtype = dtype._np_dtype
Expand Down
13 changes: 12 additions & 1 deletion array_api_strict/tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
zeros_like,
)
from .._dtypes import float32, float64
from .._array_object import Array, CPU_DEVICE
from .._array_object import Array, CPU_DEVICE, Device
from .._flags import set_array_api_strict_flags

def test_asarray_errors():
Expand Down Expand Up @@ -97,6 +97,17 @@ def test_asarray_copy():
a[0] = 0
assert all(b[0] == 0)


def test_asarray_device_inference():
assert asarray([1, 2, 3]).device == CPU_DEVICE

x = asarray([1, 2, 3])
assert asarray(x).device == CPU_DEVICE

device1 = Device("device1")
x = asarray([1, 2, 3], device=device1)
assert asarray(x).device == device1

def test_arange_errors():
arange(1, device=CPU_DEVICE) # Doesn't error
assert_raises(ValueError, lambda: arange(1, device="cpu"))
Expand Down

0 comments on commit 349c4ff

Please sign in to comment.