Skip to content

Commit

Permalink
Rename ProjectionStyle's -> AdjointStyles and improve docs
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav-arya committed Jul 18, 2023
1 parent d53f57d commit bfd3133
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 27 deletions.
4 changes: 4 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ AbstractFFTs.plan_brfft
AbstractFFTs.plan_irfft
AbstractFFTs.fftdims
Base.adjoint
AbstractFFTs.FFTAdjointStyle
AbstractFFTs.RFFTAdjointStyle
AbstractFFTs.BRFFTAdjointStyle
AbstractFFTs.UnitaryAdjointStyle
AbstractFFTs.fftshift
AbstractFFTs.fftshift!
AbstractFFTs.ifftshift
Expand Down
7 changes: 3 additions & 4 deletions docs/src/implementations.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ To define a new FFT implementation in your own module, you should

* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.

* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return:
* `AbstractFFTs.NoProjectionStyle()`,
* `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref),
* `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension.
* We offer an experimental `AdjointStyle` trait to enable automatic computation of adjoint plans via [`Base.adjoint`](@ref),
(which `AbstractFFTs` uses to implement reverse-mode differentiation rules). To support adjoints in a new plan, define the trait `AbstractFFTs.AdjointStyle(::MyPlan)`. This should return a subtype of `AS <: AbstractFFTs.AdjointStyle` supporting `AbstractFFTs.adjoint_mul(::AdjointPlan, ::AbstractArray, ::AS)`. `AbstractFFTs` pre-implements [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref),
[`AbstractFFTs.BRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref).

The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.
60 changes: 42 additions & 18 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -583,35 +583,57 @@ plan_brfft

##############################################################################

abstract type ProjectionStyle end
abstract type AdjointStyle end

"""
NoProjectionStyle()
FFTAdjointStyle()
Projection style for complex to complex discrete Fourier transform
Projection style for complex to complex discrete Fourier transforms.
Since the Fourier transform is unitary up to a scaling, the adjoint simply applies
the transform's inverse with an appropriate scaling.
"""
struct NoProjectionStyle <: ProjectionStyle end
struct FFTAdjointStyle <: AdjointStyle end

"""
RealProjectionStyle()
RFFTAdjointStyle()
Projection style for complex to real discrete Fourier transform
Projection style for real to complex discrete Fourier transforms, for plans that
halve one of the output's dimensions analogously to [`rfft`](@ref).
Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
inverse, but with additional logic to handle the fact that the output is projected
to exploit its conjugate symmetry (see [`rfft`](@ref)).
"""
struct RealProjectionStyle <: ProjectionStyle end
struct RFFTAdjointStyle <: AdjointStyle end

"""
RealInverseProjectionStyle()
BRFFTAdjointStyle(d::Dim)
Projection style for inverse of complex to real discrete Fourier transform
Projection style for complex to real discrete Fourier transforms, for plans that
expect an input with a halved dimension analogously to [`irfft`](@ref), where `d`
is the original length of the dimension.
Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
inverse, but with additional logic to handle the fact that the input is projected
to exploit its conjugate symmetry (see [`irfft`](@ref)).
"""
struct RealInverseProjectionStyle <: ProjectionStyle
struct BRFFTAdjointStyle <: AdjointStyle
dim::Int
end

output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p))
_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
"""
UnitaryAdjointStyle()
Projection style for unitary transforms, whose adjoint equals their inverse.
"""
struct UnitaryAdjointStyle <: AdjointStyle end

output_size(p::Plan) = _output_size(p, AdjointStyle(p))
_output_size(p::Plan, ::FFTAdjointStyle) = size(p)
_output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p))
_output_size(p::Plan, s::BRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
_output_size(p::Plan, ::UnitaryAdjointStyle) = size(p)

struct AdjointPlan{T,P<:Plan} <: Plan{T}
p::P
Expand All @@ -638,15 +660,15 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)
size(p::AdjointPlan) = output_size(p.p)
output_size(p::AdjointPlan) = size(p.p)

Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p, x, AdjointStyle(p.p))

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
function adjoint_mul(p::AdjointPlan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
return (p.p \ x) / N
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real}
function adjoint_mul(p::AdjointPlan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
halfdim = first(dims)
Expand All @@ -659,7 +681,7 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where
return p.p \ (x ./ convert(typeof(x), scale))
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
function adjoint_mul(p::AdjointPlan{T}, x::AbstractArray, ::BRFFTAdjointStyle) where {T}
dims = fftdims(p.p)
N = normalization(real(T), output_size(p.p), dims)
halfdim = first(dims)
Expand All @@ -672,6 +694,8 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle)
return (convert(typeof(x), scale) ./ N) .* (p.p \ x)
end

adjoint_mul(p::AdjointPlan, x::AbstractArray, ::UnitaryAdjointStyle) = p.p \ x

# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))
inv(p::AdjointPlan) = adjoint(inv(p.p))
10 changes: 5 additions & 5 deletions test/testplans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N
Base.size(p::InverseTestPlan) = p.sz
Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N

AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle()
AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle()
AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle()
AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle()

function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
return TestPlan{T}(region, size(x))
Expand Down Expand Up @@ -110,8 +110,8 @@ mutable struct InverseTestRPlan{T,N,G} <: Plan{Complex{T}}
end
end

AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle()
AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d)
AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle()
AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.BRFFTAdjointStyle(p.d)

function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real}
return TestRPlan{T}(region, size(x))
Expand Down Expand Up @@ -241,7 +241,7 @@ end

Base.size(p::InplaceTestPlan) = size(p.plan)
Base.ndims(p::InplaceTestPlan) = ndims(p.plan)
AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan)
AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan)

function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...)
return InplaceTestPlan(plan_fft(x, region; kwargs...))
Expand Down

0 comments on commit bfd3133

Please sign in to comment.