Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap fft for dask #139

Merged
merged 22 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/array-api-tests-dask.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ jobs:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: dask
package-version: '>= 2024.9.0'
module-name: dask.array
extra-requires: numpy
pytest-extra-args: --disable-deadline --max-examples=5
3 changes: 2 additions & 1 deletion .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
# min version of dask we needs drops support for python 3.9
python-version: ${{ inputs.package-name == 'dask' && fromJson('[''3.10'', ''3.11'', ''3.12'']') || fromJson('[''3.9'', ''3.10'', ''3.11'', ''3.12'']') }}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's other similar skips in some of the lines below. We should probably consolidate. I don't know if there's a cleaner way to do it but if you are aware of one let me know!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we would put these sorts of skips in the individual action files for each package, but I don't know if it's possible to do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for now, it might be easier to keep it in here (even though it's less clean), since I don't think any of the other libraries need this Python restriction.

Assuming Python 3.9 gets dropped sometime soonish, it'd probably be easier to remove in this state.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably come up with a policy, but I've tried to be very conservative in supporting old versions to make it so this package is as easy for people to adopt as possible. So I don't have immediate plans to drop Python 3.9, although we will obviously want to do so at some point especially as all the wrapped packages stop supporting it.


steps:
- name: Checkout array-api-compat
Expand Down
1 change: 1 addition & 0 deletions array_api_compat/dask/array/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
__array_api_version__ = '2022.12'

__import__(__package__ + '.linalg')
__import__(__package__ + '.fft')
33 changes: 13 additions & 20 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@

import numpy as np
from numpy import (
# Constants
e,
inf,
nan,
pi,
newaxis,
# Dtypes
iinfo,
finfo,
bool_ as bool,
float32,
float64,
Expand All @@ -29,8 +25,6 @@
uint64,
complex64,
complex128,
iinfo,
finfo,
can_cast,
result_type,
)
Expand Down Expand Up @@ -206,19 +200,18 @@ def _isscalar(a):

return astype(xp.minimum(xp.maximum(x, min), max), x.dtype)

# exclude these from all since
# exclude these from all since dask.array has no sorting functions
_da_unsupported = ['sort', 'argsort']

common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
_common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]

__all__ = common_aliases + ['__array_namespace_info__', 'asarray', 'bool',
'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow', 'e',
'inf', 'nan', 'pi', 'newaxis', 'float32',
'float64', 'int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'iinfo', 'finfo',
'can_cast', 'result_type']
__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast',
'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'iinfo', 'finfo',
'can_cast', 'result_type']

_all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np']
_all_ignore = ["get_xp", "da", "np"]
24 changes: 24 additions & 0 deletions array_api_compat/dask/array/fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from dask.array.fft import * # noqa: F403
# cupy.fft doesn't have __all__. If it is added, replace this with
#
# from cupy.fft import __all__ as linalg_all
lithomas1 marked this conversation as resolved.
Show resolved Hide resolved
_n = {}
exec('from dask.array.fft import *', _n)
del _n['__builtins__']
fft_all = list(_n)
del _n

from ...common import _fft
from ..._internal import get_xp

import dask.array as da

fftfreq = get_xp(da)(_fft.fftfreq)
rfftfreq = get_xp(da)(_fft.rfftfreq)

__all__ = [elem for elem in fft_all if elem != "annotations"] + ["fftfreq", "rfftfreq"]

del get_xp
del da
del fft_all
del _fft
15 changes: 0 additions & 15 deletions dask-skips.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,2 @@
# FFT isn't conformant
array_api_tests/test_fft.py
array_api_tests/test_signatures.py::test_extension_func_signature[fft.fft]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.ifft]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.fftn]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.ifftn]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.rfft]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.irfft]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.rfftn]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.irfftn]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.hfft]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.ihfft]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.fftfreq]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.rfftfreq]

# slow and not implemented in dask
array_api_tests/test_linalg.py::test_matrix_power
Loading