From 8ee16138094d44a65be73af4b88a5b3cc74d2203 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Mon, 18 Dec 2023 19:40:29 -0800 Subject: [PATCH 01/22] Add dask to array-api-compat --- .github/workflows/array-api-tests-dask.yml | 9 +++ array_api_compat/common/_aliases.py | 6 +- array_api_compat/common/_helpers.py | 24 ++++++ array_api_compat/dask/__init__.py | 6 ++ array_api_compat/dask/_aliases.py | 88 ++++++++++++++++++++++ dask-skips.txt | 0 dask-xfails.txt | 47 ++++++++++++ tests/test_array_namespace.py | 5 +- tests/test_common.py | 2 +- tests/test_isdtype.py | 4 +- tests/test_vendoring.py | 4 + vendor_test/uses_dask.py | 19 +++++ 12 files changed, 208 insertions(+), 6 deletions(-) create mode 100644 .github/workflows/array-api-tests-dask.yml create mode 100644 array_api_compat/dask/__init__.py create mode 100644 array_api_compat/dask/_aliases.py create mode 100644 dask-skips.txt create mode 100644 dask-xfails.txt create mode 100644 vendor_test/uses_dask.py diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml new file mode 100644 index 00000000..0b516ddb --- /dev/null +++ b/.github/workflows/array-api-tests-dask.yml @@ -0,0 +1,9 @@ +name: Array API Tests (Dask) + +on: [push, pull_request] + +jobs: + array-api-tests-dask: + uses: ./.github/workflows/array-api-tests.yml + with: + package-name: dask diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index e5231770..41da4d64 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -303,6 +303,8 @@ def _asarray( import numpy as xp elif namespace == 'cupy': import cupy as xp + elif namespace == 'dask': + import dask.array as xp else: raise ValueError("Unrecognized namespace argument to asarray()") @@ -322,7 +324,9 @@ def _asarray( if copy in COPY_FALSE: # copy=False is not yet implemented in xp.asarray raise NotImplementedError("copy=False is not yet implemented") - if isinstance(obj, xp.ndarray): + # TODO: This feels wrong (__array__ is not in the standard) + # Dask doesn't support DLPack, though, so, this'll do + if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"): if dtype is not None and obj.dtype != dtype: copy = True if copy in COPY_TRUE: diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 61aee4be..94cbb902 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -40,6 +40,16 @@ def _is_torch_array(x): # TODO: Should we reject ndarray subclasses? return isinstance(x, torch.Tensor) +def _is_dask_array(x): + # Avoid importing dask if it isn't already + if 'dask.array' not in sys.modules: + return False + + import dask.array + + # TODO: Should we reject ndarray subclasses? + return isinstance(x, dask.array.Array) + def is_array_api_obj(x): """ Check if x is an array API compatible array object. @@ -97,6 +107,13 @@ def your_function(x, y): else: import torch namespaces.add(torch) + elif _is_dask_array(x): + _check_api_version(api_version) + if _use_compat: + from .. import dask as dask_namespace + namespaces.add(dask_namespace) + else: + raise TypeError("_use_compat cannot be False if input array is a dask array!") else: # TODO: Support Python scalars? raise TypeError("The input is not a supported array type") @@ -219,6 +236,13 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A return _cupy_to_device(x, device, stream=stream) elif _is_torch_array(x): return _torch_to_device(x, device, stream=stream) + elif _is_dask_array(x): + if stream is not None: + raise ValueError("The stream argument to to_device() is not supported") + # TODO: What if our array is on the GPU already? + if device == 'cpu': + return x + raise ValueError(f"Unsupported device {device!r}") return x.to_device(device, stream=stream) def size(x): diff --git a/array_api_compat/dask/__init__.py b/array_api_compat/dask/__init__.py new file mode 100644 index 00000000..12e3a92d --- /dev/null +++ b/array_api_compat/dask/__init__.py @@ -0,0 +1,6 @@ +from dask.array import * + +# These imports may overwrite names from the import * above. +from ._aliases import * + +__array_api_version__ = '2022.12' diff --git a/array_api_compat/dask/_aliases.py b/array_api_compat/dask/_aliases.py new file mode 100644 index 00000000..93fe4594 --- /dev/null +++ b/array_api_compat/dask/_aliases.py @@ -0,0 +1,88 @@ +from ..common import _aliases + +from .._internal import get_xp + +import numpy as np +from numpy import ( + # Constants + e, + inf, + nan, + pi, + newaxis, + # Dtypes + bool_ as bool, + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + complex64, + complex128, + iinfo, + finfo, + can_cast, + result_type, +) + +import dask.array as da + +isdtype = get_xp(np)(_aliases.isdtype) +astype = _aliases.astype + +# Common aliases +arange = get_xp(da)(_aliases.arange) + +from functools import partial +asarray = partial(_aliases._asarray, namespace='dask') +asarray.__doc__ = _aliases._asarray.__doc__ + +linspace = get_xp(da)(_aliases.linspace) +eye = get_xp(da)(_aliases.eye) +UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult) +UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult) +UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult) +unique_all = get_xp(da)(_aliases.unique_all) +unique_counts = get_xp(da)(_aliases.unique_counts) +unique_inverse = get_xp(da)(_aliases.unique_inverse) +unique_values = get_xp(da)(_aliases.unique_values) +permute_dims = get_xp(da)(_aliases.permute_dims) +std = get_xp(da)(_aliases.std) +var = get_xp(da)(_aliases.var) +empty = get_xp(da)(_aliases.empty) +empty_like = get_xp(da)(_aliases.empty_like) +full = get_xp(da)(_aliases.full) +full_like = get_xp(da)(_aliases.full_like) +ones = get_xp(da)(_aliases.ones) +ones_like = get_xp(da)(_aliases.ones_like) +zeros = get_xp(da)(_aliases.zeros) +zeros_like = get_xp(da)(_aliases.zeros_like) +reshape = get_xp(da)(_aliases.reshape) +matrix_transpose = get_xp(da)(_aliases.matrix_transpose) +vecdot = get_xp(da)(_aliases.vecdot) + + + +from dask.array import ( + # Element wise aliases + arccos as acos, + arccosh as acosh, + arcsin as asin, + arcsinh as asinh, + arctan as atan, + arctan2 as atan2, + arctanh as atanh, + left_shift as bitwise_left_shift, + right_shift as bitwise_right_shift, + invert as bitwise_invert, + power as pow, + # Other + concatenate as concat, + +) + diff --git a/dask-skips.txt b/dask-skips.txt new file mode 100644 index 00000000..e69de29b diff --git a/dask-xfails.txt b/dask-xfails.txt new file mode 100644 index 00000000..cc8b39d0 --- /dev/null +++ b/dask-xfails.txt @@ -0,0 +1,47 @@ +# finfo(float32).eps returns float32 but should return float +array_api_tests/test_data_type_functions.py::test_finfo[float32] + +# No sorting in dask +array_api_tests/test_has_names.py::test_has_names[sorting-argsort] +array_api_tests/test_has_names.py::test_has_names[sorting-sort] + +# Array methods and attributes not already on np.ndarray cannot be wrapped +array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] +array_api_tests/test_has_names.py::test_has_names[array_method-to_device] +array_api_tests/test_has_names.py::test_has_names[array_attribute-device] +array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] + +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] - AssertionError: out[0]=0, but should be (x1 + x2[0])=65536 [__add__()] +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_bitwise_and(ctx=BinaryParamContext(<__and__(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the ... +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_bitwise_right_shift(ctx=BinaryParamContext(<__rshift__(x1, x2)>), data=data(...)) produces unreliable results: Falsif... +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 994.44ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)] - ValueError: Inferred dtype from function 'xor' was 'uint64' but got 'int16', which can't be cast using casting='same_kind' +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_ceil - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions) +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_divide(ctx=BinaryParamContext(), data=data(...)) produces unreliable results: Falsified on the first ... +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_floor - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions) +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1015.43ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 900.99ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1106.49ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_greater(ctx=BinaryParamContext(<__gt__(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the first... +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_greater_equal(ctx=BinaryParamContext(), data=data(...)) produces unreliable results: Falsified... +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_greater_equal(ctx=BinaryParamContext(<__ge__(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the... +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 961.84ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 980.68ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1043.47ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1011.76ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)] - AssertionError: out[0]=0, but should be (x1 * x2[0])=256 [multiply()] +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)] - AssertionError: out[0]=2, but should be (x1 * x2[0])=258 [__mul__()] +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)] - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions) +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1034.64ms, which exceeds the deadline of 800.00ms +#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_trunc - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions) + +#FAILED array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error - Failed: DID NOT RAISE + +# Fails because shape is NaN since we don't materialize it yet +#FAILED array_api_tests/test_searching_functions.py::test_nonzero - AssertionError: prod(out[0].shape)=nan, but should be prod(out[0].shape)=nan +#FAILED array_api_tests/test_set_functions.py::test_unique_all - AssertionError: out.indices.shape=(nan,), but should be out.values.shape=(nan,) +#FAILED array_api_tests/test_set_functions.py::test_unique_counts - AssertionError: out.counts.shape=(nan,), but should be out.values.shape=(nan,) + +# Needs investigation +#FAILED array_api_tests/test_set_functions.py::test_unique_inverse - TypeError: 'float' object cannot be interpreted as an integer +#FAILED array_api_tests/test_set_functions.py::test_unique_values - TypeError: 'float' object cannot be interpreted as an integer diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 1675377d..4012f212 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -5,8 +5,7 @@ import pytest - -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) @pytest.mark.parametrize("api_version", [None, '2021.12']) def test_array_namespace(library, api_version): lib = import_(library) @@ -17,6 +16,8 @@ def test_array_namespace(library, api_version): if 'array_api' in library: assert namespace == lib else: + if library == "dask.array": + library = "dask" assert namespace == getattr(array_api_compat, library) diff --git a/tests/test_common.py b/tests/test_common.py index 86886b7f..45d8030a 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -5,7 +5,7 @@ import numpy as np from numpy.testing import assert_allclose -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask"]) def test_to_device_host(library): # different libraries have different semantics # for DtoH transfers; ensure that we support a portable diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index d164699e..6a40140d 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -64,7 +64,7 @@ def isdtype_(dtype_, kind): assert type(res) is bool return res -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask"]) def test_isdtype_spec_dtypes(library): xp = import_('array_api_compat.' + library) @@ -98,7 +98,7 @@ def test_isdtype_spec_dtypes(library): 'bfloat16', ] -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask"]) @pytest.mark.parametrize("dtype_", additional_dtypes) def test_isdtype_additional_dtypes(library, dtype_): xp = import_('array_api_compat.' + library) diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 93a961aa..44b89a74 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -17,3 +17,7 @@ def test_vendoring_cupy(): def test_vendoring_torch(): from vendor_test import uses_torch uses_torch._test_torch() + +def test_vendoring_torch(): + from vendor_test import uses_torch + uses_torch._test_torch() diff --git a/vendor_test/uses_dask.py b/vendor_test/uses_dask.py new file mode 100644 index 00000000..d376bdf1 --- /dev/null +++ b/vendor_test/uses_dask.py @@ -0,0 +1,19 @@ +# Basic test that vendoring works + +from .vendored._compat import dask as dask_compat + +import dask.array as da +import numpy as np + +def _test_numpy(): + a = dask_compat.asarray([1., 2., 3.]) + b = dask_compat.arange(3, dtype=dask_compat.float32) + + # np.pow does not exist. Update this to use something else if it is added + res = dask_compat.pow(a, b) + assert res.dtype == dask_compat.float64 == np.float64 + assert isinstance(a, da.array) + assert isinstance(b, da.array) + assert isinstance(res, da.array) + + np.testing.assert_allclose(res, [1., 2., 9.]) From 3cc285712487cad434ca9ac491dd43e547609496 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Tue, 19 Dec 2023 08:00:39 -0800 Subject: [PATCH 02/22] fix workflows? --- .github/workflows/array-api-tests-dask.yml | 1 + .github/workflows/array-api-tests.yml | 5 ++++- .github/workflows/tests.yml | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index 0b516ddb..2557abc6 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -7,3 +7,4 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: dask + extra-requires: numpy diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index bdc7e9da..7ab41ec8 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -6,6 +6,9 @@ on: package-name: required: true type: string + package-extras: + required: false + type: string package-version: required: false type: string @@ -54,7 +57,7 @@ jobs: if: "! (matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" run: | python -m pip install --upgrade pip - python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' + python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.package-extras }} python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt - name: Run the array API testsuite (${{ inputs.package-name }}) if: "! (matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index aec709b4..45239bcd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,7 +15,7 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest numpy torch + python -m pip install pytest numpy torch dask[array] - name: Run Tests run: | From ad6bf5601a52376493908934146f2f9a29f4c6a9 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Tue, 19 Dec 2023 08:03:10 -0800 Subject: [PATCH 03/22] typo --- .github/workflows/array-api-tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 7ab41ec8..11037c80 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -6,7 +6,7 @@ on: package-name: required: true type: string - package-extras: + extra-requires: required: false type: string package-version: @@ -57,7 +57,7 @@ jobs: if: "! (matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" run: | python -m pip install --upgrade pip - python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.package-extras }} + python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }} python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt - name: Run the array API testsuite (${{ inputs.package-name }}) if: "! (matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" From f178e28a325b4386e8882b55bb544c14331a0bda Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Thu, 21 Dec 2023 21:49:30 -0800 Subject: [PATCH 04/22] green --- .github/workflows/array-api-tests-dask.yml | 1 + array_api_compat/common/_aliases.py | 3 +- array_api_compat/dask/_aliases.py | 4 +- array_api_compat/dask/linalg.py | 34 +++++ dask-skips.txt | 5 + dask-xfails.txt | 154 ++++++++++++++++----- 6 files changed, 166 insertions(+), 35 deletions(-) create mode 100644 array_api_compat/dask/linalg.py diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index 2557abc6..78bf72c7 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -8,3 +8,4 @@ jobs: with: package-name: dask extra-requires: numpy + pytest-extra-args: --disable-deadline diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 41da4d64..73362c67 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -329,7 +329,8 @@ def _asarray( if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"): if dtype is not None and obj.dtype != dtype: copy = True - if copy in COPY_TRUE: + # Dask arrays are immutable, so copy doesn't do anything + if copy in COPY_TRUE and namespace != "dask": return xp.array(obj, copy=True, dtype=dtype) return obj diff --git a/array_api_compat/dask/_aliases.py b/array_api_compat/dask/_aliases.py index 93fe4594..506eb1f6 100644 --- a/array_api_compat/dask/_aliases.py +++ b/array_api_compat/dask/_aliases.py @@ -37,6 +37,7 @@ # Common aliases arange = get_xp(da)(_aliases.arange) +eye = get_xp(da)(_aliases.eye) from functools import partial asarray = partial(_aliases._asarray, namespace='dask') @@ -66,8 +67,6 @@ matrix_transpose = get_xp(da)(_aliases.matrix_transpose) vecdot = get_xp(da)(_aliases.vecdot) - - from dask.array import ( # Element wise aliases arccos as acos, @@ -83,6 +82,5 @@ power as pow, # Other concatenate as concat, - ) diff --git a/array_api_compat/dask/linalg.py b/array_api_compat/dask/linalg.py new file mode 100644 index 00000000..82113ee3 --- /dev/null +++ b/array_api_compat/dask/linalg.py @@ -0,0 +1,34 @@ +from dask.array.linalg import * +from dask.array.linalg import __all__ as linalg_all + +from ..common import _linalg +from .._internal import get_xp +from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) + +import dask.array as da + +cross = get_xp(da)(_linalg.cross) +outer = get_xp(da)(_linalg.outer) +EighResult = _linalg.EighResult +QRResult = _linalg.QRResult +SlogdetResult = _linalg.SlogdetResult +SVDResult = _linalg.SVDResult +eigh = get_xp(da)(_linalg.eigh) +qr = get_xp(da)(_linalg.qr) +slogdet = get_xp(da)(_linalg.slogdet) +svd = get_xp(da)(_linalg.svd) +cholesky = get_xp(da)(_linalg.cholesky) +matrix_rank = get_xp(da)(_linalg.matrix_rank) +pinv = get_xp(da)(_linalg.pinv) +matrix_norm = get_xp(da)(_linalg.matrix_norm) +svdvals = get_xp(da)(_linalg.svdvals) +vector_norm = get_xp(da)(_linalg.vector_norm) +diagonal = get_xp(da)(_linalg.diagonal) +trace = get_xp(da)(_linalg.trace) + +__all__ = linalg_all + _linalg.__all__ + +del get_xp +del da +del linalg_all +del _linalg diff --git a/dask-skips.txt b/dask-skips.txt index e69de29b..901583e3 100644 --- a/dask-skips.txt +++ b/dask-skips.txt @@ -0,0 +1,5 @@ +# FFT isn't conformant +array_api_tests/test_fft.py + +# Errors with dask, also makes Dask go OOM +array_api_tests/test_creation_functions.py::test_arange diff --git a/dask-xfails.txt b/dask-xfails.txt index cc8b39d0..e63d0517 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -1,9 +1,38 @@ +# This fails in dask +# import dask.array as da +# a = da.array([1]).reshape((1,1)) +# key = (0, slice(None, None, -1)) +# a[key] = da.array([1]) + +# Failing hypothesis test case +#x=dask.array +#| Draw 1 (key): (slice(None, None, None), slice(None, None, None)) +#| Draw 2 (value): dask.array + +# TODO: this also skips test_setitem_masking unnecessarily +array_api_tests/test_array_object.py::test_setitem + +# Various indexing errors +array_api_tests/test_array_object.py::test_getitem_masking + + +# asarray(copy=False) is not yet implemented +# copied from numpy xfails, TODO: should this pass with dask? +array_api_tests/test_creation_functions.py::test_asarray_arrays + +# zero division error, and typeerror: tuple indices must be integers or slices not tuple +array_api_tests/test_creation_functions.py::test_eye + # finfo(float32).eps returns float32 but should return float array_api_tests/test_data_type_functions.py::test_finfo[float32] # No sorting in dask array_api_tests/test_has_names.py::test_has_names[sorting-argsort] array_api_tests/test_has_names.py::test_has_names[sorting-sort] +array_api_tests/test_sorting_functions.py::test_argsort +array_api_tests/test_sorting_functions.py::test_sort +array_api_tests/test_signatures.py::test_func_signature[argsort] +array_api_tests/test_signatures.py::test_func_signature[sort] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] @@ -11,37 +40,100 @@ array_api_tests/test_has_names.py::test_has_names[array_method-to_device] array_api_tests/test_has_names.py::test_has_names[array_attribute-device] array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] - AssertionError: out[0]=0, but should be (x1 + x2[0])=65536 [__add__()] -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_bitwise_and(ctx=BinaryParamContext(<__and__(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the ... -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_bitwise_right_shift(ctx=BinaryParamContext(<__rshift__(x1, x2)>), data=data(...)) produces unreliable results: Falsif... -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 994.44ms, which exceeds the deadline of 800.00ms -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)] - ValueError: Inferred dtype from function 'xor' was 'uint64' but got 'int16', which can't be cast using casting='same_kind' -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_ceil - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions) -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_divide(ctx=BinaryParamContext(), data=data(...)) produces unreliable results: Falsified on the first ... -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_floor - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions) -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1015.43ms, which exceeds the deadline of 800.00ms -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 900.99ms, which exceeds the deadline of 800.00ms -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1106.49ms, which exceeds the deadline of 800.00ms -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_greater(ctx=BinaryParamContext(<__gt__(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the first... -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_greater_equal(ctx=BinaryParamContext(), data=data(...)) produces unreliable results: Falsified... -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] - hypothesis.errors.Flaky: Hypothesis test_greater_equal(ctx=BinaryParamContext(<__ge__(x1, x2)>), data=data(...)) produces unreliable results: Falsified on the... -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 961.84ms, which exceeds the deadline of 800.00ms -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 980.68ms, which exceeds the deadline of 800.00ms -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1043.47ms, which exceeds the deadline of 800.00ms -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1011.76ms, which exceeds the deadline of 800.00ms -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)] - AssertionError: out[0]=0, but should be (x1 * x2[0])=256 [multiply()] -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)] - AssertionError: out[0]=2, but should be (x1 * x2[0])=258 [__mul__()] -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)] - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions) -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] - hypothesis.errors.DeadlineExceeded: Test took 1034.64ms, which exceeds the deadline of 800.00ms -#FAILED array_api_tests/test_operators_and_elementwise_functions.py::test_trunc - exceptiongroup.ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions) - -#FAILED array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error - Failed: DID NOT RAISE +# dask doesn't return int when input is already int for ceil/floor/trunc +# Use $ to denote end of regex so we don't xfail other tests accidentally +array_api_tests/test_operators_and_elementwise_functions.py::test_ceil +# TODO: this xfails more than it should ... (e.g. test_floor_divide works) +array_api_tests/test_operators_and_elementwise_functions.py::test_floor +array_api_tests/test_operators_and_elementwise_functions.py::test_trunc + +# Dask doesn't raise an error for this test +array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error # Fails because shape is NaN since we don't materialize it yet -#FAILED array_api_tests/test_searching_functions.py::test_nonzero - AssertionError: prod(out[0].shape)=nan, but should be prod(out[0].shape)=nan -#FAILED array_api_tests/test_set_functions.py::test_unique_all - AssertionError: out.indices.shape=(nan,), but should be out.values.shape=(nan,) -#FAILED array_api_tests/test_set_functions.py::test_unique_counts - AssertionError: out.counts.shape=(nan,), but should be out.values.shape=(nan,) +array_api_tests/test_searching_functions.py::test_nonzero +array_api_tests/test_set_functions.py::test_unique_all +array_api_tests/test_set_functions.py::test_unique_counts + +# Different error but same cause as above, we're just trying to do ndindex on nan shape +array_api_tests/test_set_functions.py::test_unique_inverse +array_api_tests/test_set_functions.py::test_unique_values + +# Linalg failures (signature failures/missing methods) +array_api_tests/test_has_names.py::test_has_names[linalg-cross] +array_api_tests/test_has_names.py::test_has_names[linalg-det] +array_api_tests/test_has_names.py::test_has_names[linalg-diagonal] +array_api_tests/test_has_names.py::test_has_names[linalg-eigh] +array_api_tests/test_has_names.py::test_has_names[linalg-eigvalsh] +array_api_tests/test_has_names.py::test_has_names[linalg-matmul] +array_api_tests/test_has_names.py::test_has_names[linalg-matrix_norm] +array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power] +array_api_tests/test_has_names.py::test_has_names[linalg-matrix_rank] +array_api_tests/test_has_names.py::test_has_names[linalg-matrix_transpose] +array_api_tests/test_has_names.py::test_has_names[linalg-outer] +array_api_tests/test_has_names.py::test_has_names[linalg-pinv] +array_api_tests/test_has_names.py::test_has_names[linalg-slogdet] +array_api_tests/test_has_names.py::test_has_names[linalg-svdvals] +array_api_tests/test_has_names.py::test_has_names[linalg-tensordot] +array_api_tests/test_has_names.py::test_has_names[linalg-trace] +array_api_tests/test_has_names.py::test_has_names[linalg-vecdot] +array_api_tests/test_has_names.py::test_has_names[linalg-vector_norm] +array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] +array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] +array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] +array_api_tests/test_linalg.py::test_cross +array_api_tests/test_linalg.py::test_det +array_api_tests/test_linalg.py::test_diagonal +array_api_tests/test_linalg.py::test_eigvalsh +array_api_tests/test_linalg.py::test_matrix_norm +array_api_tests/test_linalg.py::test_matrix_rank +array_api_tests/test_linalg.py::test_outer +array_api_tests/test_linalg.py::test_pinv +array_api_tests/test_linalg.py::test_slogdet +array_api_tests/test_linalg.py::test_svdvals +array_api_tests/test_linalg.py::test_tensordot +array_api_tests/test_linalg.py::test_trace +array_api_tests/test_linalg.py::test_cholesky +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cholesky] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.det] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.diagonal] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigh] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigvalsh] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matmul] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_norm] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_power] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_rank] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_transpose] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.outer] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.pinv] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.qr] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.slogdet] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.svd] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.svdvals] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.tensordot] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.trace] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vector_norm] +# errors +array_api_tests/test_linalg.py::test_matrix_power +array_api_tests/test_linalg.py::test_qr +array_api_tests/test_linalg.py::test_solve +array_api_tests/test_linalg.py::test_svd + +# Missing dlpack stuff +array_api_tests/test_signatures.py::test_func_signature[from_dlpack] +array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] +array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] +array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__] +array_api_tests/test_signatures.py::test_array_method_signature[to_device] + +# Some cases unsupported by dask +array_api_tests/test_manipulation_functions.py::test_roll + +# Dtype doesn't match (output is float32 but should be float64) +array_api_tests/test_statistical_functions.py::test_prod +array_api_tests/test_statistical_functions.py::test_sum -# Needs investigation -#FAILED array_api_tests/test_set_functions.py::test_unique_inverse - TypeError: 'float' object cannot be interpreted as an integer -#FAILED array_api_tests/test_set_functions.py::test_unique_values - TypeError: 'float' object cannot be interpreted as an integer +# No mT on dask array +array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices From 8c9c784b394ffc5e1ccf7bec444608f7c2d5ccad Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Tue, 26 Dec 2023 10:51:12 -0800 Subject: [PATCH 05/22] go for green --- array_api_compat/_dask_ci_shim.py | 13 +++++++++++++ dask-xfails.txt | 11 ++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 array_api_compat/_dask_ci_shim.py diff --git a/array_api_compat/_dask_ci_shim.py b/array_api_compat/_dask_ci_shim.py new file mode 100644 index 00000000..cafbe78b --- /dev/null +++ b/array_api_compat/_dask_ci_shim.py @@ -0,0 +1,13 @@ +""" +A little CI shim for the dask backend that +disables the dask scheduler +""" +import dask +dask.config.set(scheduler='synchronous') + +from dask.distributed import Client +_client = Client() +print(_client.dashboard_link) + +from .dask import * +from .dask import __array_api_version__ diff --git a/dask-xfails.txt b/dask-xfails.txt index e63d0517..0b593ca3 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -15,7 +15,6 @@ array_api_tests/test_array_object.py::test_setitem # Various indexing errors array_api_tests/test_array_object.py::test_getitem_masking - # asarray(copy=False) is not yet implemented # copied from numpy xfails, TODO: should this pass with dask? array_api_tests/test_creation_functions.py::test_asarray_arrays @@ -26,6 +25,16 @@ array_api_tests/test_creation_functions.py::test_eye # finfo(float32).eps returns float32 but should return float array_api_tests/test_data_type_functions.py::test_finfo[float32] +# shape mismatch +array_api_tests/test_creation_functions.py::test_linspace + +# out=-0, but should be +0 +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] + +# output is nan but should be infinity +array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] + # No sorting in dask array_api_tests/test_has_names.py::test_has_names[sorting-argsort] array_api_tests/test_has_names.py::test_has_names[sorting-sort] From 6305d7e99bdbc55a34243fe91ddc09244053c65d Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Tue, 26 Dec 2023 13:45:48 -0800 Subject: [PATCH 06/22] try again --- dask-xfails.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dask-xfails.txt b/dask-xfails.txt index 0b593ca3..0fe23042 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -30,10 +30,11 @@ array_api_tests/test_creation_functions.py::test_linspace # out=-0, but should be +0 array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] # output is nan but should be infinity array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] # No sorting in dask array_api_tests/test_has_names.py::test_has_names[sorting-argsort] From 06afd7521c7cec5981fd546230bc0d0a1a9783e8 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Tue, 2 Jan 2024 07:37:33 -0800 Subject: [PATCH 07/22] fix arange test --- array_api_compat/_dask_ci_shim.py | 4 +++- array_api_compat/common/_helpers.py | 1 - array_api_compat/dask/_aliases.py | 36 ++++++++++++++++++++++++++++- dask-skips.txt | 3 --- dask-xfails.txt | 3 ++- tests/test_vendoring.py | 6 ++--- vendor_test/uses_dask.py | 8 +++---- 7 files changed, 47 insertions(+), 14 deletions(-) diff --git a/array_api_compat/_dask_ci_shim.py b/array_api_compat/_dask_ci_shim.py index cafbe78b..82e0955e 100644 --- a/array_api_compat/_dask_ci_shim.py +++ b/array_api_compat/_dask_ci_shim.py @@ -1,13 +1,15 @@ """ A little CI shim for the dask backend that disables the dask scheduler + +It also lets you see the dask dashboard for debugging +at http://127.0.0.1:8787/status """ import dask dask.config.set(scheduler='synchronous') from dask.distributed import Client _client = Client() -print(_client.dashboard_link) from .dask import * from .dask import __array_api_version__ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 94cbb902..2a9f75f2 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -47,7 +47,6 @@ def _is_dask_array(x): import dask.array - # TODO: Should we reject ndarray subclasses? return isinstance(x, dask.array.Array) def is_array_api_obj(x): diff --git a/array_api_compat/dask/_aliases.py b/array_api_compat/dask/_aliases.py index 506eb1f6..22491b2a 100644 --- a/array_api_compat/dask/_aliases.py +++ b/array_api_compat/dask/_aliases.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from ..common import _aliases +from ..common._helpers import _check_device from .._internal import get_xp @@ -30,13 +33,44 @@ result_type, ) +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Union + from ..common._typing import ndarray, Device, Dtype + import dask.array as da isdtype = get_xp(np)(_aliases.isdtype) astype = _aliases.astype # Common aliases -arange = get_xp(da)(_aliases.arange) + +# This arange func is modified from the common one to +# not pass stop/step as keyword arguments, which will cause +# an error with dask +def dask_arange( + start: Union[int, float], + /, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + xp, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs +) -> ndarray: + _check_device(xp, device) + args = [start] + if stop is not None: + args.append(stop) + else: + # stop is None, so start is actually stop + # prepend the default value for start which is 0 + args.insert(0, 0) + args.append(step) + return xp.arange(*args, dtype=dtype, **kwargs) + +arange = get_xp(da)(dask_arange) eye = get_xp(da)(_aliases.eye) from functools import partial diff --git a/dask-skips.txt b/dask-skips.txt index 901583e3..4a06d23d 100644 --- a/dask-skips.txt +++ b/dask-skips.txt @@ -1,5 +1,2 @@ # FFT isn't conformant array_api_tests/test_fft.py - -# Errors with dask, also makes Dask go OOM -array_api_tests/test_creation_functions.py::test_arange diff --git a/dask-xfails.txt b/dask-xfails.txt index 0fe23042..976de938 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -25,7 +25,8 @@ array_api_tests/test_creation_functions.py::test_eye # finfo(float32).eps returns float32 but should return float array_api_tests/test_data_type_functions.py::test_finfo[float32] -# shape mismatch +# out[-1]=dask.aray but should be some floating number +# (I think the test is not forcing the op to be computed?) array_api_tests/test_creation_functions.py::test_linspace # out=-0, but should be +0 diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 44b89a74..873b233a 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -18,6 +18,6 @@ def test_vendoring_torch(): from vendor_test import uses_torch uses_torch._test_torch() -def test_vendoring_torch(): - from vendor_test import uses_torch - uses_torch._test_torch() +def test_vendoring_dask(): + from vendor_test import uses_dask + uses_dask._test_dask() diff --git a/vendor_test/uses_dask.py b/vendor_test/uses_dask.py index d376bdf1..222dcc3f 100644 --- a/vendor_test/uses_dask.py +++ b/vendor_test/uses_dask.py @@ -5,15 +5,15 @@ import dask.array as da import numpy as np -def _test_numpy(): +def _test_dask(): a = dask_compat.asarray([1., 2., 3.]) b = dask_compat.arange(3, dtype=dask_compat.float32) # np.pow does not exist. Update this to use something else if it is added res = dask_compat.pow(a, b) assert res.dtype == dask_compat.float64 == np.float64 - assert isinstance(a, da.array) - assert isinstance(b, da.array) - assert isinstance(res, da.array) + assert isinstance(a, da.Array) + assert isinstance(b, da.Array) + assert isinstance(res, da.Array) np.testing.assert_allclose(res, [1., 2., 9.]) From 6e0ef29e0d7cde55c7771148c8ef38dfa667f286 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Tue, 2 Jan 2024 12:28:49 -0800 Subject: [PATCH 08/22] use max-examples=1 --- .github/workflows/array-api-tests-dask.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index 78bf72c7..c05d801b 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -8,4 +8,4 @@ jobs: with: package-name: dask extra-requires: numpy - pytest-extra-args: --disable-deadline + pytest-extra-args: --disable-deadline --max-examples=1 From 9abe56d8f7c444a016799c6ad9fd9f27232ecd8b Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Wed, 3 Jan 2024 09:46:05 -0800 Subject: [PATCH 09/22] Update array-api-tests-dask.yml --- .github/workflows/array-api-tests-dask.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index c05d801b..548c9009 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -8,4 +8,4 @@ jobs: with: package-name: dask extra-requires: numpy - pytest-extra-args: --disable-deadline --max-examples=1 + pytest-extra-args: --disable-deadline --max-examples=5 From 20622ba4d8cae826ebe1730b31aacd08e4442663 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Wed, 3 Jan 2024 17:36:06 -0800 Subject: [PATCH 10/22] fix missing check --- array_api_compat/common/_helpers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 2a9f75f2..f513a71e 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -56,6 +56,7 @@ def is_array_api_obj(x): return _is_numpy_array(x) \ or _is_cupy_array(x) \ or _is_torch_array(x) \ + or _is_dask_array(x) \ or hasattr(x, '__array_namespace__') def _check_api_version(api_version): From 0d661602331323c130fda4de868f97fd14d6bde5 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Tue, 9 Jan 2024 09:49:26 -0800 Subject: [PATCH 11/22] change namespace --- array_api_compat/_dask_ci_shim.py | 15 --------------- array_api_compat/common/_aliases.py | 4 +--- array_api_compat/dask/{ => array}/__init__.py | 0 array_api_compat/dask/{ => array}/_aliases.py | 8 ++++---- array_api_compat/dask/{ => array}/linalg.py | 0 tests/test_common.py | 2 +- tests/test_isdtype.py | 4 ++-- vendor_test/.DS_Store | Bin 0 -> 6148 bytes vendor_test/uses_dask.py | 2 +- 9 files changed, 9 insertions(+), 26 deletions(-) delete mode 100644 array_api_compat/_dask_ci_shim.py rename array_api_compat/dask/{ => array}/__init__.py (100%) rename array_api_compat/dask/{ => array}/_aliases.py (94%) rename array_api_compat/dask/{ => array}/linalg.py (100%) create mode 100644 vendor_test/.DS_Store diff --git a/array_api_compat/_dask_ci_shim.py b/array_api_compat/_dask_ci_shim.py deleted file mode 100644 index 82e0955e..00000000 --- a/array_api_compat/_dask_ci_shim.py +++ /dev/null @@ -1,15 +0,0 @@ -""" -A little CI shim for the dask backend that -disables the dask scheduler - -It also lets you see the dask dashboard for debugging -at http://127.0.0.1:8787/status -""" -import dask -dask.config.set(scheduler='synchronous') - -from dask.distributed import Client -_client = Client() - -from .dask import * -from .dask import __array_api_version__ diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 73362c67..f35160e0 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -324,13 +324,11 @@ def _asarray( if copy in COPY_FALSE: # copy=False is not yet implemented in xp.asarray raise NotImplementedError("copy=False is not yet implemented") - # TODO: This feels wrong (__array__ is not in the standard) - # Dask doesn't support DLPack, though, so, this'll do if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"): if dtype is not None and obj.dtype != dtype: copy = True # Dask arrays are immutable, so copy doesn't do anything - if copy in COPY_TRUE and namespace != "dask": + if copy in COPY_TRUE and namespace != "dask.array": return xp.array(obj, copy=True, dtype=dtype) return obj diff --git a/array_api_compat/dask/__init__.py b/array_api_compat/dask/array/__init__.py similarity index 100% rename from array_api_compat/dask/__init__.py rename to array_api_compat/dask/array/__init__.py diff --git a/array_api_compat/dask/_aliases.py b/array_api_compat/dask/array/_aliases.py similarity index 94% rename from array_api_compat/dask/_aliases.py rename to array_api_compat/dask/array/_aliases.py index 22491b2a..c3ea021f 100644 --- a/array_api_compat/dask/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,9 +1,9 @@ from __future__ import annotations -from ..common import _aliases -from ..common._helpers import _check_device +from ...common import _aliases +from ...common._helpers import _check_device -from .._internal import get_xp +from ..._internal import get_xp import numpy as np from numpy import ( @@ -36,7 +36,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Union - from ..common._typing import ndarray, Device, Dtype + from ...common._typing import ndarray, Device, Dtype import dask.array as da diff --git a/array_api_compat/dask/linalg.py b/array_api_compat/dask/array/linalg.py similarity index 100% rename from array_api_compat/dask/linalg.py rename to array_api_compat/dask/array/linalg.py diff --git a/tests/test_common.py b/tests/test_common.py index 45d8030a..f98a717a 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -5,7 +5,7 @@ import numpy as np from numpy.testing import assert_allclose -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) def test_to_device_host(library): # different libraries have different semantics # for DtoH transfers; ensure that we support a portable diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index 6a40140d..77e7ce72 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -64,7 +64,7 @@ def isdtype_(dtype_, kind): assert type(res) is bool return res -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) def test_isdtype_spec_dtypes(library): xp = import_('array_api_compat.' + library) @@ -98,7 +98,7 @@ def test_isdtype_spec_dtypes(library): 'bfloat16', ] -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) @pytest.mark.parametrize("dtype_", additional_dtypes) def test_isdtype_additional_dtypes(library, dtype_): xp = import_('array_api_compat.' + library) diff --git a/vendor_test/.DS_Store b/vendor_test/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..c804b5ef87d13f697bfc8a1cea378e7b7b3d830c GIT binary patch literal 6148 zcmeHK!AiqG5S@)hQhF(P@wmUxi|4h(FX#tUn}|I$O%N132>Aq$p8W;!=+Xc1M+D!@ zENw&DLNB5+1G8^4nb}OZ!h-gSeDU>nkquC;SpS2@NwD3(QW8`#2InC*s zrX6oH{6+?N?gp5xDb@6Vxm(ZQba6A9%`5_j#bUh23R_Y^72Tp&v@@l*Z9bbhzJ5J^K0Qm0ow3Slo!vgac3%Abm>2`b zfHClA7(mTtNe7D78Ux0FF|cBQ{|_F@7$P=`^6Nk&M*v_CW)bvdF9h^#0EUQ-A}kQ6 zp+F6F@rdCx9Cjc5Lc~Tv>7 Date: Tue, 9 Jan 2024 09:52:35 -0800 Subject: [PATCH 12/22] changes left behind --- array_api_compat/common/_helpers.py | 2 +- array_api_compat/dask/array/linalg.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index f513a71e..f0d1180f 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -110,7 +110,7 @@ def your_function(x, y): elif _is_dask_array(x): _check_api_version(api_version) if _use_compat: - from .. import dask as dask_namespace + from ..dask import array as dask_namespace namespaces.add(dask_namespace) else: raise TypeError("_use_compat cannot be False if input array is a dask array!") diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 82113ee3..e660621a 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -1,8 +1,8 @@ from dask.array.linalg import * from dask.array.linalg import __all__ as linalg_all -from ..common import _linalg -from .._internal import get_xp +from ...common import _linalg +from ..._internal import get_xp from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) import dask.array as da From 1f9799aa676ad578839f7c6f046a6ce22124aac6 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Wed, 10 Jan 2024 07:28:43 -0800 Subject: [PATCH 13/22] finish rename and fix linalg tests --- .github/workflows/array-api-tests-dask.yml | 2 +- array_api_compat/common/_linalg.py | 3 +- array_api_compat/dask/array/__init__.py | 2 + array_api_compat/dask/array/_aliases.py | 3 + array_api_compat/dask/array/linalg.py | 34 +++++---- dask-skips.txt | 2 - dask.array-skips.txt | 5 ++ dask-xfails.txt => dask.array-xfails.txt | 83 ++++++++++------------ 8 files changed, 70 insertions(+), 64 deletions(-) delete mode 100644 dask-skips.txt create mode 100644 dask.array-skips.txt rename dask-xfails.txt => dask.array-xfails.txt (76%) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index 548c9009..2d60ca63 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -6,6 +6,6 @@ jobs: array-api-tests-dask: uses: ./.github/workflows/array-api-tests.yml with: - package-name: dask + package-name: dask.array extra-requires: numpy pytest-extra-args: --disable-deadline --max-examples=5 diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 97c584be..ccabd82c 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -74,7 +74,8 @@ def matrix_rank(x: ndarray, # dimensional arrays. if x.ndim < 2: raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") - S = xp.linalg.svd(x, compute_uv=False, **kwargs) + S = xp.linalg.svdvals(x, **kwargs) + #S = xp.linalg.svd(x, compute_uv=False, **kwargs) if rtol is None: tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps else: diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index 12e3a92d..a7c0b22e 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -4,3 +4,5 @@ from ._aliases import * __array_api_version__ = '2022.12' + +__import__(__package__ + '.linalg') diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index c3ea021f..8be3a566 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -48,6 +48,8 @@ # This arange func is modified from the common one to # not pass stop/step as keyword arguments, which will cause # an error with dask + +# TODO: delete the xp stuff, it shouldn't be necessary def dask_arange( start: Union[int, float], /, @@ -118,3 +120,4 @@ def dask_arange( concatenate as concat, ) +del da diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index e660621a..a4407cd1 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -1,34 +1,42 @@ -from dask.array.linalg import * -from dask.array.linalg import __all__ as linalg_all +from __future__ import annotations +from dask.array.linalg import * from ...common import _linalg from ..._internal import get_xp -from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) +from dask.array import matmul, tensordot, trace, outer +from ._aliases import matrix_transpose, vecdot import dask.array as da -cross = get_xp(da)(_linalg.cross) -outer = get_xp(da)(_linalg.outer) +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Union, Tuple + from ...common._typing import ndarray, Device, Dtype + +#cross = get_xp(da)(_linalg.cross) +#outer = get_xp(da)(_linalg.outer) EighResult = _linalg.EighResult QRResult = _linalg.QRResult SlogdetResult = _linalg.SlogdetResult SVDResult = _linalg.SVDResult -eigh = get_xp(da)(_linalg.eigh) qr = get_xp(da)(_linalg.qr) -slogdet = get_xp(da)(_linalg.slogdet) -svd = get_xp(da)(_linalg.svd) +#svd = get_xp(da)(_linalg.svd) cholesky = get_xp(da)(_linalg.cholesky) matrix_rank = get_xp(da)(_linalg.matrix_rank) -pinv = get_xp(da)(_linalg.pinv) +#pinv = get_xp(da)(_linalg.pinv) matrix_norm = get_xp(da)(_linalg.matrix_norm) -svdvals = get_xp(da)(_linalg.svdvals) + +def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: + # TODO: can't avoid computing U or V for dask + _, s, _ = svd(x) + return s + vector_norm = get_xp(da)(_linalg.vector_norm) diagonal = get_xp(da)(_linalg.diagonal) -trace = get_xp(da)(_linalg.trace) -__all__ = linalg_all + _linalg.__all__ +#__all__ = linalg_all + _linalg.__all__ del get_xp del da -del linalg_all +#del linalg_all del _linalg diff --git a/dask-skips.txt b/dask-skips.txt deleted file mode 100644 index 4a06d23d..00000000 --- a/dask-skips.txt +++ /dev/null @@ -1,2 +0,0 @@ -# FFT isn't conformant -array_api_tests/test_fft.py diff --git a/dask.array-skips.txt b/dask.array-skips.txt new file mode 100644 index 00000000..8e884ac0 --- /dev/null +++ b/dask.array-skips.txt @@ -0,0 +1,5 @@ +# FFT isn't conformant +array_api_tests/test_fft.py + +# slow and not implemented in dask +array_api_tests/test_linalg.py::test_matrix_power diff --git a/dask-xfails.txt b/dask.array-xfails.txt similarity index 76% rename from dask-xfails.txt rename to dask.array-xfails.txt index 976de938..2a5330f6 100644 --- a/dask-xfails.txt +++ b/dask.array-xfails.txt @@ -71,65 +71,51 @@ array_api_tests/test_set_functions.py::test_unique_inverse array_api_tests/test_set_functions.py::test_unique_values # Linalg failures (signature failures/missing methods) -array_api_tests/test_has_names.py::test_has_names[linalg-cross] -array_api_tests/test_has_names.py::test_has_names[linalg-det] -array_api_tests/test_has_names.py::test_has_names[linalg-diagonal] -array_api_tests/test_has_names.py::test_has_names[linalg-eigh] -array_api_tests/test_has_names.py::test_has_names[linalg-eigvalsh] -array_api_tests/test_has_names.py::test_has_names[linalg-matmul] -array_api_tests/test_has_names.py::test_has_names[linalg-matrix_norm] -array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power] -array_api_tests/test_has_names.py::test_has_names[linalg-matrix_rank] -array_api_tests/test_has_names.py::test_has_names[linalg-matrix_transpose] -array_api_tests/test_has_names.py::test_has_names[linalg-outer] -array_api_tests/test_has_names.py::test_has_names[linalg-pinv] -array_api_tests/test_has_names.py::test_has_names[linalg-slogdet] -array_api_tests/test_has_names.py::test_has_names[linalg-svdvals] -array_api_tests/test_has_names.py::test_has_names[linalg-tensordot] -array_api_tests/test_has_names.py::test_has_names[linalg-trace] -array_api_tests/test_has_names.py::test_has_names[linalg-vecdot] -array_api_tests/test_has_names.py::test_has_names[linalg-vector_norm] -array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] -array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] -array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] -array_api_tests/test_linalg.py::test_cross -array_api_tests/test_linalg.py::test_det -array_api_tests/test_linalg.py::test_diagonal -array_api_tests/test_linalg.py::test_eigvalsh -array_api_tests/test_linalg.py::test_matrix_norm -array_api_tests/test_linalg.py::test_matrix_rank -array_api_tests/test_linalg.py::test_outer -array_api_tests/test_linalg.py::test_pinv -array_api_tests/test_linalg.py::test_slogdet + + +# fails for ndim > 2 array_api_tests/test_linalg.py::test_svdvals +array_api_tests/test_linalg.py::test_cholesky +# dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :( array_api_tests/test_linalg.py::test_tensordot +# probably same reason for failing as numpy array_api_tests/test_linalg.py::test_trace -array_api_tests/test_linalg.py::test_cholesky -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cholesky] + +# Linalg - these don't exist in dask array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.det] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.diagonal] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigh] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigvalsh] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matmul] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_norm] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_power] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_rank] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_transpose] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.outer] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.pinv] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.qr] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.slogdet] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.svd] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.svdvals] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.tensordot] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.trace] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vector_norm] -# errors -array_api_tests/test_linalg.py::test_matrix_power +array_api_tests/test_linalg.py::test_cross +array_api_tests/test_linalg.py::test_det +array_api_tests/test_linalg.py::test_eigvalsh +array_api_tests/test_linalg.py::test_pinv +array_api_tests/test_linalg.py::test_slogdet +array_api_tests/test_has_names.py::test_has_names[linalg-cross] +array_api_tests/test_has_names.py::test_has_names[linalg-det] +array_api_tests/test_has_names.py::test_has_names[linalg-eigh] +array_api_tests/test_has_names.py::test_has_names[linalg-eigvalsh] +array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power] +array_api_tests/test_has_names.py::test_has_names[linalg-pinv] +array_api_tests/test_has_names.py::test_has_names[linalg-slogdet] + +array_api_tests/test_linalg.py::test_matrix_norm +array_api_tests/test_linalg.py::test_matrix_rank + +# missing mode kw +# https://github.com/dask/dask/issues/10388 array_api_tests/test_linalg.py::test_qr + +# Constructing the input arrays fails to a weird shape error... array_api_tests/test_linalg.py::test_solve + +# missing full_matrics kw +# https://github.com/dask/dask/issues/10389 +# also only supports 2-d inputs +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.svd] array_api_tests/test_linalg.py::test_svd # Missing dlpack stuff @@ -138,6 +124,9 @@ array_api_tests/test_signatures.py::test_array_method_signature[__array_namespac array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__] array_api_tests/test_signatures.py::test_array_method_signature[to_device] +array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] +array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] +array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] # Some cases unsupported by dask array_api_tests/test_manipulation_functions.py::test_roll From 424f25d852e40a67bad70ffd6bbbaccb9ef7a221 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Wed, 10 Jan 2024 15:05:34 -0800 Subject: [PATCH 14/22] fix ci? --- .github/workflows/array-api-tests-dask.yml | 3 ++- .github/workflows/array-api-tests.yml | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index 2d60ca63..b0ce007e 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -6,6 +6,7 @@ jobs: array-api-tests-dask: uses: ./.github/workflows/array-api-tests.yml with: - package-name: dask.array + package-name: dask + module-name: dask.array extra-requires: numpy pytest-extra-args: --disable-deadline --max-examples=5 diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 11037c80..d3ce2739 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -6,6 +6,9 @@ on: package-name: required: true type: string + module-name: + required: false + type: string extra-requires: required: false type: string @@ -62,7 +65,7 @@ jobs: - name: Run the array API testsuite (${{ inputs.package-name }}) if: "! (matrix.python-version == '3.11' && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" env: - ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.package-name }} + ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }} # This enables the NEP 50 type promotion behavior (without it a lot of # tests fail on bad scalar type promotion behavior) NPY_PROMOTION_STATE: weak From 64d20c42526ec6dd36ed4e1d8e5ee82fbb027d8f Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Wed, 10 Jan 2024 15:07:58 -0800 Subject: [PATCH 15/22] rename files again --- dask.array-skips.txt => dask-skips.txt | 0 dask.array-xfails.txt => dask-xfails.txt | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename dask.array-skips.txt => dask-skips.txt (100%) rename dask.array-xfails.txt => dask-xfails.txt (100%) diff --git a/dask.array-skips.txt b/dask-skips.txt similarity index 100% rename from dask.array-skips.txt rename to dask-skips.txt diff --git a/dask.array-xfails.txt b/dask-xfails.txt similarity index 100% rename from dask.array-xfails.txt rename to dask-xfails.txt From 762a03c0f7179d79ed3f451ead8a50d4a37ca10c Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Wed, 10 Jan 2024 17:30:15 -0800 Subject: [PATCH 16/22] some fixes --- array_api_compat/common/_aliases.py | 2 +- array_api_compat/dask/array/_aliases.py | 2 +- tests/test_array_namespace.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index f35160e0..cc09a0a6 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -303,7 +303,7 @@ def _asarray( import numpy as xp elif namespace == 'cupy': import cupy as xp - elif namespace == 'dask': + elif namespace == 'dask.array': import dask.array as xp else: raise ValueError("Unrecognized namespace argument to asarray()") diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 8be3a566..bea488cc 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -76,7 +76,7 @@ def dask_arange( eye = get_xp(da)(_aliases.eye) from functools import partial -asarray = partial(_aliases._asarray, namespace='dask') +asarray = partial(_aliases._asarray, namespace='dask.array') asarray.__doc__ = _aliases._asarray.__doc__ linspace = get_xp(da)(_aliases.linspace) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 4012f212..0becfc3d 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -17,8 +17,9 @@ def test_array_namespace(library, api_version): assert namespace == lib else: if library == "dask.array": - library = "dask" - assert namespace == getattr(array_api_compat, library) + assert namespace == array_api_compat.dask.array + else: + assert namespace == getattr(array_api_compat, library) def test_array_namespace_errors(): From 69cc93b55ad92f07dc1abcc83ecf4b34c071f076 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Wed, 17 Jan 2024 21:27:47 -0800 Subject: [PATCH 17/22] fix astype bug --- array_api_compat/common/_aliases.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index cc09a0a6..db0fa623 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -327,9 +327,16 @@ def _asarray( if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"): if dtype is not None and obj.dtype != dtype: copy = True - # Dask arrays are immutable, so copy doesn't do anything - if copy in COPY_TRUE and namespace != "dask.array": - return xp.array(obj, copy=True, dtype=dtype) + if copy in COPY_TRUE: + copy_kwargs = {} + if namespace != "dask.array": + copy_kwargs["copy"] = True + else: + # No copy kw in dask.asarray so we go thorugh np.asarray first + # (like dask also does) but copy after + import numpy as np + obj = np.asarray(obj).copy() + return xp.array(obj, dtype=dtype, **copy_kwargs) return obj return xp.asarray(obj, dtype=dtype, **kwargs) From 5edc5ec43d65532f49e720bba793f47743b59175 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Wed, 24 Jan 2024 11:14:01 -0500 Subject: [PATCH 18/22] remove file --- vendor_test/.DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 vendor_test/.DS_Store diff --git a/vendor_test/.DS_Store b/vendor_test/.DS_Store deleted file mode 100644 index c804b5ef87d13f697bfc8a1cea378e7b7b3d830c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK!AiqG5S@)hQhF(P@wmUxi|4h(FX#tUn}|I$O%N132>Aq$p8W;!=+Xc1M+D!@ zENw&DLNB5+1G8^4nb}OZ!h-gSeDU>nkquC;SpS2@NwD3(QW8`#2InC*s zrX6oH{6+?N?gp5xDb@6Vxm(ZQba6A9%`5_j#bUh23R_Y^72Tp&v@@l*Z9bbhzJ5J^K0Qm0ow3Slo!vgac3%Abm>2`b zfHClA7(mTtNe7D78Ux0FF|cBQ{|_F@7$P=`^6Nk&M*v_CW)bvdF9h^#0EUQ-A}kQ6 zp+F6F@rdCx9Cjc5Lc~Tv>7 Date: Sun, 28 Jan 2024 11:15:54 -0500 Subject: [PATCH 19/22] address more comments --- array_api_compat/common/_aliases.py | 8 ++++++++ array_api_compat/common/_linalg.py | 6 ++++-- array_api_compat/dask/array/_aliases.py | 7 ++++++- array_api_compat/dask/array/linalg.py | 7 ------- dask-xfails.txt | 10 +++++++++- 5 files changed, 27 insertions(+), 11 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 425f3e46..7713213e 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -325,8 +325,10 @@ def _asarray( # copy=False is not yet implemented in xp.asarray raise NotImplementedError("copy=False is not yet implemented") if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"): + #print('hit me') if dtype is not None and obj.dtype != dtype: copy = True + #print(copy) if copy in COPY_TRUE: copy_kwargs = {} if namespace != "dask.array": @@ -334,8 +336,14 @@ def _asarray( else: # No copy kw in dask.asarray so we go thorugh np.asarray first # (like dask also does) but copy after + if dtype is None: + # Same dtype copy is no-op in dask + #print("in here?") + return obj.copy() import numpy as np + #print(obj) obj = np.asarray(obj).copy() + #print(obj) return xp.array(obj, dtype=dtype, **copy_kwargs) return obj diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index d8b3e208..22f9c484 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -77,8 +77,10 @@ def matrix_rank(x: ndarray, # dimensional arrays. if x.ndim < 2: raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") - S = xp.linalg.svdvals(x, **kwargs) - #S = xp.linalg.svd(x, compute_uv=False, **kwargs) + if hasattr(xp.linalg, "svdvals"): + S = xp.linalg.svdvals(x, **kwargs) + else: + S = xp.linalg.svd(x, compute_uv=False, **kwargs) if rtol is None: tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps else: diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index bea488cc..687d3694 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -120,4 +120,9 @@ def dask_arange( concatenate as concat, ) -del da +del da, partial + +__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', + 'acosh', 'asin', 'asinh', 'atan', 'atan2', + 'atanh', 'bitwise_left_shift', 'bitwise_invert', + 'bitwise_right_shift', 'concat', 'pow'] diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index a4407cd1..baa0ae7f 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -13,17 +13,13 @@ from typing import Optional, Union, Tuple from ...common._typing import ndarray, Device, Dtype -#cross = get_xp(da)(_linalg.cross) -#outer = get_xp(da)(_linalg.outer) EighResult = _linalg.EighResult QRResult = _linalg.QRResult SlogdetResult = _linalg.SlogdetResult SVDResult = _linalg.SVDResult qr = get_xp(da)(_linalg.qr) -#svd = get_xp(da)(_linalg.svd) cholesky = get_xp(da)(_linalg.cholesky) matrix_rank = get_xp(da)(_linalg.matrix_rank) -#pinv = get_xp(da)(_linalg.pinv) matrix_norm = get_xp(da)(_linalg.matrix_norm) def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: @@ -34,9 +30,6 @@ def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: vector_norm = get_xp(da)(_linalg.vector_norm) diagonal = get_xp(da)(_linalg.diagonal) -#__all__ = linalg_all + _linalg.__all__ - del get_xp del da -#del linalg_all del _linalg diff --git a/dask-xfails.txt b/dask-xfails.txt index 2a5330f6..d3bcd885 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -9,9 +9,17 @@ #| Draw 1 (key): (slice(None, None, None), slice(None, None, None)) #| Draw 2 (value): dask.array -# TODO: this also skips test_setitem_masking unnecessarily +# Various shape mismatches e.g. +ValueError: shape mismatch: value array of shape (0, 2) could not be broadcast to indexing result of shape (0, 2) array_api_tests/test_array_object.py::test_setitem +# Fails since bad upcast from uint8 -> int64 +# MRE: +# a = da.array(0, dtype="uint8") +# b = da.array(False) +# a[b] = 0 +array_api_tests/test_array_object.py::test_setitem_masking + # Various indexing errors array_api_tests/test_array_object.py::test_getitem_masking From 4be5517374425ea48cd5f7c62a2d299bae7e4e22 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Sun, 28 Jan 2024 12:45:58 -0500 Subject: [PATCH 20/22] fix more tests --- array_api_compat/dask/array/_aliases.py | 27 ++++++++++++++++++++----- dask-xfails.txt | 15 -------------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 687d3694..ef9ea356 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -103,6 +103,15 @@ def dask_arange( matrix_transpose = get_xp(da)(_aliases.matrix_transpose) vecdot = get_xp(da)(_aliases.vecdot) +nonzero = get_xp(da)(_aliases.nonzero) +sum = get_xp(np)(_aliases.sum) +prod = get_xp(np)(_aliases.prod) +ceil = get_xp(np)(_aliases.ceil) +floor = get_xp(np)(_aliases.floor) +trunc = get_xp(np)(_aliases.trunc) +matmul = get_xp(np)(_aliases.matmul) +tensordot = get_xp(np)(_aliases.tensordot) + from dask.array import ( # Element wise aliases arccos as acos, @@ -120,9 +129,17 @@ def dask_arange( concatenate as concat, ) -del da, partial +# exclude these from all since +_da_unsupported = ['sort', 'argsort'] + +common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported] + +__all__ = common_aliases + ['asarray', 'bool', 'acos', + 'acosh', 'asin', 'asinh', 'atan', 'atan2', + 'atanh', 'bitwise_left_shift', 'bitwise_invert', + 'bitwise_right_shift', 'concat', 'pow', + 'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8', + 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', + 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type'] -__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] +del da, partial, common_aliases, _da_unsupported, diff --git a/dask-xfails.txt b/dask-xfails.txt index d3bcd885..39a1dd8a 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -59,16 +59,6 @@ array_api_tests/test_has_names.py::test_has_names[array_method-to_device] array_api_tests/test_has_names.py::test_has_names[array_attribute-device] array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] -# dask doesn't return int when input is already int for ceil/floor/trunc -# Use $ to denote end of regex so we don't xfail other tests accidentally -array_api_tests/test_operators_and_elementwise_functions.py::test_ceil -# TODO: this xfails more than it should ... (e.g. test_floor_divide works) -array_api_tests/test_operators_and_elementwise_functions.py::test_floor -array_api_tests/test_operators_and_elementwise_functions.py::test_trunc - -# Dask doesn't raise an error for this test -array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error - # Fails because shape is NaN since we don't materialize it yet array_api_tests/test_searching_functions.py::test_nonzero array_api_tests/test_set_functions.py::test_unique_all @@ -80,7 +70,6 @@ array_api_tests/test_set_functions.py::test_unique_values # Linalg failures (signature failures/missing methods) - # fails for ndim > 2 array_api_tests/test_linalg.py::test_svdvals array_api_tests/test_linalg.py::test_cholesky @@ -139,9 +128,5 @@ array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__ # Some cases unsupported by dask array_api_tests/test_manipulation_functions.py::test_roll -# Dtype doesn't match (output is float32 but should be float64) -array_api_tests/test_statistical_functions.py::test_prod -array_api_tests/test_statistical_functions.py::test_sum - # No mT on dask array array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices From cd381a016e4726305596b445852292d681144959 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Mon, 29 Jan 2024 21:07:19 -0500 Subject: [PATCH 21/22] address more feedback --- array_api_compat/common/_linalg.py | 5 +---- array_api_compat/dask/array/linalg.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 22f9c484..e5bc17d4 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -77,10 +77,7 @@ def matrix_rank(x: ndarray, # dimensional arrays. if x.ndim < 2: raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") - if hasattr(xp.linalg, "svdvals"): - S = xp.linalg.svdvals(x, **kwargs) - else: - S = xp.linalg.svd(x, compute_uv=False, **kwargs) + S = get_xp(xp).linalg.svdvals(x, **kwargs) if rtol is None: tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps else: diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index baa0ae7f..c8aa7c9f 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -10,8 +10,17 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Union, Tuple - from ...common._typing import ndarray, Device, Dtype + from typing import Union, Tuple + from ...common._typing import ndarray + +# cupy.linalg doesn't have __all__. If it is added, replace this with +# +# from cupy.linalg import __all__ as linalg_all +_n = {} +exec('from dask.array.linalg import *', _n) +del _n['__builtins__'] +linalg_all = list(_n) +del _n EighResult = _linalg.EighResult QRResult = _linalg.QRResult @@ -30,6 +39,10 @@ def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]: vector_norm = get_xp(da)(_linalg.vector_norm) diagonal = get_xp(da)(_linalg.diagonal) +__all__ = linalg_all + ["EighResult", "QRResult", "SlogdetResult", + "SVDResult", "qr", "cholesky", "matrix_rank", "matrix_norm", + "svdvals", "vector_norm", "diagonal"] + del get_xp del da del _linalg From 54f4838fd63ef4e48a811f32325eb522a7e9d014 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Tue, 30 Jan 2024 10:33:47 -0500 Subject: [PATCH 22/22] fix? --- array_api_compat/common/_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index e5bc17d4..9f0c993f 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -77,7 +77,7 @@ def matrix_rank(x: ndarray, # dimensional arrays. if x.ndim < 2: raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") - S = get_xp(xp).linalg.svdvals(x, **kwargs) + S = get_xp(xp)(svdvals)(x, **kwargs) if rtol is None: tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps else: