Skip to content

Commit

Permalink
Merge branch 'master' into dct
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 authored Sep 26, 2023
2 parents 608aa02 + ef8fc5b commit 26df888
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FFTWChainRulesCoreExt = "ChainRulesCore"

[compat]
AbstractFFTs = "1.0"
ChainRulesCore = "1"
AbstractFFTs = "1.5"
FFTW_jll = "3.3.9"
MKL_jll = "2019.0.117, 2020, 2021, 2022, 2023"
Preferences = "1.2"
Expand Down
2 changes: 2 additions & 0 deletions src/dct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,5 @@ end
mul!(Array{T}(undef, p.plan.osz), p, copy(x)) # need copy to preserve input

*(p::DCTPlan{T,K,true}, x::StridedArray{T}) where {T,K} = mul!(x, p, x)

AbstractFFTs.AdjointStyle(::DCTPlan) = AbstractFFTs.UnitaryAdjointStyle()
6 changes: 6 additions & 0 deletions src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1049,3 +1049,9 @@ function *(p::r2rFFTWPlan{T,K,true}, x::StridedArray{T}) where {T,K}
unsafe_execute!(p, x, x)
return x
end

#######################################################################

AbstractFFTs.AdjointStyle(::cFFTWPlan) = AbstractFFTs.FFTAdjointStyle()
AbstractFFTs.AdjointStyle(::rFFTWPlan{T, FORWARD}) where {T} = AbstractFFTs.RFFTAdjointStyle()
AbstractFFTs.AdjointStyle(p::rFFTWPlan{T, BACKWARD}) where {T} = AbstractFFTs.IRFFTAdjointStyle(p.osz[first(p.region)])
58 changes: 19 additions & 39 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -578,46 +578,26 @@ end
end
end

@testset "ChainRules" begin

if isdefined(Base, :get_extension)
CRCEXT = Base.get_extension(FFTW, :FFTWChainRulesCoreExt)
@test isnothing(CRCEXT)
end

using ChainRulesTestUtils

if isdefined(Base, :get_extension)
CRCEXT = Base.get_extension(FFTW, :FFTWChainRulesCoreExt)
@test !isnothing(CRCEXT)
@testset "DCT adjoints" begin
# only test on FFTW because MKL is missing functionality
if FFTW.get_provider() == "fftw"
for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5))
y = randn(size(x))
N = ndims(x)
for dims in unique((1, 1:N, N))
for P in (plan_dct(x, dims), plan_idct(x, dims))
AbstractFFTs.TestUtils.test_plan_adjoint(P, x)
end
end
end
end
end

@testset "DCT" begin
for f in (dct, idct)
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
test_frule(f, x)
test_rrule(f, x)

N = ndims(x)
for region in unique((1, 1:N, N))
test_frule(f, x, region)
test_rrule(f, x, region)
end # for region
end # for x
end # for f
end

@testset "r2r" begin
for k in 0:10
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
test_frule(r2r, x, k)

N = ndims(x)
for region in unique((1, 1:N, N))
test_frule(r2r, x, k, region)
end # for region
end # for x
end # for f
@testset "AbstractFFTs FFT backend tests" begin
# note this also tests adjoint functionality for FFT plans
# only test on FFTW because MKL is missing functionality
if FFTW.get_provider() == "fftw"
AbstractFFTs.TestUtils.test_complex_ffts(Array)
AbstractFFTs.TestUtils.test_real_ffts(Array; copy_input=true)
end

end

0 comments on commit 26df888

Please sign in to comment.