Skip to content

Commit

Permalink
feat(model): AsConvolution wrapper
Browse files Browse the repository at this point in the history
This wrapper can turn any AbstractSpectralModel into a convolutional
model, and will remove additive normalisations if required.

Convolution function is a modification of the Toeplitz matrix method,
which performs the matrix multiplications without allocating the matrix.
The benefit this has over, e.g., the FFT(W) method of taking
convolutions, in that we can propagate gradients through the operation.
  • Loading branch information
fjebaker committed Jun 29, 2024
1 parent e19a8e6 commit dc8d7bd
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/SpectralFitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ include("meta-models/wrappers.jl")
include("meta-models/table-models.jl")
include("meta-models/surrogate-models.jl")
include("meta-models/caching.jl")
include("meta-models/functions.jl")

include("poisson.jl")

Expand Down
18 changes: 13 additions & 5 deletions src/meta-models/caching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,29 @@ function AutoCache(model::AbstractSpectralModel{T,K}; abstol = 1e-3) where {T,K}
AutoCache(model, cache, abstol)
end

function _reinterpret_dual(::Type, v::AbstractArray, n::Int)
function _reinterpret_dual(
M::Type{<:AbstractSpectralModel},
::Type,
v::AbstractArray,
n::Int,
)
needs_resize = n > length(v)
if needs_resize
@warn "AutoCache: Growing dual buffer..."
@warn "$(Base.typename(M).name): Growing dual buffer..."
resize!(v, n)
end
view(v, 1:n), needs_resize
end
function _reinterpret_dual(
M::Type{<:AbstractSpectralModel},
DualType::Type{<:ForwardDiff.Dual},
v::AbstractArray{T},
n::Int,
) where {T}
n_elems = div(sizeof(DualType), sizeof(T)) * n
needs_resize = n_elems > length(v)
if needs_resize
@warn "AutoCache: Growing dual buffer..."
@warn "$(Base.typename(M).name): Growing dual buffer..."
resize!(v, n_elems)
end
reinterpret(DualType, view(v, 1:n_elems)), needs_resize
Expand All @@ -58,8 +64,10 @@ function invoke!(output, domain, model::AutoCache{M,T,K}) where {M,T,K}
_new_params = parameter_tuple(model.model)
_new_limits = (first(domain), last(domain))

output_cache, out_resized = _reinterpret_dual(D, model.cache.cache, length(output))
param_cache, _ = _reinterpret_dual(D, model.cache.params, length(_new_params))
output_cache, out_resized =
_reinterpret_dual(typeof(model), D, model.cache.cache, length(output))
param_cache, _ =
_reinterpret_dual(typeof(model), D, model.cache.params, length(_new_params))

same_domain = model.cache.domain_limits == _new_limits

Expand Down
77 changes: 77 additions & 0 deletions src/meta-models/functions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
struct AsConvolution{M,T,V,P} <: AbstractModelWrapper{M,T,Convolutional}
model::M
# the domain on which we evaluate this model
domain::V
# an additional output cache
cache::NTuple{2,Vector{P}}
function AsConvolution(
model::AbstractSpectralModel{T},
domain::V,
cache::NTuple{2,Vector{P}},
) where {T,V,P}
new{typeof(model),T,V,P}(model, domain, cache)
end
end

function AsConvolution(
model::AbstractSpectralModel{T};
domain = collect(range(0, 2, 100)),
) where {T}
output = invokemodel(domain, model)
AsConvolution(model, domain, (output, deepcopy(output)))
end

function invoke!(output, domain, model::AsConvolution{M,T}) where {M,T}
D = promote_type(eltype(domain), T)
model_output, _ =
_reinterpret_dual(typeof(model), D, model.cache[1], length(model.domain) - 1)
convolution_cache, _ = _reinterpret_dual(
typeof(model),
D,
model.cache[2],
length(output) + length(model_output) - 1,
)

# invoke the child model
invoke!(model_output, model.domain, model.model)

# do the convolution
convolve!(convolution_cache, output, model_output)

# overwrite the output
shift = div(length(model_output), 2)
@views output .= convolution_cache[1+shift:length(output)+shift]
end

function Reflection.get_parameter_symbols(
::Type{<:AsConvolution{M}},
) where {M<:AbstractSpectralModel{T,K}} where {T,K}
syms = Reflection.get_parameter_symbols(M)
if K === Additive
# we need to lose the normalisation parameter
(syms[2:end]...,)
else
syms
end
end

function Reflection.make_constructor(
M::Type{<:AsConvolution{Model}},
closures::Vector,
params::Vector,
T::Type,
) where {Model<:AbstractSpectralModel{Q,K}} where {Q,K}
num_closures = fieldcount(M) - 1 # ignore the `model` field
my_closures = closures[1:num_closures]

model_params = if K === Additive
# insert a dummy normalisation to the constructor
vcat(:(one($T)), params)
else
params
end

model_constructor =
Reflection.make_constructor(Model, closures[num_closures+1:end], model_params, T)
:($(Base.typename(M).name)($(model_constructor), $(my_closures...)))
end
49 changes: 34 additions & 15 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,47 @@
function _convolve_1d_same_domain!(
output::Vector{T},
vec_A::Vector{T},
kernel::Vector{T},
) where {T<:Real}
function _convolve_implementation!(
output::AbstractVector{T},
vec_A::AbstractVector{T},
kernel::AbstractVector{T},
) where {T<:Number}
# Based on https://discourse.julialang.org/t/97658/15
@assert length(output) == length(vec_A)
@assert length(output) == length(kernel)
J = length(vec_A)
K = length(kernel)
@assert length(output) == J + K - 1 "Ouput is $(length(output)); should be $(J + K - 1)"

fill!(output, 0)

@turbo for i in eachindex(output)
# do the kernel's side first
for i = 1:K-1
total = zero(T)
for k = 1:K
ib = (i >= k)
oa = ib ? vec_A[i-k+1] : zero(T)
total += kernel[k] * oa
end
output[i] = total
end
# now the middle
for i = K:J-1
total = zero(T)
for k = 1:K
oa = vec_A[i-k+1]
total += kernel[k] * oa
end
output[i] = total
end
# and finally the end
for i = J:(J+K-1)
total = zero(T)
for k in eachindex(output)
ib0 = (i >= k)
oa = ib0 ? vec_A[i-k+1] : zero(T)
for k = 1:K
ib = (i < J + k)
oa = ib ? vec_A[i-k+1] : zero(T)
total += kernel[k] * oa
end
output[i] = total
end
output
end

convolve!(output, A, kernel) = _convolve_1d_same_domain!(output, A, kernel)
convolve!(output, A, kernel) = _convolve_implementation!(output, A, kernel)
function convolve(A, kernel)
output = similar(A)
output = zeros(eltype(A), length(A) + length(kernel) - 1)
convolve!(output, A, kernel)
end

0 comments on commit dc8d7bd

Please sign in to comment.