Skip to content

Commit

Permalink
feat: convolve on irregular grids
Browse files Browse the repository at this point in the history
  • Loading branch information
fjebaker committed Jul 1, 2024
1 parent 88f80aa commit 7648aaa
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 16 deletions.
13 changes: 4 additions & 9 deletions src/meta-models/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,17 @@ function invoke!(output, domain, model::AsConvolution{M,T}) where {M,T}
D = promote_type(eltype(domain), T)
model_output, _ =
_reinterpret_dual(typeof(model), D, model.cache[1], length(model.domain) - 1)
convolution_cache, _ = _reinterpret_dual(
typeof(model),
D,
model.cache[2],
length(output) + length(model_output) - 1,
)
convolution_cache, _ =
_reinterpret_dual(typeof(model), D, model.cache[2], length(output))

# invoke the child model
invoke!(model_output, model.domain, model.model)

# do the convolution
convolve!(convolution_cache, output, model_output)
convolve!(convolution_cache, output, domain, model_output, model.domain)

# overwrite the output
shift = div(length(model_output), 2)
@views output .= convolution_cache[1+shift:length(output)+shift]
@views output .= convolution_cache
end

function Reflection.get_parameter_symbols(
Expand Down
46 changes: 44 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,50 @@ function _convolve_implementation!(
output
end

convolve!(output, A, kernel) = _convolve_implementation!(output, A, kernel)
function convolve!(output, A, kernel)
if length(kernel) <= length(A)
_convolve_implementation!(output, A, kernel)
else
_convolve_implementation!(output, kernel, A)
end
end
function convolve(A, kernel)
output = zeros(eltype(A), length(A) + length(kernel) - 1)
convolve!(output, A, kernel)
convolve!(output, A, A_domain, kernel, kernel_domain)
end

"""
Assumes A is binned in X1 and kernel is binned in X2. Output will also be binned on X1
"""
function _convolve_irregular_grid!(output, A, X1, kernel, X2)
@assert length(X1) == length(A) + 1
@assert length(X2) == length(kernel) + 1

function _kernel_func(x)
i1 = searchsortedfirst(X2, x)
if i1 >= length(X2) || i1 <= 1
return zero(x)
end
w = (x - X2[i1-1]) / (X2[i1] - X2[i1-1])
w * kernel[i1] + (1 - w) * kernel[i1-1]
end

fill!(output, 0)
for i in eachindex(output)
for j in eachindex(output)
x = X1[i] / X1[j]
k = A[j] * _kernel_func(x)
if k > 0
output[i] += k
end
end
end
end

function convolve!(output, A, A_domain, kernel, kernel_domain)
_convolve_irregular_grid!(output, A, A_domain, kernel, kernel_domain)
end
function convolve(A, A_domain, kernel, kernel_domain)
output = zeros(eltype(A), length(A) + length(kernel) - 1)
convolve!(output, A, A_domain, kernel, kernel_domain)
end
12 changes: 7 additions & 5 deletions test/models/test-as-convolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ domain = collect(range(0.0, 10.0, 150))

output = invokemodel(domain, model)

@test sum(output) β‰ˆ 3.2570820013702395 atol = 1e-4
@test output[10] β‰ˆ 0.0036345342427057687 atol = 1e-4
@test output[40] β‰ˆ 0.055218163108951814 atol = 1e-4
@test sum(output) β‰ˆ 4.050608829695485 atol = 1e-4
@test output[10] β‰ˆ 0.0022170544439135222 atol = 1e-4
@test output[40] β‰ˆ 0.058630601782812125 atol = 1e-4

# simulate a model spectrum
dummy_data = make_dummy_dataset((E) -> (E^(-3.0)); units = u"counts / (s * keV)")
Expand All @@ -38,7 +38,8 @@ begin
prob = FittingProblem(model => sim)
result = fit(prob, LevenbergMarquadt())
end
@test result.Ο‡2 β‰ˆ 76.15221077389369 atol = 1e-3
@test result.Ο‡2 β‰ˆ 76.71272868245076 atol = 1e-3
@test result.u[1] β‰ˆ 0.3 atol = 1e-2

# put a couple of delta emission lines together
lines =
Expand Down Expand Up @@ -66,4 +67,5 @@ begin
prob = FittingProblem(model => sim)
result = fit(prob, LevenbergMarquadt(); verbose = true)
end
@test result.Ο‡2 β‰ˆ 75.736 atol = 1e-3
@test result.Ο‡2 β‰ˆ 76.66970981760741 atol = 1e-3
@test result.u[1] β‰ˆ 3.0 atol = 1e-2

0 comments on commit 7648aaa

Please sign in to comment.