Skip to content

Commit

Permalink
Detect if length of array is too much for MKL to handle, switch to FF…
Browse files Browse the repository at this point in the history
…TW if it is
  • Loading branch information
GarethCabournDavies committed Jun 14, 2023
1 parent 1c68f5b commit cf70160
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion pycbc/fft/func_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
implementations within PyCBC.
"""

import logging
from pycbc.types import TimeSeries as _TimeSeries
from pycbc.types import FrequencySeries as _FrequencySeries
from .core import _check_fft_args, _check_fwd_args, _check_inv_args
from .backend_support import get_backend
from .backend_support import get_backend, set_backend

def fft(invec, outvec):
""" Fourier transform from invec to outvec.
Expand All @@ -49,6 +50,11 @@ def fft(invec, outvec):

# The following line is where all the work is done:
backend = get_backend()
if len(invec) > 2 ** 24 and backend.__name__ == 'pycbc.fft.mkl':
logging.warning("MKL cannot handle arrays longer than 2^24 in "
"FFTs. Changing scheme to FFTW.")
set_backend(['fftw'])
backend = get_backend()
backend.fft(invec, outvec, prec, itype, otype)
# For a forward FFT, the length of the *input* vector is the length
# we should divide by, whether C2C or R2HC transform
Expand Down Expand Up @@ -79,6 +85,11 @@ def ifft(invec, outvec):

# The following line is where all the work is done:
backend = get_backend()
if len(invec) > 2 ** 24 and backend.__name__ == 'pycbc.fft.mkl':
logging.warning("MKL cannot handle arrays longer than 2^24 in "
"FFTs. Changing scheme to FFTW.")
set_backend(['fftw'])
backend = get_backend()
backend.ifft(invec, outvec, prec, itype, otype)
# For an inverse FFT, the length of the *output* vector is the length
# we should divide by, whether C2C or HC2R transform
Expand Down

0 comments on commit cf70160

Please sign in to comment.