Skip to content

Commit

Permalink
Merge pull request #98 from fjebaker/fergus/destructuring
Browse files Browse the repository at this point in the history
Composite destructing
  • Loading branch information
fjebaker authored May 13, 2024
2 parents 1426cc7 + e9074ab commit 7265e5d
Show file tree
Hide file tree
Showing 13 changed files with 173 additions and 78 deletions.
1 change: 1 addition & 0 deletions src/SpectralFitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ include("xspec-models/convolutional.jl")
include("julia-models/model-utilities.jl")
include("julia-models/additive.jl")
include("julia-models/multiplicative.jl")
include("julia-models/convolutional.jl")

function __init__()
# check if we have the minimum model data already
Expand Down
5 changes: 0 additions & 5 deletions src/abstract-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,6 @@ updatemodel(model::AbstractSpectralModel, patch::NamedTuple) =
updatemodel(model::AbstractSpectralModel; kwargs...) =
ConstructionBase.setproperties(model; kwargs...)

@inline function updatefree(model::AbstractSpectralModel, free_params)
patch = free_parameters_to_named_tuple(free_params, model)
updatemodel(model, patch)
end

@inline function updateparameters(model::AbstractSpectralModel, params)
patch = all_parameters_to_named_tuple(params, model)
updatemodel(model, patch)
Expand Down
16 changes: 13 additions & 3 deletions src/composite-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,22 @@ ConstructionBase.constructorof(::Type{<:CompositeModel}) =
throw("Cannot be used with `CompositeModel`.")

function Base.propertynames(model::CompositeModel, private::Bool = false)
all_parameter_symbols(model)
(all_parameter_symbols(model)..., all_model_symbols(model)...)
end

# TODO: really ensure this is type stable as it could be a performance killer
Base.@constprop aggressive function _get_property(model::CompositeModel, symb::Symbol)
if symb in all_model_symbols(model)
lookup = composite_model_map(model)
return lookup[symb]
else
lookup = all_parameters_to_named_tuple(model)
return lookup[symb]
end
end

function Base.getproperty(model::CompositeModel, symb::Symbol)
lookup = all_parameters_to_named_tuple(model)
lookup[symb]
_get_property(model, symb)
end

function Base.setproperty!(model::CompositeModel, symb::Symbol, value::FitParam)
Expand Down
7 changes: 6 additions & 1 deletion src/fitting/goodness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@ function goodness(
stat = ChiSquared(),
distribution = Distributions.Normal,
refit = true,
seed = abs(randint()),
kwargs...,
) where {T}
x = similar(u)
measures = zeros(T, N)
config = deepcopy(result.config)

rng = Random.default_rng()
Random.seed!(rng, seed)

