diff --git a/Project.toml b/Project.toml index b648109..9277108 100644 --- a/Project.toml +++ b/Project.toml @@ -10,9 +10,16 @@ MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +[weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + +[extensions] +FFTWForwardDiffExt = "ForwardDiff" + [compat] -AbstractFFTs = "1.5" +AbstractFFTs = "1.6" FFTW_jll = "3.3.9" +ForwardDiff = "0.10" LinearAlgebra = "<0.0.1, 1" MKL_jll = "2019.0.117, 2020, 2021, 2022, 2023, 2024" Preferences = "1.2" diff --git a/ext/FFTWForwardDiffExt.jl b/ext/FFTWForwardDiffExt.jl new file mode 100644 index 0000000..4e57fb3 --- /dev/null +++ b/ext/FFTWForwardDiffExt.jl @@ -0,0 +1,40 @@ +module FFTWForwardDiffExt +using FFTW +using ForwardDiff +import FFTW: plan_r2r, plan_r2r!, plan_dct, plan_dct!, plan_idct, plan_idct!, r2r, r2r!, dct, dct!, idct, idct!, fftwReal, REDFT10, REDFT01 +import FFTW.AbstractFFTs: dualplan, dual2array +import ForwardDiff: Dual + + +for plan in (:plan_r2r, :plan_r2r!) + @eval begin + $plan(x::AbstractArray{D}, FLAG, dims=1:ndims(x)) where D<:Dual = dualplan(D, $plan(dual2array(x), FLAG, 1 .+ dims)) + $plan(x::AbstractArray{<:Complex{D}}, FLAG, dims=1:ndims(x)) where D<:Dual = dualplan(D, $plan(dual2array(x), FLAG, 1 .+ dims)) + end +end + +for f in (:r2r, :r2r!) + pf = Symbol("plan_", f) + @eval begin + $f(x::AbstractArray{<:Dual}, kinds, region...) = $pf(x, kinds, region...) * x + $f(x::AbstractArray{<:Complex{<:Dual}}, kinds, region...) = $pf(x, kinds, region...) * x + end +end + + +for f in (:dct, :dct!, :idct, :idct!) + pf = Symbol("plan_", f) + @eval begin + $f(x::AbstractArray{<:Dual}) = $pf(x) * x + $f(x::AbstractArray{<:Dual}, region) = $pf(x, region) * x + end +end + +for plan in (:plan_dct, :plan_dct!, :plan_idct, :plan_idct!) + @eval begin + $plan(x::AbstractArray{D}, dims=1:ndims(x); kwds...) where D<:Dual = dualplan(D, $plan(dual2array(x), 1 .+ dims; kwds...)) + $plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x); kwds...) where D<:Dual = dualplan(D, $plan(dual2array(x), 1 .+ dims; kwds...)) + end +end + +end #module \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index 895c586..f0b0722 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,9 +5,11 @@ [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Aqua = "0.8" +ForwardDiff = "0.10" Test = "<0.0.1, 1" diff --git a/test/fftwforwarddiff.jl b/test/fftwforwarddiff.jl new file mode 100644 index 0000000..5f14906 --- /dev/null +++ b/test/fftwforwarddiff.jl @@ -0,0 +1,38 @@ +using FFTW, ForwardDiff, Test +using ForwardDiff: Dual, value, partials + +@testset "ForwardDiff extension" begin + @testset "r2r" begin + x1 = Dual.(1:4.0, 2:5, 3:6) + t = FFTW.r2r(x1, FFTW.R2HC) + + @test value.(t) == FFTW.r2r(value.(x1), FFTW.R2HC) + @test partials.(t, 1) == FFTW.r2r(partials.(x1, 1), FFTW.R2HC) + @test partials.(t, 2) == FFTW.r2r(partials.(x1, 2), FFTW.R2HC) + + t = FFTW.r2r(x1 + 2im*x1, FFTW.R2HC) + @test value.(t) == FFTW.r2r(value.(x1 + 2im*x1), FFTW.R2HC) + @test partials.(t, 1) == FFTW.r2r(partials.(x1 + 2im*x1, 1), FFTW.R2HC) + @test partials.(t, 2) == FFTW.r2r(partials.(x1 + 2im*x1, 2), FFTW.R2HC) + + f = ω -> FFTW.r2r([ω; zeros(9)], FFTW.R2HC)[1] + @test ForwardDiff.derivative(f, 0.1) ≡ 1.0 + + @test mul!(similar(x1), FFTW.plan_r2r(x1, FFTW.R2HC), x1) == FFTW.r2r(x1, FFTW.R2HC) + + x = [Dual(1.0,2,3), Dual(4,5,6)] + a = FFTW.r2r(x, FFTW.REDFT00) + b = FFTW.r2r!(x, FFTW.REDFT00) + @test a == b == x + end + + @testset "dct" begin + x = [Dual(1.0,2,3), Dual(4,5,6)] + a = dct(x) + b = dct!(x) + @test a == b == x + + c = x -> dct([x; 0; 0])[1] + @test ForwardDiff.derivative(c,0.1) ≈ 1/sqrt(3) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 6b158ac..1dda166 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -606,3 +606,5 @@ end AbstractFFTs.TestUtils.test_real_ffts(Array; copy_input=true) end end + +include("fftwforwarddiff.jl") \ No newline at end of file