Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to test suite to support CUDA arrays #118

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 37 additions & 31 deletions ext/AbstractFFTsTestExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ const TEST_CASES = (
)


function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray; inplace_plan=false, copy_input=false)
function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray;
inplace_plan=false, copy_input=false, test_wrappers=true)
gaurav-arya marked this conversation as resolved.
Show resolved Hide resolved
_copy = copy_input ? copy : identity
@test size(P) == size(x)
if !inplace_plan
Expand All @@ -61,7 +62,9 @@ function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transform
_x_out = similar(P * _copy(x))
@test mul!(_x_out, P, _copy(x)) ≈ x_transformed
@test _x_out ≈ x_transformed
@test P * view(_copy(x), axes(x)...) ≈ x_transformed # test view input
if test_wrappers
@test P * view(_copy(x), axes(x)...) ≈ x_transformed # test view input
end
gaurav-arya marked this conversation as resolved.
Show resolved Hide resolved
else
_x = copy(x)
@test P * _copy(_x) ≈ x_transformed
Expand All @@ -71,9 +74,10 @@ function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transform
end
end

function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; real_plan=false, copy_input=false)
function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray;
real_plan=false, copy_input=false, test_wrappers=true)
_copy = copy_input ? copy : identity
y = rand(eltype(P * _copy(x)), size(P * _copy(x)))
y = map(a -> rand(typeof(a)), P * _copy(x)) # generically construct rand array
gaurav-arya marked this conversation as resolved.
Show resolved Hide resolved
# test basic properties
@test eltype(P') === eltype(y)
@test (P')' === P # test adjoint of adjoint
Expand All @@ -87,11 +91,13 @@ function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; rea
@test _component_dot(y, P * _copy(x)) ≈ _component_dot(P' * _copy(y), x)
@test _component_dot(x, P \ _copy(y)) ≈ _component_dot(P' \ _copy(x), y)
end
@test P' * view(_copy(y), axes(y)...) ≈ P' * _copy(y) # test view input (AbstractFFTs.jl#112)
if test_wrappers
@test P' * view(_copy(y), axes(y)...) ≈ P' * _copy(y) # test view input (AbstractFFTs.jl#112)
end
@test_throws MethodError mul!(x, P', y)
end

function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true)
function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true, test_wrappers=true)
@testset "correctness of fft, bfft, ifft" begin
for test_case in TEST_CASES
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
Expand All @@ -111,18 +117,18 @@ function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_ad
for P in (plan_fft(similar(x_complexf), dims),
(_inv(plan_ifft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_complexf, x_fft)
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
TestUtils.test_plan(P, x_complexf, x_fft; test_wrappers=test_wrappers)
if test_adjoint
@test fftdims(P') == fftdims(P)
TestUtils.test_plan_adjoint(P, x_complexf)
TestUtils.test_plan_adjoint(P, x_complexf, test_wrappers=test_wrappers)
end
end
if test_inplace
# test IIP plans
for P in (plan_fft!(similar(x_complexf), dims),
(_inv(plan_ifft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
TestUtils.test_plan(P, x_complexf, x_fft; inplace_plan=true)
TestUtils.test_plan(P, x_complexf, x_fft; inplace_plan=true, test_wrappers=test_wrappers)
end
end

Expand All @@ -137,17 +143,17 @@ function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_ad
# test OOP plans. Just 1 plan to test, but we use a for loop for consistent style
for P in (plan_bfft(similar(x_fft), dims),)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x_scaled)
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
TestUtils.test_plan(P, x_fft, x_scaled; test_wrappers=test_wrappers)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_fft)
TestUtils.test_plan_adjoint(P, x_fft, test_wrappers=test_wrappers)
end
end
# test IIP plans
for P in (plan_bfft!(similar(x_fft), dims),)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x_scaled; inplace_plan=true)
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
TestUtils.test_plan(P, x_fft, x_scaled; inplace_plan=true, test_wrappers=test_wrappers)
end

# IFFT
Expand All @@ -161,33 +167,33 @@ function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_ad
for P in (plan_ifft(similar(x_complexf), dims),
(_inv(plan_fft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x)
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
TestUtils.test_plan(P, x_fft, x; test_wrappers=test_wrappers)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_fft)
TestUtils.test_plan_adjoint(P, x_fft; test_wrappers=test_wrappers)
end
end
# test IIP plans
if test_inplace
for P in (plan_ifft!(similar(x_complexf), dims),
(_inv(plan_fft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x; inplace_plan=true)
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
TestUtils.test_plan(P, x_fft, x; inplace_plan=true, test_wrappers=test_wrappers)
end
end
end
end
end

function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)
function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false, test_wrappers=true)
@testset "correctness of rfft, brfft, irfft" begin
for test_case in TEST_CASES
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
x = convert(ArrayType, _x) # dummy array that will be passed to plans
x_real = float.(x) # for testing mutating real FFTs
x_fft = convert(ArrayType, _x_fft)
x_rfft = collect(selectdim(x_fft, first(dims), 1:(size(x_fft, first(dims)) ÷ 2 + 1)))
x_rfft = convert(ArrayType, collect(selectdim(x_fft, first(dims), 1:(size(x_fft, first(dims)) ÷ 2 + 1))))

if !(eltype(x) <: Real)
continue
Expand All @@ -198,10 +204,10 @@ function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input
for P in (plan_rfft(similar(x_real), dims),
(_inv(plan_irfft(similar(x_rfft), size(x, first(dims)), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Real
@test fftdims(P) == dims
TestUtils.test_plan(P, x_real, x_rfft; copy_input=copy_input)
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
TestUtils.test_plan(P, x_real, x_rfft; copy_input=copy_input, test_wrappers=test_wrappers)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_real; real_plan=true, copy_input=copy_input)
TestUtils.test_plan_adjoint(P, x_real; real_plan=true, copy_input=copy_input, test_wrappers=test_wrappers)
end
end

Expand All @@ -210,10 +216,10 @@ function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input
@test brfft(x_rfft, size(x, first(dims)), dims) ≈ x_scaled
for P in (plan_brfft(similar(x_rfft), size(x, first(dims)), dims),)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_rfft, x_scaled; copy_input=copy_input)
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
TestUtils.test_plan(P, x_rfft, x_scaled; copy_input=copy_input, test_wrappers=test_wrappers)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input)
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input, test_wrappers=test_wrappers)
end
end

Expand All @@ -222,10 +228,10 @@ function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input
for P in (plan_irfft(similar(x_rfft), size(x, first(dims)), dims),
(_inv(plan_rfft(similar(x_real), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_rfft, x; copy_input=copy_input)
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
TestUtils.test_plan(P, x_rfft, x; copy_input=copy_input, test_wrappers=test_wrappers)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input)
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input, test_wrappers=test_wrappers)
end
end
end
Expand Down
12 changes: 9 additions & 3 deletions src/TestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module TestUtils
import ..AbstractFFTs

"""
TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true)
TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true, test_wrappers=true)

Run tests to verify correctness of FFT, BFFT, and IFFT functionality using a particular backend plan implementation.
The backend implementation is assumed to be loaded prior to calling this function.
Expand All @@ -15,11 +15,12 @@ The backend implementation is assumed to be loaded prior to calling this functio
`convert(ArrayType, ...)`.
- `test_inplace=true`: whether to test in-place plans.
- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint).
- `test_wrappers=true`: whether to test any wrapper array inputs such as views.
"""
function test_complex_ffts end

"""
TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)
TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false, test_wrappers=true)

Run tests to verify correctness of RFFT, BRFFT, and IRFFT functionality using a particular backend plan implementation.
The backend implementation is assumed to be loaded prior to calling this function.
Expand All @@ -32,18 +33,21 @@ The backend implementation is assumed to be loaded prior to calling this functio
- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint).
- `copy_input=false`: whether to copy the input before applying the plan in tests, to accomodate for
[input-mutating behaviour of real FFTW plans](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101).
- `test_wrappers=true`: whether to test any wrapper array inputs such as views.
"""
function test_real_ffts end

# Always copy input before application due to FFTW real plans possibly mutating input (AbstractFFTs.jl#101)
"""
TestUtils.test_plan(P::Plan, x::AbstractArray, x_transformed::AbstractArray;
inplace_plan=false, copy_input=false)
inplace_plan=false, copy_input=false, test_wrappers=true)

Test basic properties of a plan `P` given an input array `x` and expected output `x_transformed`.

Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101),
we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan.
We also allow specifying `test_wrappers=false` to skip testing wrapper array inputs such as views, which may cause ambiguity
issues for some array types currently.
"""
function test_plan end

Expand All @@ -57,6 +61,8 @@ Real-to-complex and complex-to-real plans require a slightly modified dot test,
The plan is assumed out-of-place, as adjoints are not yet supported for in-place plans.
Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101),
we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan.
We also allow specifying `test_wrappers=false` to skip testing wrapper array inputs such as views, which may cause ambiguity
issues for some array types currently.
"""
function test_plan_adjoint end

Expand Down
Loading