Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add integration test with FFTW backend #75

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ jobs:
- windows-latest
arch:
- x64
group:
- TestPlans
- FFTW
exclude:
- version: '1.0'
group: FFTW
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand All @@ -40,7 +46,10 @@ jobs:
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
GROUP: ${{ matrix.group }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v1
with:
file: lcov.info
flag-name: group-${{ matrix.group }} # unique name for coverage report of each group
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ julia = "^1.0"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["ChainRulesTestUtils", "Random", "Test", "Unitful"]
test = ["ChainRulesTestUtils", "FFTW", "Random", "Test", "Unitful"]
7 changes: 7 additions & 0 deletions test/testplans.jl → test/TestPlans.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
module TestPlans

using AbstractFFTs
using AbstractFFTs: Plan

mutable struct TestPlan{T,N} <: Plan{T}
region
sz::NTuple{N,Int}
Expand Down Expand Up @@ -226,3 +231,5 @@ function Base.:*(p::InverseTestRPlan, x::AbstractArray)

return y
end

end
239 changes: 9 additions & 230 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license

using AbstractFFTs
using AbstractFFTs: Plan
using ChainRulesTestUtils
Expand All @@ -12,235 +10,16 @@ import Unitful

Random.seed!(1234)

include("testplans.jl")

@testset "rfft sizes" begin
A = rand(11, 10)
@test @inferred(AbstractFFTs.rfft_output_size(A, 1)) == (6, 10)
@test @inferred(AbstractFFTs.rfft_output_size(A, 2)) == (11, 6)
A1 = rand(6, 10); A2 = rand(11, 6)
@test @inferred(AbstractFFTs.brfft_output_size(A1, 11, 1)) == (11, 10)
@test @inferred(AbstractFFTs.brfft_output_size(A2, 10, 2)) == (11, 10)
@test_throws AssertionError AbstractFFTs.brfft_output_size(A1, 10, 2)
end

@testset "Custom Plan" begin
# DFT along last dimension, results computed using FFTW
for (x, fftw_fft) in (
(collect(1:7),
[28.0 + 0.0im,
-3.5 + 7.267824888003178im,
-3.5 + 2.7911568610884143im,
-3.5 + 0.7988521603655248im,
-3.5 - 0.7988521603655248im,
-3.5 - 2.7911568610884143im,
-3.5 - 7.267824888003178im]),
(collect(1:8),
[36.0 + 0.0im,
-4.0 + 9.65685424949238im,
-4.0 + 4.0im,
-4.0 + 1.6568542494923806im,
-4.0 + 0.0im,
-4.0 - 1.6568542494923806im,
-4.0 - 4.0im,
-4.0 - 9.65685424949238im]),
(collect(reshape(1:8, 2, 4)),
[16.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im;
20.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im]),
(collect(reshape(1:9, 3, 3)),
[12.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]),
)
# FFT
dims = ndims(x)
y = AbstractFFTs.fft(x, dims)
@test y ≈ fftw_fft
P = plan_fft(x, dims)
@test eltype(P) === ComplexF64
@test P * x ≈ fftw_fft
@test P \ (P * x) ≈ x
@test fftdims(P) == dims

fftw_bfft = complex.(size(x, dims) .* x)
@test AbstractFFTs.bfft(y, dims) ≈ fftw_bfft
P = plan_bfft(x, dims)
@test P * y ≈ fftw_bfft
@test P \ (P * y) ≈ y
@test fftdims(P) == dims

fftw_ifft = complex.(x)
@test AbstractFFTs.ifft(y, dims) ≈ fftw_ifft
P = plan_ifft(x, dims)
@test P * y ≈ fftw_ifft
@test P \ (P * y) ≈ y
@test fftdims(P) == dims

# real FFT
fftw_rfft = fftw_fft[
(Colon() for _ in 1:(ndims(fftw_fft) - 1))...,
1:(size(fftw_fft, ndims(fftw_fft)) ÷ 2 + 1)
]
ry = AbstractFFTs.rfft(x, dims)
@test ry ≈ fftw_rfft
P = plan_rfft(x, dims)
@test eltype(P) === Int
@test P * x ≈ fftw_rfft
@test P \ (P * x) ≈ x
@test fftdims(P) == dims

fftw_brfft = complex.(size(x, dims) .* x)
@test AbstractFFTs.brfft(ry, size(x, dims), dims) ≈ fftw_brfft
P = plan_brfft(ry, size(x, dims), dims)
@test P * ry ≈ fftw_brfft
@test P \ (P * ry) ≈ ry
@test fftdims(P) == dims
const GROUP = get(ENV, "GROUP", "All")

fftw_irfft = complex.(x)
@test AbstractFFTs.irfft(ry, size(x, dims), dims) ≈ fftw_irfft
P = plan_irfft(ry, size(x, dims), dims)
@test P * ry ≈ fftw_irfft
@test P \ (P * ry) ≈ ry
@test fftdims(P) == dims
end
end

@testset "Shift functions" begin
@test @inferred(AbstractFFTs.fftshift([1 2 3])) == [3 1 2]
@test @inferred(AbstractFFTs.fftshift([1, 2, 3])) == [3, 1, 2]
@test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6])) == [6 4 5; 3 1 2]
a = [0 0 0]
b = [0, 0, 0]
c = [0 0 0; 0 0 0]
@test (AbstractFFTs.fftshift!(a, [1 2 3]); a == [3 1 2])
@test (AbstractFFTs.fftshift!(b, [1, 2, 3]); b == [3, 1, 2])
@test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6]); c == [6 4 5; 3 1 2])