for i in eachindex(measures)
# sample the next parameters
for i in eachindex(u)
Expand All @@ -25,7 +30,7 @@ function goodness(
x[i] = rand(distr)
end

simulate!(config, x; kwargs...)
simulate!(config, x; rng = rng, kwargs...)

if refit
new_result = fit(config, LevenbergMarquadt())
Expand Down
2 changes: 1 addition & 1 deletion src/fitting/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ function fit(
σ = try
LsqFit.standard_errors(lsq_result)
catch e
if e isa LinearAlgebra.SingularException
if e isa LinearAlgebra.SingularException || e isa LinearAlgebra.LAPACKException
@warn "No parameter uncertainty estimation due to error: $e"
nothing
else
Expand Down
2 changes: 1 addition & 1 deletion src/generation/function-generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ end

function _parameter_lens(info::ModelInfo, symbols)
map(symbols) do s
:(getproperty($(info.lens), $(Meta.quot(s))))
:(getfield($(info.lens), $(Meta.quot(s))))
end
end

Expand Down
25 changes: 20 additions & 5 deletions src/generation/parsing-utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function _addinfoinvoke!(
lens::Lens,
) where {NumType}
# don't increment flux for convolutional models
if !(modelkind(model) === Convolutional)
if !(modelkind(model) === Convolutional())
inc_flux!(ga)
end
flux = get_flux_symbol(ga.objective_cache_count)
Expand Down Expand Up @@ -189,23 +189,38 @@ function _unique_parameter_symbols(infos::Vector{ModelInfo})
(param_names...,)
end

function all_parameters_to_named_tuple(model::Type{<:CompositeModel})
function _all_parameters_to_named_tuple_composite(model::Type{<:CompositeModel})
infos = getinfo(model)
lenses = reduce(vcat, map(i -> _parameter_lens(i, i.symbols), infos))
names = _unique_parameter_symbols(infos)
names, lenses
end

function all_parameters_to_named_tuple(model::Type{<:CompositeModel})
names, lenses = _all_parameters_to_named_tuple_composite(model)
:(NamedTuple{$(names)}(($(lenses...),)))
end

function model_structure(model::Type{<:CompositeModel})
counters = (; a = Ref(0), m = Ref(0), c = Ref(0))
infos = Pair{Symbol,ModelInfo}[]
expr = _addinfosymbol!(infos, counters, model, :model)
expr, infos
end

function all_model_symbols_to_models(model::Type{<:CompositeModel}; kwargs...)
_, infos = model_structure(model; kwargs...)
infos
end

_unique_model_symbol(::SpectralFitting.Additive, counters) = Symbol('a', counters.a[] += 1)
_unique_model_symbol(::SpectralFitting.Multiplicative, counters) =
Symbol('m', counters.m[] += 1)
_unique_model_symbol(::SpectralFitting.Convolutional, counters) =
Symbol('c', counters.c[] += 1)

function _destructure_for_printing(model::Type{<:CompositeModel}; lens = :(model))
counters = (; a = Ref(0), m = Ref(0), c = Ref(0))
infos = Pair{Symbol,ModelInfo}[]
expr = _addinfosymbol!(infos, counters, model, lens)
expr, infos = model_structure(model)
# reorder data structure as NamedTuple of pairs of (model, parameters)
param_names = _unique_parameter_symbols(last.(infos))
offset = 1
Expand Down
20 changes: 20 additions & 0 deletions src/generation/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,26 @@ Returns a compile-time known tuple of all models symbols.
:($(params))
end

@inline @generated function all_model_symbols(model::AbstractSpectralModel)
syms = (first.(FunctionGeneration.all_model_symbols_to_models(model))...,)
:($(syms))
end

@inline @generated function composite_model_map(model::CompositeModel)
info = FunctionGeneration.all_model_symbols_to_models(model)
syms = (first.(info)...,)
models = map(i -> i[2].lens, info)
:(NamedTuple{$(syms)}(($(models...),)))
end

@inline @generated function composite_model_parameter_map(model::CompositeModel)
info = FunctionGeneration.all_model_symbols_to_models(model)
param_syms, params = FunctionGeneration._all_parameters_to_named_tuple_composite(model)
syms = (first.(info)..., param_syms...)
models = map(i -> i[2].lens, info)
:(NamedTuple{$(syms)}(($(models...), $(params...))))
end

remake_with_parameters(model::AbstractSpectralModel, cache::ParameterCache) =
_unsafe_remake_with_parameters(model, cache.parameters)
function remake_with_parameters(model::AbstractSpectralModel, params::AbstractArray)
Expand Down
69 changes: 69 additions & 0 deletions src/julia-models/convolutional.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@

struct Log10Flux{T} <: AbstractSpectralModel{T,Convolutional}
E_min::T
E_max::T
log10Flux::T
end
function Log10Flux(;
E_min = FitParam(0.2, lower_limit = 0, frozen = true),
E_max = FitParam(2.0, lower_limit = 0, frozen = true),
log10Flux = FitParam(-10.0, lower_limit = -100, upper_limit = 100),
)
Log10Flux(E_min, E_max, log10Flux)
end

function invoke!(flux, energy, model::Log10Flux)
ilow = clamp(
_or_else(findfirst(i -> i > model.E_min, energy), length(energy) - 1),
1,
lastindex(energy) - 1,
)
ihigh = clamp(
_or_else(findfirst(i -> i > model.E_max, energy), length(energy) - 1) - 1,
1,
lastindex(energy) - 1,
)

@show ilow, ihigh

total_e_flux = zero(eltype(flux))

# low bin straddle
if ilow > 1
weight = (energy[ilow]^2 - model.E_min^2) / (energy[ilow] - energy[ilow-1])
total_e_flux += flux[ilow-1] * weight
end

for i = ilow:ihigh
f = flux[i]
e_low = energy[i]
e_high = energy[i+1]

if (e_high > e_low)
total_e_flux += f * (e_high^2 - e_low^2) / (e_high - e_low)
end
end

# high bin straddle
if ihigh > 1
weight = (model.E_max^2 - energy[ihigh+1]^2) / (energy[ihigh+2] - energy[ihigh+1])
total_e_flux += flux[ilow+1] * weight
end

# convert keV to ergs
total_e_flux = total_e_flux * 0.801096e-9

flux_exp = 10^model.log10Flux

if total_e_flux > 0
@. flux = flux * flux_exp / total_e_flux
end
end

function _or_else(value::Union{Nothing,T}, v::T)::T where {T}
if isnothing(value)
v
else
value
end
end
2 changes: 1 addition & 1 deletion src/xspec-models/convolutional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ end
function XS_CalculateFlux(;
E_min = FitParam(0.2, frozen = true),
E_max = FitParam(2.0, frozen = true),
log10Flux = FitParam(-10.0, lower_limit = -Inf, upper_limit = 0.0),
log10Flux = FitParam(-10.0, lower_limit = -100, upper_limit = 100),
)
XS_CalculateFlux(E_min, E_max, log10Flux)
end
Expand Down
33 changes: 33 additions & 0 deletions test/generation/test-parsing-utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,36 @@ expr, info = eval(t)
m2 = (DummyMultiplicative(), ("a_4", "b_4")),
m3 = (DummyMultiplicative(), ("a_5", "b_5")),
))

