From 4347fa9db34cc3ee663e4e76805edc92d44e4777 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Wed, 17 Aug 2022 17:38:53 -0400 Subject: [PATCH] Add option to test with FFTW backend --- .github/workflows/CI.yml | 9 ++ Project.toml | 3 +- test/{testplans.jl => TestPlans.jl} | 7 + test/runtests.jl | 239 ++-------------------------- test/testfft.jl | 237 +++++++++++++++++++++++++++ 5 files changed, 264 insertions(+), 231 deletions(-) rename test/{testplans.jl => TestPlans.jl} (99%) create mode 100644 test/testfft.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 8d43117..1c43e97 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -22,6 +22,12 @@ jobs: - windows-latest arch: - x64 + group: + - TestPlans + - FFTW + exclude: + - version: '1.0' + group: FFTW steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 @@ -40,7 +46,10 @@ jobs: ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + env: + GROUP: ${{ matrix.group }} - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v1 with: file: lcov.info + flag-name: group-${{ matrix.group }} # unique name for coverage report of each group diff --git a/Project.toml b/Project.toml index a639c5d..1f7c02d 100644 --- a/Project.toml +++ b/Project.toml @@ -12,9 +12,10 @@ julia = "^1.0" [extras] ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["ChainRulesTestUtils", "Random", "Test", "Unitful"] +test = ["ChainRulesTestUtils", "FFTW", "Random", "Test", "Unitful"] diff --git a/test/testplans.jl b/test/TestPlans.jl similarity index 99% rename from test/testplans.jl rename to test/TestPlans.jl index 7abecfe..9658230 100644 --- a/test/testplans.jl +++ b/test/TestPlans.jl @@ -1,3 +1,8 @@ +module TestPlans + +using AbstractFFTs +using AbstractFFTs: Plan + mutable struct TestPlan{T,N} <: Plan{T} region sz::NTuple{N,Int} @@ -226,3 +231,5 @@ function Base.:*(p::InverseTestRPlan, x::AbstractArray) return y end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 623d625..870646a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,3 @@ -# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license - using AbstractFFTs using AbstractFFTs: Plan using ChainRulesTestUtils @@ -12,235 +10,16 @@ import Unitful Random.seed!(1234) -include("testplans.jl") - -@testset "rfft sizes" begin - A = rand(11, 10) - @test @inferred(AbstractFFTs.rfft_output_size(A, 1)) == (6, 10) - @test @inferred(AbstractFFTs.rfft_output_size(A, 2)) == (11, 6) - A1 = rand(6, 10); A2 = rand(11, 6) - @test @inferred(AbstractFFTs.brfft_output_size(A1, 11, 1)) == (11, 10) - @test @inferred(AbstractFFTs.brfft_output_size(A2, 10, 2)) == (11, 10) - @test_throws AssertionError AbstractFFTs.brfft_output_size(A1, 10, 2) -end - -@testset "Custom Plan" begin - # DFT along last dimension, results computed using FFTW - for (x, fftw_fft) in ( - (collect(1:7), - [28.0 + 0.0im, - -3.5 + 7.267824888003178im, - -3.5 + 2.7911568610884143im, - -3.5 + 0.7988521603655248im, - -3.5 - 0.7988521603655248im, - -3.5 - 2.7911568610884143im, - -3.5 - 7.267824888003178im]), - (collect(1:8), - [36.0 + 0.0im, - -4.0 + 9.65685424949238im, - -4.0 + 4.0im, - -4.0 + 1.6568542494923806im, - -4.0 + 0.0im, - -4.0 - 1.6568542494923806im, - -4.0 - 4.0im, - -4.0 - 9.65685424949238im]), - (collect(reshape(1:8, 2, 4)), - [16.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im; - 20.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im]), - (collect(reshape(1:9, 3, 3)), - [12.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im; - 15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im; - 18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]), - ) - # FFT - dims = ndims(x) - y = AbstractFFTs.fft(x, dims) - @test y ≈ fftw_fft - P = plan_fft(x, dims) - @test eltype(P) === ComplexF64 - @test P * x ≈ fftw_fft - @test P \ (P * x) ≈ x - @test fftdims(P) == dims - - fftw_bfft = complex.(size(x, dims) .* x) - @test AbstractFFTs.bfft(y, dims) ≈ fftw_bfft - P = plan_bfft(x, dims) - @test P * y ≈ fftw_bfft - @test P \ (P * y) ≈ y - @test fftdims(P) == dims - - fftw_ifft = complex.(x) - @test AbstractFFTs.ifft(y, dims) ≈ fftw_ifft - P = plan_ifft(x, dims) - @test P * y ≈ fftw_ifft - @test P \ (P * y) ≈ y - @test fftdims(P) == dims - - # real FFT - fftw_rfft = fftw_fft[ - (Colon() for _ in 1:(ndims(fftw_fft) - 1))..., - 1:(size(fftw_fft, ndims(fftw_fft)) ÷ 2 + 1) - ] - ry = AbstractFFTs.rfft(x, dims) - @test ry ≈ fftw_rfft - P = plan_rfft(x, dims) - @test eltype(P) === Int - @test P * x ≈ fftw_rfft - @test P \ (P * x) ≈ x - @test fftdims(P) == dims - - fftw_brfft = complex.(size(x, dims) .* x) - @test AbstractFFTs.brfft(ry, size(x, dims), dims) ≈ fftw_brfft - P = plan_brfft(ry, size(x, dims), dims) - @test P * ry ≈ fftw_brfft - @test P \ (P * ry) ≈ ry - @test fftdims(P) == dims +const GROUP = get(ENV, "GROUP", "All") - fftw_irfft = complex.(x) - @test AbstractFFTs.irfft(ry, size(x, dims), dims) ≈ fftw_irfft - P = plan_irfft(ry, size(x, dims), dims) - @test P * ry ≈ fftw_irfft - @test P \ (P * ry) ≈ ry - @test fftdims(P) == dims - end -end - -@testset "Shift functions" begin - @test @inferred(AbstractFFTs.fftshift([1 2 3])) == [3 1 2] - @test @inferred(AbstractFFTs.fftshift([1, 2, 3])) == [3, 1, 2] - @test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6])) == [6 4 5; 3 1 2] - a = [0 0 0] - b = [0, 0, 0] - c = [0 0 0; 0 0 0] - @test (AbstractFFTs.fftshift!(a, [1 2 3]); a == [3 1 2]) - @test (AbstractFFTs.fftshift!(b, [1, 2, 3]); b == [3, 1, 2]) - @test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6]); c == [6 4 5; 3 1 2]) - - @test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], 1)) == [4 5 6; 1 2 3] - @test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], ())) == [1 2 3; 4 5 6] - @test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], (1,2))) == [6 4 5; 3 1 2] - @test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], 1:2)) == [6 4 5; 3 1 2] - @test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], 1); c == [4 5 6; 1 2 3]) - @test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], ()); c == [1 2 3; 4 5 6]) - @test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], (1,2)); c == [6 4 5; 3 1 2]) - @test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], 1:2); c == [6 4 5; 3 1 2]) - - @test @inferred(AbstractFFTs.ifftshift([1 2 3])) == [2 3 1] - @test @inferred(AbstractFFTs.ifftshift([1, 2, 3])) == [2, 3, 1] - @test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6])) == [5 6 4; 2 3 1] - @test (AbstractFFTs.ifftshift!(a, [1 2 3]); a == [2 3 1]) - @test (AbstractFFTs.ifftshift!(b, [1, 2, 3]); b == [2, 3, 1]) - @test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6]); c == [5 6 4; 2 3 1]) +include("TestPlans.jl") +include("testfft.jl") - @test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1)) == [4 5 6; 1 2 3] - @test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], ())) == [1 2 3; 4 5 6] - @test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], (1,2))) == [5 6 4; 2 3 1] - @test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1:2)) == [5 6 4; 2 3 1] - @test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], 1); c == [4 5 6; 1 2 3]) - @test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], ()); c == [1 2 3; 4 5 6]) - @test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], (1,2)); c == [5 6 4; 2 3 1]) - @test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], 1:2); c == [5 6 4; 2 3 1]) +if GROUP == "All" || GROUP == "TestPlans" + using .TestPlans + testfft() +elseif GROUP == "All" || GROUP == "FFTW" # integration test with FFTW + using FFTW + testfft() end -@testset "FFT Frequencies" begin - @test fftfreq(8) isa Frequencies - @test copy(fftfreq(8)) isa Frequencies - - # N even - @test fftfreq(8) == [0.0, 0.125, 0.25, 0.375, -0.5, -0.375, -0.25, -0.125] - @test rfftfreq(8) == [0.0, 0.125, 0.25, 0.375, 0.5] - @test fftshift(fftfreq(8)) == -0.5:0.125:0.375 - - # N odd - @test fftfreq(5) == [0.0, 0.2, 0.4, -0.4, -0.2] - @test rfftfreq(5) == [0.0, 0.2, 0.4] - @test fftshift(fftfreq(5)) == -0.4:0.2:0.4 - - # Sampling Frequency - @test fftfreq(5, 2) == [0.0, 0.4, 0.8, -0.8, -0.4] - # <:Number type compatibility - @test eltype(fftfreq(5, ComplexF64(2))) == ComplexF64 - - @test_throws ArgumentError Frequencies(12, 10, 1) - - @testset "scaling" begin - @test fftfreq(4, 1) * 2 === fftfreq(4, 2) - @test fftfreq(4, 1) .* 2 === fftfreq(4, 2) - @test 2 * fftfreq(4, 1) === fftfreq(4, 2) - @test 2 .* fftfreq(4, 1) === fftfreq(4, 2) - - @test fftfreq(4, 1) / 2 === fftfreq(4, 1/2) - @test fftfreq(4, 1) ./ 2 === fftfreq(4, 1/2) - - @test 2 \ fftfreq(4, 1) === fftfreq(4, 1/2) - @test 2 .\ fftfreq(4, 1) === fftfreq(4, 1/2) - end - - @testset "extrema" begin - function check_extrema(freqs) - for f in [minimum, maximum, extrema] - @test f(freqs) == f(collect(freqs)) == f(fftshift(freqs)) - end - end - for f in (fftfreq, rfftfreq), n in (8, 9), multiplier in (2, 1/3, -1/7, 1.0*Unitful.mm) - freqs = f(n, multiplier) - check_extrema(freqs) - end - end -end - -@testset "normalization" begin - # normalization should be inferable even if region is only inferred as ::Any, - # need to wrap in another function to test this (note that p.region::Any for - # p::TestPlan) - f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, fftdims(p)) - @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 -end - -@testset "ChainRules" begin - @testset "shift functions" begin - for x in (randn(3), randn(3, 4), randn(3, 4, 5)) - for dims in ((), 1, 2, (1,2), 1:2) - any(d > ndims(x) for d in dims) && continue - - # type inference checks of `rrule` fail on old Julia versions - # for higher-dimensional arrays: - # https://github.com/JuliaMath/AbstractFFTs.jl/pull/58#issuecomment-916530016 - check_inferred = ndims(x) < 3 || VERSION >= v"1.6" - - test_frule(AbstractFFTs.fftshift, x, dims) - test_rrule(AbstractFFTs.fftshift, x, dims; check_inferred=check_inferred) - - test_frule(AbstractFFTs.ifftshift, x, dims) - test_rrule(AbstractFFTs.ifftshift, x, dims; check_inferred=check_inferred) - end - end - end - - @testset "fft" begin - for x in (randn(3), randn(3, 4), randn(3, 4, 5)) - N = ndims(x) - complex_x = complex.(x) - for dims in unique((1, 1:N, N)) - for f in (fft, ifft, bfft) - test_frule(f, x, dims) - test_rrule(f, x, dims) - test_frule(f, complex_x, dims) - test_rrule(f, complex_x, dims) - end - - test_frule(rfft, x, dims) - test_rrule(rfft, x, dims) - - for f in (irfft, brfft) - for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) - test_frule(f, x, d, dims) - test_rrule(f, x, d, dims) - test_frule(f, complex_x, d, dims) - test_rrule(f, complex_x, d, dims) - end - end - end - end - end -end diff --git a/test/testfft.jl b/test/testfft.jl new file mode 100644 index 0000000..a460c46 --- /dev/null +++ b/test/testfft.jl @@ -0,0 +1,237 @@ +# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license + +# Runs AbstractFFTs tests, relying on an FFT backend present in the environment. +# Ideally, the backend would be passed as an argument to make this pure, but this isn't possible +# since downstream implementations perform type piracy on AbstractFFTs.plan_fft in the current design. +function testfft() + @testset "rfft sizes" begin + A = rand(11, 10) + @test @inferred(AbstractFFTs.rfft_output_size(A, 1)) == (6, 10) + @test @inferred(AbstractFFTs.rfft_output_size(A, 2)) == (11, 6) + A1 = rand(6, 10); A2 = rand(11, 6) + @test @inferred(AbstractFFTs.brfft_output_size(A1, 11, 1)) == (11, 10) + @test @inferred(AbstractFFTs.brfft_output_size(A2, 10, 2)) == (11, 10) + @test_throws AssertionError AbstractFFTs.brfft_output_size(A1, 10, 2) + end + + @testset "Custom Plan" begin + # DFT along last dimension, results computed using FFTW + for (x, fftw_fft) in ( + (collect(1:7), + [28.0 + 0.0im, + -3.5 + 7.267824888003178im, + -3.5 + 2.7911568610884143im, + -3.5 + 0.7988521603655248im, + -3.5 - 0.7988521603655248im, + -3.5 - 2.7911568610884143im, + -3.5 - 7.267824888003178im]), + (collect(1:8), + [36.0 + 0.0im, + -4.0 + 9.65685424949238im, + -4.0 + 4.0im, + -4.0 + 1.6568542494923806im, + -4.0 + 0.0im, + -4.0 - 1.6568542494923806im, + -4.0 - 4.0im, + -4.0 - 9.65685424949238im]), + (collect(reshape(1:8, 2, 4)), + [16.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im; + 20.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im]), + (collect(reshape(1:9, 3, 3)), + [12.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im; + 15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im; + 18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]), + ) + # FFT + dims = ndims(x) + y = AbstractFFTs.fft(x, dims) + @test y ≈ fftw_fft + P = plan_fft(x, dims) + @test eltype(P) === ComplexF64 + @test P * x ≈ fftw_fft + @test P \ (P * x) ≈ x + @test fftdims(P) == dims + + fftw_bfft = complex.(size(x, dims) .* x) + @test AbstractFFTs.bfft(y, dims) ≈ fftw_bfft + P = plan_bfft(x, dims) + @test P * y ≈ fftw_bfft + @test P \ (P * y) ≈ y + @test fftdims(P) == dims + + fftw_ifft = complex.(x) + @test AbstractFFTs.ifft(y, dims) ≈ fftw_ifft + P = plan_ifft(x, dims) + @test P * y ≈ fftw_ifft + @test P \ (P * y) ≈ y + @test fftdims(P) == dims + + # real FFT + fftw_rfft = fftw_fft[ + (Colon() for _ in 1:(ndims(fftw_fft) - 1))..., + 1:(size(fftw_fft, ndims(fftw_fft)) ÷ 2 + 1) + ] + ry = AbstractFFTs.rfft(x, dims) + @test ry ≈ fftw_rfft + P = plan_rfft(x, dims) + @test eltype(P) <: Real + @test P * x ≈ fftw_rfft + @test P \ (P * x) ≈ x + @test fftdims(P) == dims + + fftw_brfft = complex.(size(x, dims) .* x) + @test AbstractFFTs.brfft(ry, size(x, dims), dims) ≈ fftw_brfft + P = plan_brfft(ry, size(x, dims), dims) + @test P * ry ≈ fftw_brfft + @test P \ (P * ry) ≈ ry + @test fftdims(P) == dims + + fftw_irfft = complex.(x) + @test AbstractFFTs.irfft(ry, size(x, dims), dims) ≈ fftw_irfft + P = plan_irfft(ry, size(x, dims), dims) + @test P * ry ≈ fftw_irfft + @test P \ (P * ry) ≈ ry + @test fftdims(P) == dims + end + end + + @testset "Shift functions" begin + @test @inferred(AbstractFFTs.fftshift([1 2 3])) == [3 1 2] + @test @inferred(AbstractFFTs.fftshift([1, 2, 3])) == [3, 1, 2] + @test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6])) == [6 4 5; 3 1 2] + a = [0 0 0] + b = [0, 0, 0] + c = [0 0 0; 0 0 0] + @test (AbstractFFTs.fftshift!(a, [1 2 3]); a == [3 1 2]) + @test (AbstractFFTs.fftshift!(b, [1, 2, 3]); b == [3, 1, 2]) + @test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6]); c == [6 4 5; 3 1 2]) + + @test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], 1)) == [4 5 6; 1 2 3] + @test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], ())) == [1 2 3; 4 5 6] + @test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], (1,2))) == [6 4 5; 3 1 2] + @test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], 1:2)) == [6 4 5; 3 1 2] + @test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], 1); c == [4 5 6; 1 2 3]) + @test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], ()); c == [1 2 3; 4 5 6]) + @test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], (1,2)); c == [6 4 5; 3 1 2]) + @test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], 1:2); c == [6 4 5; 3 1 2]) + + @test @inferred(AbstractFFTs.ifftshift([1 2 3])) == [2 3 1] + @test @inferred(AbstractFFTs.ifftshift([1, 2, 3])) == [2, 3, 1] + @test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6])) == [5 6 4; 2 3 1] + @test (AbstractFFTs.ifftshift!(a, [1 2 3]); a == [2 3 1]) + @test (AbstractFFTs.ifftshift!(b, [1, 2, 3]); b == [2, 3, 1]) + @test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6]); c == [5 6 4; 2 3 1]) + + @test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1)) == [4 5 6; 1 2 3] + @test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], ())) == [1 2 3; 4 5 6] + @test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], (1,2))) == [5 6 4; 2 3 1] + @test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1:2)) == [5 6 4; 2 3 1] + @test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], 1); c == [4 5 6; 1 2 3]) + @test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], ()); c == [1 2 3; 4 5 6]) + @test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], (1,2)); c == [5 6 4; 2 3 1]) + @test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], 1:2); c == [5 6 4; 2 3 1]) + end + + @testset "FFT Frequencies" begin + @test fftfreq(8) isa Frequencies + @test copy(fftfreq(8)) isa Frequencies + + # N even + @test fftfreq(8) == [0.0, 0.125, 0.25, 0.375, -0.5, -0.375, -0.25, -0.125] + @test rfftfreq(8) == [0.0, 0.125, 0.25, 0.375, 0.5] + @test fftshift(fftfreq(8)) == -0.5:0.125:0.375 + + # N odd + @test fftfreq(5) == [0.0, 0.2, 0.4, -0.4, -0.2] + @test rfftfreq(5) == [0.0, 0.2, 0.4] + @test fftshift(fftfreq(5)) == -0.4:0.2:0.4 + + # Sampling Frequency + @test fftfreq(5, 2) == [0.0, 0.4, 0.8, -0.8, -0.4] + # <:Number type compatibility + @test eltype(fftfreq(5, ComplexF64(2))) == ComplexF64 + + @test_throws ArgumentError Frequencies(12, 10, 1) + + @testset "scaling" begin + @test fftfreq(4, 1) * 2 === fftfreq(4, 2) + @test fftfreq(4, 1) .* 2 === fftfreq(4, 2) + @test 2 * fftfreq(4, 1) === fftfreq(4, 2) + @test 2 .* fftfreq(4, 1) === fftfreq(4, 2) + + @test fftfreq(4, 1) / 2 === fftfreq(4, 1/2) + @test fftfreq(4, 1) ./ 2 === fftfreq(4, 1/2) + + @test 2 \ fftfreq(4, 1) === fftfreq(4, 1/2) + @test 2 .\ fftfreq(4, 1) === fftfreq(4, 1/2) + end + + @testset "extrema" begin + function check_extrema(freqs) + for f in [minimum, maximum, extrema] + @test f(freqs) == f(collect(freqs)) == f(fftshift(freqs)) + end + end + for f in (fftfreq, rfftfreq), n in (8, 9), multiplier in (2, 1/3, -1/7, 1.0*Unitful.mm) + freqs = f(n, multiplier) + check_extrema(freqs) + end + end + end + + @testset "normalization" begin + # normalization should be inferable even if region is only inferred as ::Any, + # need to wrap in another function to test this (note that p.region::Any for + # p::TestPlan) + f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, fftdims(p)) + @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 + end + + @testset "ChainRules" begin + @testset "shift functions" begin + for x in (randn(3), randn(3, 4), randn(3, 4, 5)) + for dims in ((), 1, 2, (1,2), 1:2) + any(d > ndims(x) for d in dims) && continue + + # type inference checks of `rrule` fail on old Julia versions + # for higher-dimensional arrays: + # https://github.com/JuliaMath/AbstractFFTs.jl/pull/58#issuecomment-916530016 + check_inferred = ndims(x) < 3 || VERSION >= v"1.6" + + test_frule(AbstractFFTs.fftshift, x, dims) + test_rrule(AbstractFFTs.fftshift, x, dims; check_inferred=check_inferred) + + test_frule(AbstractFFTs.ifftshift, x, dims) + test_rrule(AbstractFFTs.ifftshift, x, dims; check_inferred=check_inferred) + end + end + end + + @testset "fft" begin + for x in (randn(3), randn(3, 4), randn(3, 4, 5)) + N = ndims(x) + complex_x = complex.(x) + for dims in unique((1, 1:N, N)) + for f in (fft, ifft, bfft) + test_frule(f, x, dims) + test_rrule(f, x, dims) + test_frule(f, complex_x, dims) + test_rrule(f, complex_x, dims) + end + + test_frule(rfft, x, dims) + test_rrule(rfft, x, dims) + + for f in (irfft, brfft) + for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) + test_frule(f, x, d, dims) + test_rrule(f, x, d, dims) + test_frule(f, complex_x, d, dims) + test_rrule(f, complex_x, d, dims) + end + end + end + end + end + end +end