diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index a0c8d7f..02fbdd1 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -776,6 +776,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] if not hasattr(x, "__array_namespace__"): # In JAX v0.4.31 and older, this import adds to_device method to x. import jax.experimental.array_api # noqa: F401 + return x.to_device(device, stream=stream) elif is_pydata_sparse_array(x) and device == _device(x): # Perform trivial check to return the same array if # device is same instead of err-ing.