Skip to content

Commit

Permalink
Merge pull request #97 from fjebaker/fergus/simulate
Browse files Browse the repository at this point in the history
Simulation
  • Loading branch information
fjebaker authored May 12, 2024
2 parents e1899a0 + 7937065 commit 1426cc7
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 11 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
4 changes: 4 additions & 0 deletions src/SpectralFitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import Pkg.MiniProgressBars: MiniProgressBar, start_progress, end_progress, show
import Distributions
import ConstructionBase

import Random

using FITSIO
using SparseArrays
using Surrogates
Expand Down Expand Up @@ -73,6 +75,8 @@ include("fitting/binding.jl")
include("fitting/multi-cache.jl")
include("fitting/methods.jl")
include("fitting/statistics.jl")

include("simulate.jl")
include("fitting/goodness.jl")

include("plotting-recipes.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/fitting/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct SpectralCache{M,O,T,K,P,TransformerType} <: AbstractFittingCache
calculated_objective::T
output_cache::K
parameter_cache::P
transfomer!!::TransformerType
transformer!!::TransformerType
function SpectralCache(
layout::AbstractDataLayout,
model::M,
Expand Down Expand Up @@ -73,7 +73,7 @@ function _invoke_and_transform!(cache::SpectralCache, domain, params)
parameters = _get_parameters(cache.parameter_cache, params)

output = invokemodel!(model_output, domain, cache.model, parameters)
cache.transfomer!!(calc_obj, domain, output)
cache.transformer!!(calc_obj, domain, output)

output_vector = get_tmp(cache.output_cache, params)
output_vector .= calc_obj
Expand Down
12 changes: 9 additions & 3 deletions src/fitting/goodness.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

function goodness(
result::AbstractFittingResult,
u::AbstractVector{T},
Expand All @@ -7,7 +6,9 @@ function goodness(
stat = ChiSquared(),
distribution = Distributions.Normal,
refit = true,
kwargs...,
) where {T}
x = similar(u)
measures = zeros(T, N)
config = deepcopy(result.config)
for i in eachindex(measures)
Expand All @@ -21,13 +22,18 @@ function goodness(
get_lowerlimit(config.parameters[i]),
get_upperlimit(config.parameters[i]),
)
set_value!(config.parameters[i], rand(distr))
x[i] = rand(distr)
end

simulate!(config, x; kwargs...)

if refit
new_result = fit(config, LevenbergMarquadt())
measures[i] = measure(stat, new_result)
else
measures[i] =
measure(stat, config.objective, invoke_result(result, x), config.variance)
end
measures[i] = measure(stat, new_result)
end

perc = 100 * count(<(result.χ2), measures) / N
Expand Down
4 changes: 2 additions & 2 deletions src/fitting/multi-cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ function finalize(
unc
end
MultiFittingResult(
getindex.(results, :chi2),
getindex.(results, :p),
getindex.(results, :chi2),
getindex.(results, :p),
unc_or_nothing,
config,
)
Expand Down
8 changes: 4 additions & 4 deletions src/plotting-recipes.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using RecipesBase

plotting_domain(dataset::AbstractDataset) = spectrum_energy(dataset)
plotting_domain(dataset::InjectiveData) = dataset.domain

@recipe function _plotting_func(dataset::InjectiveData; data_layout = OneToOne())
seriestype --> :scatter
markersize --> 1.0
Expand Down Expand Up @@ -39,15 +42,12 @@ end
ylabel --> objective_units(dataset)
label --> make_label(dataset)
minorgrid --> true
x = spectrum_energy(dataset)
x = plotting_domain(dataset)

I = @. !isinf(x) && !isinf(rate)
@views (x[I], rate[I])
end

plotting_domain(dataset::AbstractDataset) = spectrum_energy(dataset)
plotting_domain(dataset::InjectiveData) = dataset.domain

@recipe function _plotting_func(dataset::AbstractDataset, result::FittingResult)
label --> "fit"
seriestype --> :stepmid
Expand Down
72 changes: 72 additions & 0 deletions src/simulate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
mutable struct SimulatedSpectrum{T,F} <: AbstractDataset
domain::Vector{T}
data::Vector{T}
errors::Vector{T}
units::Union{Nothing,SpectralUnits.RateOrCount}
transformer!!::F
seed::Int
end

supports_contiguosly_binned(::Type{<:SimulatedSpectrum}) = true

function make_objective(::ContiguouslyBinned, dataset::SimulatedSpectrum)
check_units_warning(dataset.units)
dataset.data
end

function make_objective_variance(::ContiguouslyBinned, dataset::SimulatedSpectrum)
check_units_warning(dataset.units)
dataset.errors .^ 2
end

function make_model_domain(::ContiguouslyBinned, dataset::SimulatedSpectrum)
dataset.domain
end

bin_widths(dataset::SimulatedSpectrum) = diff(dataset.domain)
plotting_domain(dataset::SimulatedSpectrum) = dataset.domain[1:end-1] .+ bin_widths(dataset)
objective_units(dataset::SimulatedSpectrum) = dataset.units

function _printinfo(io::IO, spectrum::SimulatedSpectrum)
dmin, dmax = prettyfloat.(extrema(spectrum.data))
descr = """SimulatedSpectrum:
Units : $(spectrum.units)
. Data (min/max) : ($dmin, $dmax)
"""
print(io, descr)
end

function simulate!(
config::FittingConfig,
p;
simulate_distribution = Distributions.Normal,
rng = Random.default_rng(),
)
config.objective .= _invoke_and_transform!(config.cache, config.domain, p)
for (i, m) in enumerate(config.objective)
distr = simulate_distribution(m, sqrt(config.variance[i]))
config.objective[i] = rand(rng, distr)
end
end

function simulate(prob::FittingProblem; seed = abs(randint()), kwargs...)
kw, conf = _unpack_fitting_configuration(prob; kwargs...)
rng = Random.default_rng(seed)
Random.seed!(rng, seed)
simulate!(conf, get_value.(conf.parameters); rng = rng, kw...)
SimulatedSpectrum(
conf.domain,
conf.objective,
sqrt.(conf.variance),
nothing,
conf.cache.transformer!!,
seed,
)
end

function simulate(model::AbstractSpectralModel, dataset::AbstractDataset; kwargs...)
simulate(FittingProblem(model => dataset); kwargs...)
end


export simulate
9 changes: 9 additions & 0 deletions test/models/test-simulate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using Test
using SpectralFitting

include("../dummies.jl")

dummy_data = make_dummy_dataset((E) -> (E^(-3.0)); units = u"counts / (s * keV)")
model = PowerLaw()

sim = simulate(model, dummy_data; seed = 42)

0 comments on commit 1426cc7

Please sign in to comment.