Skip to content

Commit

Permalink
Merge pull request #96 from fjebaker/fergus/goodness
Browse files Browse the repository at this point in the history
Goodness
  • Loading branch information
fjebaker authored May 12, 2024
2 parents 059deb1 + 31916e1 commit e1899a0
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 26 deletions.
1 change: 1 addition & 0 deletions src/SpectralFitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ include("fitting/binding.jl")
include("fitting/multi-cache.jl")
include("fitting/methods.jl")
include("fitting/statistics.jl")
include("fitting/goodness.jl")

include("plotting-recipes.jl")

Expand Down
5 changes: 4 additions & 1 deletion src/datasets/ogip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,10 @@ function read_rmf(path::String; T::Type = Float64)
_build_reponse_matrix(header, rmf, channels, T)
end

function find_extension(fits, extension::T) where {T <: Union{<:AbstractString, <:AbstractVector}}
function find_extension(
fits,
extension::T,
) where {T<:Union{<:AbstractString,<:AbstractVector}}
# find the correct extensions
i::Int = 1
for hdu in fits
Expand Down
40 changes: 34 additions & 6 deletions src/datasets/spectraldata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,20 @@ end

# constructor

SpectralData(paths::SpectralDataPaths; kwargs...) =
_dataset_from_ogip(paths; kwargs...)
SpectralData(paths::SpectralDataPaths; kwargs...) = _dataset_from_ogip(paths; kwargs...)

