Skip to content

Commit

Permalink
Rename ProjectionStyle -> AdjointStyle, _mul -> AdjointMul, and impro…
Browse files Browse the repository at this point in the history
…ve docs (#109)

* make ProjectionStyle abstract type so we can subtype in downstream packages. add a few lines of docs

* Rename ProjectionStyle -> AdjointStyle, _mul -> AdjointMul, improve docs

* Clarify normalization

* Clarify documentation, rename _output_size -> output_size

* Remove unnecessary def

* Remove confusing commas

* Tweak docstring wording

* Reposition and improve size/output_size docstrings

* Note that size needs to be implemented in docs

---------

Co-authored-by: Gaurav Arya <[email protected]>
  • Loading branch information
vpuri3 and gaurav-arya authored Jul 27, 2023
1 parent 1cc9ca0 commit 5c23f4b
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 42 deletions.
19 changes: 18 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Public Interface

## FFT and FFT planning functions

```@docs
AbstractFFTs.fft
AbstractFFTs.fft!
Expand All @@ -20,11 +22,26 @@ AbstractFFTs.plan_rfft
AbstractFFTs.plan_brfft
AbstractFFTs.plan_irfft
AbstractFFTs.fftdims
Base.adjoint
AbstractFFTs.fftshift
AbstractFFTs.fftshift!
AbstractFFTs.ifftshift
AbstractFFTs.ifftshift!
AbstractFFTs.fftfreq
AbstractFFTs.rfftfreq
Base.size
```

## Adjoint functionality

The following API is supported by plans that support adjoint functionality.
It is also relevant to implementers of FFT plans that wish to support adjoints.
```@docs
Base.adjoint
AbstractFFTs.AdjointStyle
AbstractFFTs.output_size
AbstractFFTs.adjoint_mul
AbstractFFTs.FFTAdjointStyle
AbstractFFTs.RFFTAdjointStyle
AbstractFFTs.IRFFTAdjointStyle
AbstractFFTs.UnitaryAdjointStyle
```
10 changes: 5 additions & 5 deletions docs/src/implementations.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ To define a new FFT implementation in your own module, you should
inverse plan.

* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)`
(which defaults to `p.region`), and the input size `size(x)` should be accessible via `size(p::MyPlan)`.

* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` that computes the transform `p` of `x` and stores the result in `y`.

Expand All @@ -32,10 +33,9 @@ To define a new FFT implementation in your own module, you should

* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.

* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return:
* `AbstractFFTs.NoProjectionStyle()`,
* `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref),
* `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension.
* To support adjoints in a new plan, define the trait [`AbstractFFTs.AdjointStyle`](@ref).
`AbstractFFTs` implements the following adjoint styles: [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref), [`AbstractFFTs.IRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref).
To define a new adjoint style, define the methods [`AbstractFFTs.adjoint_mul`](@ref) and [`AbstractFFTs.output_size`](@ref).

The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.
131 changes: 100 additions & 31 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ abstract type Plan{T} end

eltype(::Type{<:Plan{T}}) where {T} = T

# size(p) should return the size of the input array for p
size(p::Plan, d) = size(p)[d]
output_size(p::Plan, d) = output_size(p)[d]
"""
size(p::Plan, [dim])
Return the size of the input of a plan `p`, optionally at a specified dimenion `dim`.
"""
size(p::Plan, dim) = size(p)[dim]
ndims(p::Plan) = length(size(p))
length(p::Plan) = prod(size(p))::Int

Expand Down Expand Up @@ -583,17 +586,73 @@ plan_brfft

##############################################################################

struct NoProjectionStyle end
struct RealProjectionStyle end
struct RealInverseProjectionStyle
"""
AbstractFFTs.AdjointStyle(::Plan)
Return the adjoint style of a plan, enabling automatic computation of adjoint plans via
[`Base.adjoint`](@ref). Instructions for supporting adjoint styles are provided in the
[implementation instructions](implementations.md#Defining-a-new-implementation).
"""
abstract type AdjointStyle end

"""
FFTAdjointStyle()
Adjoint style for complex to complex discrete Fourier transforms that normalize
the output analogously to [`fft`](@ref).
Since the Fourier transform is unitary up to a scaling, the adjoint simply applies
the transform's inverse with an appropriate scaling.
"""
struct FFTAdjointStyle <: AdjointStyle end

"""
RFFTAdjointStyle()
Adjoint style for real to complex discrete Fourier transforms that halve one of
the output's dimensions and normalize the output analogously to [`rfft`](@ref).
Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
inverse, but with appropriate scaling and additional logic to handle the fact that the
output is projected to exploit its conjugate symmetry (see [`rfft`](@ref)).
"""
struct RFFTAdjointStyle <: AdjointStyle end

"""
IRFFTAdjointStyle(d::Dim)
Adjoint style for complex to real discrete Fourier transforms that expect an input
with a halved dimension and normalize the output analogously to [`irfft`](@ref),
where `d` is the original length of the dimension.
Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
inverse, but with appropriate scaling and additional logic to handle the fact that the
input is projected to exploit its conjugate symmetry (see [`irfft`](@ref)).
"""
struct IRFFTAdjointStyle <: AdjointStyle
dim::Int
end
const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}

output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p))
_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
"""
UnitaryAdjointStyle()
Adjoint style for unitary transforms, whose adjoint equals their inverse.
"""
struct UnitaryAdjointStyle <: AdjointStyle end

"""
output_size(p::Plan, [dim])
Return the size of the output of a plan `p`, optionally at a specified dimension `dim`.
Implementations of a new adjoint style `AS <: AbstractFFTs.AdjointStyle` should define `output_size(::Plan, ::AS)`.
"""
output_size(p::Plan) = output_size(p, AdjointStyle(p))
output_size(p::Plan, dim) = output_size(p)[dim]
output_size(p::Plan, ::FFTAdjointStyle) = size(p)
output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p))
output_size(p::Plan, s::IRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
output_size(p::Plan, ::UnitaryAdjointStyle) = size(p)

struct AdjointPlan{T,P<:Plan} <: Plan{T}
p::P
Expand All @@ -604,9 +663,7 @@ end
(p::Plan)'
adjoint(p::Plan)
Form the adjoint operator of an FFT plan. Returns a plan that performs the adjoint operation of
the original plan. Note that this differs from the corresponding backwards plan in the case of real
FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref).
Return a plan that performs the adjoint operation of the original plan.
!!! note
Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`,
Expand All @@ -620,40 +677,52 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)
size(p::AdjointPlan) = output_size(p.p)
output_size(p::AdjointPlan) = size(p.p)

Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x)

"""
adjoint_mul(p::Plan, x::AbstractArray)
Multiply an array `x` by the adjoint of a plan `p`. This is equivalent to `p' * x`.
Implementations of a new adjoint style `AS <: AbstractFFTs.AdjointStyle` should define
`adjoint_mul(::Plan, ::AbstractArray, ::AS)`.
"""
adjoint_mul(p::Plan, x::AbstractArray) = adjoint_mul(p, x, AdjointStyle(p))

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
return (p.p \ x) / N
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(T, size(p), dims)
return (p \ x) / N
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
dims = fftdims(p)
N = normalization(T, size(p), dims)
halfdim = first(dims)
d = size(p.p, halfdim)
n = output_size(p.p, halfdim)
d = size(p, halfdim)
n = output_size(p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return p.p \ (x ./ convert(typeof(x), scale))
return p \ (x ./ convert(typeof(x), scale))
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(real(T), output_size(p.p), dims)
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(real(T), output_size(p), dims)
halfdim = first(dims)
n = size(p.p, halfdim)
d = output_size(p.p, halfdim)
n = size(p, halfdim)
d = output_size(p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return (convert(typeof(x), scale) ./ N) .* (p.p \ x)
return (convert(typeof(x), scale) ./ N) .* (p \ x)
end

adjoint_mul(p::Plan, x::AbstractArray, ::UnitaryAdjointStyle) = p \ x

# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))
inv(p::AdjointPlan) = adjoint(inv(p.p))
10 changes: 5 additions & 5 deletions test/testplans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N
Base.size(p::InverseTestPlan) = p.sz
Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N

AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle()
AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle()
AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle()
AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle()

function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
return TestPlan{T}(region, size(x))
Expand Down Expand Up @@ -110,8 +110,8 @@ mutable struct InverseTestRPlan{T,N,G} <: Plan{Complex{T}}
end
end

AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle()
AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d)
AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle()
AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.IRFFTAdjointStyle(p.d)

function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real}
return TestRPlan{T}(region, size(x))
Expand Down Expand Up @@ -241,7 +241,7 @@ end

Base.size(p::InplaceTestPlan) = size(p.plan)
Base.ndims(p::InplaceTestPlan) = ndims(p.plan)
AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan)
AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan)

function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...)
return InplaceTestPlan(plan_fft(x, region; kwargs...))
Expand Down

0 comments on commit 5c23f4b

Please sign in to comment.