@test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], 1)) == [4 5 6; 1 2 3]
@test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], ())) == [1 2 3; 4 5 6]
@test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], (1,2))) == [6 4 5; 3 1 2]
@test @inferred(AbstractFFTs.fftshift([1 2 3; 4 5 6], 1:2)) == [6 4 5; 3 1 2]
@test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], 1); c == [4 5 6; 1 2 3])
@test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], ()); c == [1 2 3; 4 5 6])
@test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], (1,2)); c == [6 4 5; 3 1 2])
@test (AbstractFFTs.fftshift!(c, [1 2 3; 4 5 6], 1:2); c == [6 4 5; 3 1 2])

@test @inferred(AbstractFFTs.ifftshift([1 2 3])) == [2 3 1]
@test @inferred(AbstractFFTs.ifftshift([1, 2, 3])) == [2, 3, 1]
@test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6])) == [5 6 4; 2 3 1]
@test (AbstractFFTs.ifftshift!(a, [1 2 3]); a == [2 3 1])
@test (AbstractFFTs.ifftshift!(b, [1, 2, 3]); b == [2, 3, 1])
@test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6]); c == [5 6 4; 2 3 1])
include("TestPlans.jl")
include("testfft.jl")

@test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1)) == [4 5 6; 1 2 3]
@test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], ())) == [1 2 3; 4 5 6]
@test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], (1,2))) == [5 6 4; 2 3 1]
@test @inferred(AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1:2)) == [5 6 4; 2 3 1]
@test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], 1); c == [4 5 6; 1 2 3])
@test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], ()); c == [1 2 3; 4 5 6])
@test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], (1,2)); c == [5 6 4; 2 3 1])
@test (AbstractFFTs.ifftshift!(c, [1 2 3; 4 5 6], 1:2); c == [5 6 4; 2 3 1])
if GROUP == "All" || GROUP == "TestPlans"
using .TestPlans
testfft()
elseif GROUP == "All" || GROUP == "FFTW" # integration test with FFTW
using FFTW
testfft()
end

@testset "FFT Frequencies" begin
@test fftfreq(8) isa Frequencies
@test copy(fftfreq(8)) isa Frequencies

# N even
@test fftfreq(8) == [0.0, 0.125, 0.25, 0.375, -0.5, -0.375, -0.25, -0.125]
@test rfftfreq(8) == [0.0, 0.125, 0.25, 0.375, 0.5]
@test fftshift(fftfreq(8)) == -0.5:0.125:0.375

# N odd
@test fftfreq(5) == [0.0, 0.2, 0.4, -0.4, -0.2]
@test rfftfreq(5) == [0.0, 0.2, 0.4]
@test fftshift(fftfreq(5)) == -0.4:0.2:0.4

# Sampling Frequency
@test fftfreq(5, 2) == [0.0, 0.4, 0.8, -0.8, -0.4]
# <:Number type compatibility
@test eltype(fftfreq(5, ComplexF64(2))) == ComplexF64

@test_throws ArgumentError Frequencies(12, 10, 1)

@testset "scaling" begin
@test fftfreq(4, 1) * 2 === fftfreq(4, 2)
@test fftfreq(4, 1) .* 2 === fftfreq(4, 2)
@test 2 * fftfreq(4, 1) === fftfreq(4, 2)
@test 2 .* fftfreq(4, 1) === fftfreq(4, 2)

@test fftfreq(4, 1) / 2 === fftfreq(4, 1/2)
@test fftfreq(4, 1) ./ 2 === fftfreq(4, 1/2)

@test 2 \ fftfreq(4, 1) === fftfreq(4, 1/2)
@test 2 .\ fftfreq(4, 1) === fftfreq(4, 1/2)
end

@testset "extrema" begin
function check_extrema(freqs)
for f in [minimum, maximum, extrema]
@test f(freqs) == f(collect(freqs)) == f(fftshift(freqs))
end
end
for f in (fftfreq, rfftfreq), n in (8, 9), multiplier in (2, 1/3, -1/7, 1.0*Unitful.mm)
freqs = f(n, multiplier)
check_extrema(freqs)
end
end
end

@testset "normalization" begin
# normalization should be inferable even if region is only inferred as ::Any,
# need to wrap in another function to test this (note that p.region::Any for
# p::TestPlan)
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, fftdims(p))
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
end

@testset "ChainRules" begin
@testset "shift functions" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
for dims in ((), 1, 2, (1,2), 1:2)
any(d > ndims(x) for d in dims) && continue

# type inference checks of `rrule` fail on old Julia versions
# for higher-dimensional arrays:
# https://github.com/JuliaMath/AbstractFFTs.jl/pull/58#issuecomment-916530016
check_inferred = ndims(x) < 3 || VERSION >= v"1.6"

test_frule(AbstractFFTs.fftshift, x, dims)
test_rrule(AbstractFFTs.fftshift, x, dims; check_inferred=check_inferred)

test_frule(AbstractFFTs.ifftshift, x, dims)
test_rrule(AbstractFFTs.ifftshift, x, dims; check_inferred=check_inferred)
end
end
end

@testset "fft" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
N = ndims(x)
complex_x = complex.(x)
for dims in unique((1, 1:N, N))
for f in (fft, ifft, bfft)
test_frule(f, x, dims)
test_rrule(f, x, dims)
test_frule(f, complex_x, dims)
test_rrule(f, complex_x, dims)
end

test_frule(rfft, x, dims)
test_rrule(rfft, x, dims)

for f in (irfft, brfft)
for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
test_frule(f, x, d, dims)
test_rrule(f, x, d, dims)
test_frule(f, complex_x, d, dims)
test_rrule(f, complex_x, d, dims)
end
end
end
end
end
end
Loading