function SpectralData(
spectrum::Spectrum,
response::ResponseMatrix;
# try to match the domains of the response matrix to the data
match_domains = true,
background = missing,
ancillary = missing,
)
domain = _make_domain_vector(spectrum, response)
energy_low, energy_high = _make_energy_vector(spectrum, response)
data_mask = BitVector(fill(true, size(spectrum.data)))
SpectralData(
data = SpectralData(
spectrum,
response,
background,
Expand All @@ -79,6 +80,14 @@ function SpectralData(
domain,
data_mask,
)
if !check_domains(data)
if match_domains
match_domains!(data)
else
@warn "The spectrum and response domains are unmatched. Use `match_domains!` to remedy. Results will assume you know what you're doing"
end
end
return data
end

supports_contiguosly_binned(::Type{<:SpectralData}) = true
Expand Down Expand Up @@ -143,8 +152,8 @@ function objective_transformer(
_transformer!!
end

bin_widths(dataset::SpectralData) =
(dataset.energy_high.-dataset.energy_low)[dataset.data_mask]
unmasked_bin_widths(dataset::SpectralData) = dataset.energy_high .- dataset.energy_low
bin_widths(dataset::SpectralData) = unmasked_bin_widths(dataset)[dataset.data_mask]
has_background(dataset::SpectralData) = !ismissing(dataset.background)
has_ancillary(dataset::SpectralData) = !ismissing(dataset.ancillary)

Expand Down Expand Up @@ -305,10 +314,29 @@ function _make_energy_vector(spec::Spectrum, resp::ResponseMatrix{T}) where {T}
resp.channel_bins_high,
)
high = full_domain[2:end]
# full domain becomes the low
resize!(full_domain, length(high))
full_domain, high
end

function check_domains(data::SpectralData)
(length(data.spectrum.channels) == length(data.response.channels)) &&
(all(i -> i in data.response.channels, data.spectrum.channels))
end

function match_domains!(data::SpectralData)
# drop parts of the response matrix that aren't in the spectrum
I = filter(i -> i data.spectrum.channels, data.response.channels)
data.response = ResponseMatrix(
data.response.matrix[I, :],
data.response.channels[I],
data.response.channel_bins_low[I],
data.response.channel_bins_high[I],
data.response.bins_low,
data.response.bins_high,
)
end

macro _forward_SpectralData_api(args)
if args.head !== :.
error("Bad syntax")
Expand Down Expand Up @@ -384,7 +412,7 @@ function _printinfo(io, data::SpectralData{T}) where {T}
Crayons.Crayon(reset = true),
" with ",
Crayons.Crayon(foreground = :cyan),
length(data.energy_low[data.data_mask]) - 1,
length(data.energy_low[data.data_mask]),
Crayons.Crayon(reset = true),
" active channels:",
)
Expand Down
9 changes: 7 additions & 2 deletions src/fitting/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,15 @@ function _f_objective(config::FittingConfig)
end
end

function finalize(config::FittingConfig, params; statistic = ChiSquared())
function finalize(
config::FittingConfig,
params;
statistic = ChiSquared(),
σparams = nothing,
)
y = _f_objective(config)(config.domain, params)
chi2 = measure(statistic, config.objective, y, config.variance)
FittingResult(chi2, params, config)
FittingResult(chi2, params, σparams, config)
end

supports_autodiff(config::FittingConfig{<:JuliaImplementation}) = true
Expand Down
44 changes: 44 additions & 0 deletions src/fitting/goodness.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

function goodness(
result::AbstractFittingResult,
u::AbstractVector{T},
σ::AbstractVector{T};
N = 1000,
stat = ChiSquared(),
distribution = Distributions.Normal,
refit = true,
) where {T}
measures = zeros(T, N)
config = deepcopy(result.config)
for i in eachindex(measures)
# sample the next parameters
for i in eachindex(u)
m = u[i]
d = σ[i]
# TODO: respect the upper and lower bounds of the parameters
distr = Distributions.Truncated(
distribution(m, d),
get_lowerlimit(config.parameters[i]),
get_upperlimit(config.parameters[i]),
)
set_value!(config.parameters[i], rand(distr))
end

if refit
new_result = fit(config, LevenbergMarquadt())
end
measures[i] = measure(stat, new_result)
end

perc = 100 * count(<(result.χ2), measures) / N
@info "% with measure < result = $(perc)"

measures
end

function goodness(result::AbstractFittingResult, σu = estimated_error(result); kwargs...)
@assert !isnothing(σu) "σ cannot be nothing, else algorithm has no parameter intervals to sample from."
goodness(result, estimated_params(result), σu; kwargs...)
end

export goodness
35 changes: 28 additions & 7 deletions src/fitting/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,26 @@ function _unpack_fitting_configuration(prob; kwargs...)
kwargs, config
end

function configuration(prob::FittingProblem; kwargs...)
kw, config = _unpack_fitting_configuration(prob; kwargs...)
if length(kw) > 0
throw("Unknown keyword arguments: $(kw)")
end
config
end

function fit(prob::FittingProblem, args...; kwargs...)
method_kwargs, config = _unpack_fitting_configuration(prob; kwargs...)
@inline fit(config, args...; method_kwargs...)
end

function fit(
prob::FittingProblem,
config::FittingConfig,
alg::LevenbergMarquadt;
verbose = false,
max_iter = 1000,
kwargs...,
method_kwargs...,
)
method_kwargs, config = _unpack_fitting_configuration(prob; kwargs...)
lsq_result = _lsq_fit(
_f_objective(config),
config.domain,
Expand All @@ -70,18 +82,27 @@ function fit(
method_kwargs...,
)
params = LsqFit.coef(lsq_result)
finalize(config, params)
σ = try
LsqFit.standard_errors(lsq_result)
catch e
if e isa LinearAlgebra.SingularException
@warn "No parameter uncertainty estimation due to error: $e"
nothing
else
throw(e)
end
end
finalize(config, params; σparams = σ)
end

function fit(
prob::FittingProblem,
config::FittingConfig,
statistic::AbstractStatistic,
optim_alg;
verbose = false,
autodiff = nothing,
kwargs...,
method_kwargs...,
)
method_kwargs, config = _unpack_fitting_configuration(prob; kwargs...)
objective = _f_wrap_objective(statistic, config)
u0 = get_value.(config.parameters)
lower = get_lowerlimit.(config.parameters)
Expand Down
18 changes: 16 additions & 2 deletions src/fitting/multi-cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,13 @@ function finalize(
config::FittingConfig{Impl,<:MultiModelCache},
params;
statistic = ChiSquared(),
σparams = nothing,
) where {Impl}
domain = config.domain
cache = config.cache
results = map(enumerate(cache.caches)) do (i, ch)
p = @views params[cache.parameter_mapping[i]]
σp = @views isnothing(σparams) ? nothing : σparams[cache.parameter_mapping[i]]

domain_start, domain_end = _get_range(cache.domain_mapping, i)
objective_start, objective_end = _get_range(cache.objective_mapping, i)
Expand All @@ -121,7 +123,19 @@ function finalize(
output,
config.variance[objective_start:objective_end],
)
(chi2, p)
(; chi2, p, σp)
end
MultiFittingResult(first.(results), last.(results), config)

unc = getindex.(results, :σp)
unc_or_nothing = if any(isnothing, unc)
nothing
else
unc
end
MultiFittingResult(
getindex.(results, :chi2),
getindex.(results, :p),
unc_or_nothing,
config,
)
end
33 changes: 25 additions & 8 deletions src/fitting/result.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ export FittingResult,
invoke_result,
update_model!

function _pretty_print_result(model, u, chi2)
function _pretty_print_result(model, u, σ, chi2)
ppx2 = prettyfloat(chi2)
ppu = join((prettyfloat(i) for i in u), ", ")
ppσ = isnothing(σ) ? nothing : join((prettyfloat(i) for i in σ), ", ")
"""
Model: $(model)
. u : [$(ppu)]
. σᵤ : [$(ppσ)]
. χ² : $(ppx2)
"""
end
Expand All @@ -25,9 +27,12 @@ struct FittingResultSlice{C,V,U,T} <: AbstractFittingResult
objective::V
variance::V
u::U
σu::Union{Nothing,U}
χ2::T
end

estimated_error(r::FittingResultSlice) = r.σu
estimated_params(r::FittingResultSlice) = r.u
measure(stat::AbstractStatistic, slice::FittingResultSlice) = measure(stat, slice, slice.u)

function measure(stat::AbstractStatistic, slice::FittingResultSlice, u)
Expand All @@ -40,19 +45,24 @@ function invoke_result(slice::FittingResultSlice, u)
end

function _pretty_print(slice::FittingResultSlice)
"FittingResultSlice:\n" * _pretty_print_result(slice.cache.model, slice.u, slice.χ2)
"FittingResultSlice:\n" *
_pretty_print_result(slice.cache.model, slice.u, slice.σu, slice.χ2)
end

function Base.show(io::IO, ::MIME"text/plain", @nospecialize(slice::FittingResultSlice))
print(io, encapsulate(_pretty_print(slice)))
end

struct FittingResult{T,K,C} <: AbstractFittingResult
struct FittingResult{T,U,C} <: AbstractFittingResult
χ2::T
u::K
u::U
σu::Union{Nothing,U}
config::C
end

estimated_error(r::FittingResult) = r.σu
estimated_params(r::FittingResult) = r.u

measure(stat::AbstractStatistic, slice::FittingResult, args...) =
measure(stat, slice[1], args...)

Expand All @@ -69,6 +79,7 @@ function Base.getindex(result::FittingResult, i)
result.config.objective,
result.config.variance,
result.u,
result.σu,
result.χ2,
)
else
Expand All @@ -77,22 +88,27 @@ function Base.getindex(result::FittingResult, i)
end

function _pretty_print(res::FittingResult)
"FittingResult:\n" * _pretty_print_result(res.config.cache.model, res.u, res.χ2)
"FittingResult:\n" * _pretty_print_result(res.config.cache.model, res.u, res.σu, res.χ2)
end

function Base.show(io::IO, ::MIME"text/plain", @nospecialize(res::FittingResult))
print(io, encapsulate(_pretty_print(res)))
end

struct MultiFittingResult{T,K,C} <: AbstractFittingResult
struct MultiFittingResult{T,U,C} <: AbstractFittingResult
χ2s::Vector{T}
us::K
us::U
σus::Union{Nothing,U}
config::C
end

estimated_error(r::MultiFittingResult) = r.σus
estimated_params(r::MultiFittingResult) = r.us

function Base.getindex(result::MultiFittingResult, i::Int)
cache = result.config.cache.caches[i]
u = result.us[i]
σu = isnothing(result.σus) ? nothing : result.σus[i]
chi2 = result.χ2s[i]
d_start, d_end = _get_range(result.config.cache.domain_mapping, i)
o_start, o_end = _get_range(result.config.cache.objective_mapping, i)
Expand All @@ -102,6 +118,7 @@ function Base.getindex(result::MultiFittingResult, i::Int)
result.config.objective[o_start:o_end],
result.config.variance[o_start:o_end],
u,
σu,
chi2,
)
end
Expand All @@ -114,7 +131,7 @@ function Base.show(io::IO, ::MIME"text/plain", @nospecialize(res::MultiFittingRe
print(buff, " ")
for i = 1:length(res.us)
slice = res[i]
b = _pretty_print_result(slice.cache.model, slice.u, slice.χ2)
b = _pretty_print_result(slice.cache.model, slice.u, slice.σu, slice.χ2)
r = indent(b, 1)
print(buff, r)
end
Expand Down

0 comments on commit e1899a0

Please sign in to comment.