From d1e7507e985bfdaca2f08e535752b8c91fbb767e Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Mon, 2 Oct 2023 20:16:42 +0100 Subject: [PATCH 1/4] FFTW overrides via extension --- Project.toml | 8 +++ ext/ForwardDiffAbstractFFTsExt.jl | 39 ++++++++++++++ ext/ForwardDiffFFTWExt.jl | 15 ++++++ src/ForwardDiff.jl | 3 ++ src/complex.jl | 14 +++++ test/FFTTest.jl | 89 +++++++++++++++++++++++++++++++ test/runtests.jl | 5 ++ 7 files changed, 173 insertions(+) create mode 100644 ext/ForwardDiffAbstractFFTsExt.jl create mode 100644 ext/ForwardDiffFFTWExt.jl create mode 100644 src/complex.jl create mode 100644 test/FFTTest.jl diff --git a/Project.toml b/Project.toml index 1f12c7ba..b447ee8e 100644 --- a/Project.toml +++ b/Project.toml @@ -3,9 +3,11 @@ uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.11-DEV" [deps] +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" @@ -16,17 +18,21 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [weakdeps] +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] ForwardDiffStaticArraysExt = "StaticArrays" [compat] +AbstractFFTs = "1" Calculus = "0.5" CommonSubexpressions = "0.3" DiffResults = "1.1" DiffRules = "1.4" DiffTests = "0.1" +FFTW = "1" LogExpFunctions = "0.3" NaNMath = "1" Preferences = "1" @@ -35,8 +41,10 @@ StaticArrays = "1.5" julia = "1.6" [extras] +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/ext/ForwardDiffAbstractFFTsExt.jl b/ext/ForwardDiffAbstractFFTsExt.jl new file mode 100644 index 00000000..dcccc835 --- /dev/null +++ b/ext/ForwardDiffAbstractFFTsExt.jl @@ -0,0 +1,39 @@ +module ForwardDiffAbstractFFTsExt + +using ForwardDiff, AbstractFFTs + +import AbstractFFTs: plan_fft, plan_ifft, plan_bfft, plan_rfft, plan_brfft, plan_irfft, Plan +using ForwardDiff: array2dual, dual2array +import LinearAlgebra: mul! + +for P in (:Plan, :ScaledPlan) # need ScaledPlan to avoid ambiguities + @eval begin + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p * dual2array(x)) + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x)) + end +end + +mul!(y::AbstractArray{<:Union{Dual,Complex{<:Dual}}}, p::Plan, x::AbstractArray{<:Union{Dual,Complex{<:Dual}}}) = copyto!(y, p*x) + +AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x) +AbstractFFTs.complexfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) + 0im + +AbstractFFTs.realfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.realfloat.(x) +AbstractFFTs.realfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) + +for plan in (:plan_fft, :plan_ifft, :plan_bfft, :plan_rfft) + @eval begin + $plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims) + $plan(x::AbstractArray{<:Complex{<:Dual}}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims) + end +end + +for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex? + @eval begin + $plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims) + $plan(x::AbstractArray{<:Complex{<:Dual}}, d::Integer, dims=1:ndims(x)) = $plan(dual2array(x), d, 1 .+ dims) + end +end + + +end \ No newline at end of file diff --git a/ext/ForwardDiffFFTWExt.jl b/ext/ForwardDiffFFTWExt.jl new file mode 100644 index 00000000..dd6bc806 --- /dev/null +++ b/ext/ForwardDiffFFTWExt.jl @@ -0,0 +1,15 @@ +module ForwardDiffFFTWExt + +using ForwardDiff, FFTW + +import FFTW: r2r, r2r!, plan_r2r + + +plan_r2r(x::AbstractArray{<:Dual}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims) +plan_r2r(x::AbstractArray{<:Complex{<:Dual}}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims) + +r2r(x::AbstractArray{<:Dual}, kinds, region...) = plan_r2r(x, kinds, region...) * x +r2r(x::AbstractArray{<:Complex{<:Dual}}, kinds, region...) = plan_r2r(x, kinds, region...) * x + + +end \ No newline at end of file diff --git a/src/ForwardDiff.jl b/src/ForwardDiff.jl index fdfcd560..6e842af8 100644 --- a/src/ForwardDiff.jl +++ b/src/ForwardDiff.jl @@ -21,9 +21,12 @@ include("derivative.jl") include("gradient.jl") include("jacobian.jl") include("hessian.jl") +include("complex.jl") if !isdefined(Base, :get_extension) include("../ext/ForwardDiffStaticArraysExt.jl") + include("../ext/ForwardDiffAbstractFFTsExt.jl") + include("../ext/ForwardDiffFFTWExt.jl") end export DiffResults diff --git a/src/complex.jl b/src/complex.jl new file mode 100644 index 00000000..0fabbc8d --- /dev/null +++ b/src/complex.jl @@ -0,0 +1,14 @@ +@inline tagtype(::Complex{T}) where T = tagtype(T) +@inline tagtype(::Type{Complex{T}}) where T = tagtype(T) + +dual2array(x::Array{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x) +dual2array(x::Array{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x))) +array2dual(DT::Type{<:Dual}, x::Array{T}) where T = reinterpret(reshape, DT, real(x)) +array2dual(DT::Type{<:Dual}, x::Array{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x))) + +value(x::Complex{<:Dual}) = Complex(x.re.value, x.im.value) + +partials(x::Complex{<:Dual}, n::Int) = Complex(partials(x.re, n), partials(x.im, n)) + +npartials(x::Complex{<:Dual{T,V,N}}) where {T,V,N} = N +npartials(::Type{<:Complex{<:Dual{T,V,N}}}) where {T,V,N} = N diff --git a/test/FFTTest.jl b/test/FFTTest.jl new file mode 100644 index 00000000..441b7577 --- /dev/null +++ b/test/FFTTest.jl @@ -0,0 +1,89 @@ +module FFTTest + +using FastTransformsForwardDiff, FFTW, LinearAlgebra, Test +using ForwardDiff: Dual, valtype, value, partials, derivative +using AbstractFFTs: complexfloat, realfloat + +@testset "complex dual" begin + x = Dual(1., 2., 3.) + im*Dual(4.,5.,6.) + @test value(x) == 1 + 4im + @test partials(x,1) == 2 + 5im + @test partials(x,2) == 3 + 6im +end + +@testset "fft and rfft" begin + x1 = Dual.(1:4.0, 2:5, 3:6) + + @test value.(x1) == 1:4 + @test partials.(x1, 1) == 2:5 + @test partials.(x1, 2) == 3:6 + + @test complexfloat(x1)[1] === complexfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + 0im + @test realfloat(x1)[1] === realfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + + @test fft(x1, 1)[1] isa Complex{<:Dual} + + @testset "$f" for f in (fft, ifft, rfft, bfft) + @test value.(f(x1)) == f(value.(x1)) + @test partials.(f(x1), 1) == f(partials.(x1, 1)) + @test partials.(f(x1), 2) == f(partials.(x1, 2)) + end + + @test ifft(fft(x1)) == x1 + @test irfft(rfft(x1), length(x1)) ≈ x1 + @test brfft(rfft(x1), length(x1)) ≈ 4x1 + + f = x -> real(fft([x; 0; 0])[1]) + @test derivative(f,0.1) ≈ 1 + + r = x -> real(rfft([x; 0; 0])[1]) + @test derivative(r,0.1) ≈ 1 + + + n = 100 + θ = range(0,2π; length=n+1)[1:end-1] + # emperical from Mathematical + @test derivative(ω -> fft(exp.(ω .* cos.(θ)))[1]/n, 1) ≈ 0.565159103992485 + + # c = x -> dct([x; 0; 0])[1] + # @test derivative(c,0.1) ≈ 1 + + @testset "matrix" begin + A = x1 * (1:10)' + @test value.(fft(A)) == fft(value.(A)) + @test partials.(fft(A), 1) == fft(partials.(A, 1)) + @test partials.(fft(A), 2) == fft(partials.(A, 2)) + + @test value.(fft(A, 1)) == fft(value.(A), 1) + @test partials.(fft(A, 1), 1) == fft(partials.(A, 1), 1) + @test partials.(fft(A, 1), 2) == fft(partials.(A, 2), 1) + + @test value.(fft(A, 2)) == fft(value.(A), 2) + @test partials.(fft(A, 2), 1) == fft(partials.(A, 1), 2) + @test partials.(fft(A, 2), 2) == fft(partials.(A, 2), 2) + end + + c1 = complex.(x1) + @test mul!(similar(c1), FFTW.plan_fft(x1), x1) == fft(x1) + @test mul!(similar(c1), FFTW.plan_fft(c1), c1) == fft(c1) +end + +@testset "r2r" begin + x1 = Dual.(1:4.0, 2:5, 3:6) + t = FFTW.r2r(x1, FFTW.R2HC) + + @test value.(t) == FFTW.r2r(value.(x1), FFTW.R2HC) + @test partials.(t, 1) == FFTW.r2r(partials.(x1, 1), FFTW.R2HC) + @test partials.(t, 2) == FFTW.r2r(partials.(x1, 2), FFTW.R2HC) + + t = FFTW.r2r(x1 + 2im*x1, FFTW.R2HC) + @test value.(t) == FFTW.r2r(value.(x1 + 2im*x1), FFTW.R2HC) + @test partials.(t, 1) == FFTW.r2r(partials.(x1 + 2im*x1, 1), FFTW.R2HC) + @test partials.(t, 2) == FFTW.r2r(partials.(x1 + 2im*x1, 2), FFTW.R2HC) + + f = ω -> FFTW.r2r([ω; zeros(9)], FFTW.R2HC)[1] + @test derivative(f, 0.1) ≡ 1.0 + + @test mul!(similar(x1), FFTW.plan_r2r(x1, FFTW.R2HC), x1) == FFTW.r2r(x1, FFTW.R2HC) +end +end # module \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index e440af65..a1fa0360 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -51,5 +51,10 @@ Random.seed!(SEED) t = @elapsed include("AllocationsTest.jl") println("##### done (took $t seconds).") end + @testset "FFT" begin + println("##### Testing fft...") + t = @elapsed include("FFTTest.jl") + println("##### done (took $t seconds).") + end println("##### Running all ForwardDiff tests took $(time() - t0) seconds.") end From f3e1d7f747653fd64949193a34f5969b6e967f13 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Mon, 2 Oct 2023 20:40:08 +0100 Subject: [PATCH 2/4] fix extension loading --- Project.toml | 4 +++- ext/ForwardDiffAbstractFFTsExt.jl | 2 +- ext/ForwardDiffFFTWExt.jl | 3 +-- test/ComplexTest.jl | 10 ++++++++++ test/FFTTest.jl | 9 +-------- test/runtests.jl | 5 +++++ 6 files changed, 21 insertions(+), 12 deletions(-) create mode 100644 test/ComplexTest.jl diff --git a/Project.toml b/Project.toml index b447ee8e..4ef9bd3c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ForwardDiff" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.11-DEV" +version = "0.11.0-DEV" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" @@ -23,6 +23,8 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] +ForwardDiffAbstractFFTsExt = "AbstractFFTs" +ForwardDiffFFTWExt = "FFTW" ForwardDiffStaticArraysExt = "StaticArrays" [compat] diff --git a/ext/ForwardDiffAbstractFFTsExt.jl b/ext/ForwardDiffAbstractFFTsExt.jl index dcccc835..da0e035a 100644 --- a/ext/ForwardDiffAbstractFFTsExt.jl +++ b/ext/ForwardDiffAbstractFFTsExt.jl @@ -3,7 +3,7 @@ module ForwardDiffAbstractFFTsExt using ForwardDiff, AbstractFFTs import AbstractFFTs: plan_fft, plan_ifft, plan_bfft, plan_rfft, plan_brfft, plan_irfft, Plan -using ForwardDiff: array2dual, dual2array +using ForwardDiff: array2dual, dual2array, Dual import LinearAlgebra: mul! for P in (:Plan, :ScaledPlan) # need ScaledPlan to avoid ambiguities diff --git a/ext/ForwardDiffFFTWExt.jl b/ext/ForwardDiffFFTWExt.jl index dd6bc806..6e004962 100644 --- a/ext/ForwardDiffFFTWExt.jl +++ b/ext/ForwardDiffFFTWExt.jl @@ -1,7 +1,6 @@ module ForwardDiffFFTWExt -using ForwardDiff, FFTW - +using ForwardDiff: Dual, dual2array import FFTW: r2r, r2r!, plan_r2r diff --git a/test/ComplexTest.jl b/test/ComplexTest.jl new file mode 100644 index 00000000..635d94d2 --- /dev/null +++ b/test/ComplexTest.jl @@ -0,0 +1,10 @@ +module ComplexTest +using ForwardDiff, Test + +@testset "complex dual" begin + x = Dual(1., 2., 3.) + im*Dual(4.,5.,6.) + @test value(x) == 1 + 4im + @test partials(x,1) == 2 + 5im + @test partials(x,2) == 3 + 6im +end +end \ No newline at end of file diff --git a/test/FFTTest.jl b/test/FFTTest.jl index 441b7577..09bd6528 100644 --- a/test/FFTTest.jl +++ b/test/FFTTest.jl @@ -1,16 +1,9 @@ module FFTTest -using FastTransformsForwardDiff, FFTW, LinearAlgebra, Test +using FFTW, LinearAlgebra, Test using ForwardDiff: Dual, valtype, value, partials, derivative using AbstractFFTs: complexfloat, realfloat -@testset "complex dual" begin - x = Dual(1., 2., 3.) + im*Dual(4.,5.,6.) - @test value(x) == 1 + 4im - @test partials(x,1) == 2 + 5im - @test partials(x,2) == 3 + 6im -end - @testset "fft and rfft" begin x1 = Dual.(1:4.0, 2:5, 3:6) diff --git a/test/runtests.jl b/test/runtests.jl index a1fa0360..14a044e6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -51,6 +51,11 @@ Random.seed!(SEED) t = @elapsed include("AllocationsTest.jl") println("##### done (took $t seconds).") end + @testset "Complex" begin + println("##### Testing complex...") + t = @elapsed include("ComplexTest.jl") + println("##### done (took $t seconds).") + end @testset "FFT" begin println("##### Testing fft...") t = @elapsed include("FFTTest.jl") From be18d45a095a712e2eb660df29ac36dda709aea3 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Mon, 2 Oct 2023 21:28:17 +0100 Subject: [PATCH 3/4] fix tests --- Project.toml | 2 +- test/ComplexTest.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4ef9bd3c..ac9b4937 100644 --- a/Project.toml +++ b/Project.toml @@ -53,4 +53,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Calculus", "DiffTests", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils"] +test = ["AbstractFFTs", "Calculus", "DiffTests", "FFTW", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils"] diff --git a/test/ComplexTest.jl b/test/ComplexTest.jl index 635d94d2..0cd8eb57 100644 --- a/test/ComplexTest.jl +++ b/test/ComplexTest.jl @@ -1,5 +1,6 @@ module ComplexTest using ForwardDiff, Test +using ForwardDiff: Dual @testset "complex dual" begin x = Dual(1., 2., 3.) + im*Dual(4.,5.,6.) From 1130fa47edfb9ae2dd2e0550054957943da62410 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Mon, 2 Oct 2023 21:41:32 +0100 Subject: [PATCH 4/4] Update ComplexTest.jl --- test/ComplexTest.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ComplexTest.jl b/test/ComplexTest.jl index 0cd8eb57..604c13a5 100644 --- a/test/ComplexTest.jl +++ b/test/ComplexTest.jl @@ -1,6 +1,6 @@ module ComplexTest using ForwardDiff, Test -using ForwardDiff: Dual +using ForwardDiff: Dual, partials, value @testset "complex dual" begin x = Dual(1., 2., 3.) + im*Dual(4.,5.,6.)