diff --git a/docs/src/api.md b/docs/src/api.md index bb3b849..c95525b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -21,6 +21,10 @@ AbstractFFTs.plan_brfft AbstractFFTs.plan_irfft AbstractFFTs.fftdims Base.adjoint +AbstractFFTs.FFTAdjointStyle +AbstractFFTs.RFFTAdjointStyle +AbstractFFTs.BRFFTAdjointStyle +AbstractFFTs.UnitaryAdjointStyle AbstractFFTs.fftshift AbstractFFTs.fftshift! AbstractFFTs.ifftshift diff --git a/docs/src/implementations.md b/docs/src/implementations.md index 7367fd4..c22eccc 100644 --- a/docs/src/implementations.md +++ b/docs/src/implementations.md @@ -32,10 +32,9 @@ To define a new FFT implementation in your own module, you should * You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs. -* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return: - * `AbstractFFTs.NoProjectionStyle()`, - * `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref), - * `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension. +* We offer an experimental `AdjointStyle` trait to enable automatic computation of adjoint plans via [`Base.adjoint`](@ref), +(which `AbstractFFTs` uses to implement reverse-mode differentiation rules). To support adjoints in a new plan, define the trait `AbstractFFTs.AdjointStyle(::MyPlan)`. This should return a subtype of `AS <: AbstractFFTs.AdjointStyle` supporting `AbstractFFTs.adjoint_mul(::AdjointPlan, ::AbstractArray, ::AS)`. `AbstractFFTs` pre-implements [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref), +[`AbstractFFTs.BRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref). The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``. diff --git a/src/definitions.jl b/src/definitions.jl index 604329f..b7996f8 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -583,35 +583,57 @@ plan_brfft ############################################################################## -abstract type ProjectionStyle end +abstract type AdjointStyle end """ - NoProjectionStyle() + FFTAdjointStyle() -Projection style for complex to complex discrete Fourier transform +Projection style for complex to complex discrete Fourier transforms. + +Since the Fourier transform is unitary up to a scaling, the adjoint simply applies +the transform's inverse with an appropriate scaling. """ -struct NoProjectionStyle <: ProjectionStyle end +struct FFTAdjointStyle <: AdjointStyle end """ - RealProjectionStyle() + RFFTAdjointStyle() -Projection style for complex to real discrete Fourier transform +Projection style for real to complex discrete Fourier transforms, for plans that +halve one of the output's dimensions analogously to [`rfft`](@ref). + +Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's +inverse, but with additional logic to handle the fact that the output is projected +to exploit its conjugate symmetry (see [`rfft`](@ref)). """ -struct RealProjectionStyle <: ProjectionStyle end +struct RFFTAdjointStyle <: AdjointStyle end """ - RealInverseProjectionStyle() + BRFFTAdjointStyle(d::Dim) -Projection style for inverse of complex to real discrete Fourier transform +Projection style for complex to real discrete Fourier transforms, for plans that +expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` +is the original length of the dimension. + +Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's +inverse, but with additional logic to handle the fact that the input is projected +to exploit its conjugate symmetry (see [`irfft`](@ref)). """ -struct RealInverseProjectionStyle <: ProjectionStyle +struct BRFFTAdjointStyle <: AdjointStyle dim::Int end -output_size(p::Plan) = _output_size(p, ProjectionStyle(p)) -_output_size(p::Plan, ::NoProjectionStyle) = size(p) -_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p)) -_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p)) +""" + UnitaryAdjointStyle() + +Projection style for unitary transforms, whose adjoint equals their inverse. +""" +struct UnitaryAdjointStyle <: AdjointStyle end + +output_size(p::Plan) = _output_size(p, AdjointStyle(p)) +_output_size(p::Plan, ::FFTAdjointStyle) = size(p) +_output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p)) +_output_size(p::Plan, s::BRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p)) +_output_size(p::Plan, ::UnitaryAdjointStyle) = size(p) struct AdjointPlan{T,P<:Plan} <: Plan{T} p::P @@ -638,15 +660,15 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale) size(p::AdjointPlan) = output_size(p.p) output_size(p::AdjointPlan) = size(p.p) -Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p)) +Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p, x, AdjointStyle(p.p)) -function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T} +function adjoint_mul(p::AdjointPlan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T} dims = fftdims(p.p) N = normalization(T, size(p.p), dims) return (p.p \ x) / N end -function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real} +function adjoint_mul(p::AdjointPlan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real} dims = fftdims(p.p) N = normalization(T, size(p.p), dims) halfdim = first(dims) @@ -659,7 +681,7 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where return p.p \ (x ./ convert(typeof(x), scale)) end -function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T} +function adjoint_mul(p::AdjointPlan{T}, x::AbstractArray, ::BRFFTAdjointStyle) where {T} dims = fftdims(p.p) N = normalization(real(T), output_size(p.p), dims) halfdim = first(dims) @@ -672,6 +694,8 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) return (convert(typeof(x), scale) ./ N) .* (p.p \ x) end +adjoint_mul(p::AdjointPlan, x::AbstractArray, ::UnitaryAdjointStyle) = p.p \ x + # Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only). plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p)) inv(p::AdjointPlan) = adjoint(inv(p.p)) diff --git a/test/testplans.jl b/test/testplans.jl index 09b3f67..623a550 100644 --- a/test/testplans.jl +++ b/test/testplans.jl @@ -21,8 +21,8 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N Base.size(p::InverseTestPlan) = p.sz Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N -AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle() -AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle() +AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle() +AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle() function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T} return TestPlan{T}(region, size(x)) @@ -110,8 +110,8 @@ mutable struct InverseTestRPlan{T,N,G} <: Plan{Complex{T}} end end -AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle() -AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d) +AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle() +AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.BRFFTAdjointStyle(p.d) function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real} return TestRPlan{T}(region, size(x)) @@ -241,7 +241,7 @@ end Base.size(p::InplaceTestPlan) = size(p.plan) Base.ndims(p::InplaceTestPlan) = ndims(p.plan) -AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan) +AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan) function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...) return InplaceTestPlan(plan_fft(x, region; kwargs...))