info = SpectralFitting.all_model_symbols(model)
@test info == (:a1, :m1, :a2, :m2, :m3)

info = SpectralFitting.composite_model_map(model)
@test info == (;
a1 = DummyAdditive(),
m1 = DummyMultiplicative(),
a2 = DummyAdditive(),
m2 = DummyMultiplicative(),
m3 = DummyMultiplicative(),
)

info = SpectralFitting.composite_model_parameter_map(model)
@test info == (;
a1 = DummyAdditive(),
m1 = DummyMultiplicative(),
a2 = DummyAdditive(),
m2 = DummyMultiplicative(),
m3 = DummyMultiplicative(),
K_1 = FitParam(1.0),
a_1 = FitParam(1.0),
b_1 = FitParam(5.0),
a_2 = FitParam(1.0),
b_2 = FitParam(5.0),
K_2 = FitParam(1.0),
a_3 = FitParam(1.0),
b_3 = FitParam(5.0),
a_4 = FitParam(1.0),
b_4 = FitParam(5.0),
a_5 = FitParam(1.0),
b_5 = FitParam(5.0),
)
8 changes: 8 additions & 0 deletions test/models/test-model-consistency.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,11 @@ end
@test isapprox.(phabs_f1 .- phabs_f2, 0.0, atol = 1e-1) |> all
end
end



y = ones(Float64, 10)
x = collect(range(0.0, 5.0, length(y) + 1))
output_xs = invokemodel!(y, x, XS_CalculateFlux()) |> copy
output_jl = invokemodel!(y, x, SpectralFitting.Log10Flux()) |> copy
@test output_xs output_jl atol = 1e-8
61 changes: 0 additions & 61 deletions test/models/test-mutations.jl

This file was deleted.

0 comments on commit 7265e5d

Please sign in to comment.