Skip to content

Commit

Permalink
Merge pull request #105 from fjebaker/fergus/cash
Browse files Browse the repository at this point in the history
Cash money
  • Loading branch information
fjebaker authored May 21, 2024
2 parents 9686072 + 9621b43 commit 8b2c632
Show file tree
Hide file tree
Showing 15 changed files with 429 additions and 321 deletions.
8 changes: 4 additions & 4 deletions src/SpectralFitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ include("datasets/injectivedata.jl")
include("model-data-io.jl")

# include fitting api
include("fitting/result.jl")
include("fitting/cache.jl")
include("fitting/problem.jl")
include("fitting/cache.jl")
include("fitting/config.jl")
include("fitting/result.jl")
include("fitting/statistics.jl")
include("fitting/binding.jl")
include("fitting/multi-cache.jl")
include("fitting/methods.jl")
include("fitting/statistics.jl")

include("simulate.jl")
include("fitting/goodness.jl")
Expand Down
18 changes: 18 additions & 0 deletions src/abstract-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export AbstractSpectralModel,
Multiplicative,
Additive,
numbertype,
paramtype,
Convolutional,
modelkind,
AbstractSpectralModelImplementation,
Expand Down Expand Up @@ -103,6 +104,8 @@ supports(::ContiguouslyBinned, ::Type{<:AbstractSpectralModel}) = true
Get the numerical type of the model. This goes through [`FitParam`](@ref), so
that the number type returned is as close to a primative as possible.
See also [`paramtype`](@ref).
## Example
```julia
Expand All @@ -112,6 +115,21 @@ numbertype(PowerLaw()) == Float64
numbertype(::AbstractSpectralModel{T}) where {T<:Number} = T
numbertype(::AbstractSpectralModel{FitParam{T}}) where {T<:Number} = T

"""
paramtype(::AbstractSpectralModel)
Get the parameter type of the model. This, unlike [`numbertype`](@ref) does not
go through [`FitParam`](@ref).
## Example
```julia
paramtype(PowerLaw()) == FitParam{Float64}
```
"""
paramtype(::T) where {T<:AbstractSpectralModel} = paramtype(T)
paramtype(::Type{<:AbstractSpectralModel{T}}) where {T} = T

"""
modelkind(M::Type{<:AbstractSpectralModel})
modelkind(::AbstractSpectralModel)
Expand Down
116 changes: 43 additions & 73 deletions src/fitting/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,89 +80,59 @@ function _invoke_and_transform!(cache::SpectralCache, domain, params)
output_vector
end

struct FittingConfig{ImplType,CacheType,P,D,O}
cache::CacheType
parameters::P
model_domain::D
output_domain::D
objective::O
variance::O
covariance::O
function FittingConfig(
impl::AbstractSpectralModelImplementation,
cache::C,
params::P,
model_domain::D,
output_domain::D,
objective::O,
variance::O;
covariance::O = inv.(variance),
) where {C,P,D,O}
new{typeof(impl),C,P,D,O}(
cache,
params,
model_domain,
output_domain,
objective,
variance,
covariance,
)
end
struct MultiModelCache{K,N,CacheTypes<:Tuple,ParameterMappingType} <: AbstractFittingCache
caches::CacheTypes
all_outputs::K
domain_mapping::NTuple{N,Int}
output_domain_mapping::NTuple{N,Int}
objective_mapping::NTuple{N,Int}
parameter_mapping::ParameterMappingType
end

function FittingConfig(model::AbstractSpectralModel{T}, dataset::AbstractDataset) where {T}
layout = common_support(model, dataset)
model_domain = make_model_domain(layout, dataset)
output_domain = make_output_domain(layout, dataset)
objective = make_objective(layout, dataset)
variance = make_objective_variance(layout, dataset)
params::Vector{T} = collect(filter(isfree, parameter_tuple(model)))
cache = SpectralCache(
layout,
model,
model_domain,
objective,
objective_transformer(layout, dataset),
)
FittingConfig(
implementation(model),
cache,
params,
model_domain,
output_domain,
objective,
variance,
)
function _get_range(mapping::NTuple, i)
m_start = i == 1 ? 1 : mapping[i-1] + 1
m_end = mapping[i]
(m_start, m_end)
end

function _f_objective(config::FittingConfig)
function f!!(domain, parameters)
_invoke_and_transform!(config.cache, domain, parameters)
function _invoke_and_transform!(cache::MultiModelCache, domain, params)
all_outputs = get_tmp(cache.all_outputs, params)

for (i, ch) in enumerate(cache.caches)
p = @views params[cache.parameter_mapping[i]]

domain_start, domain_end = _get_range(cache.domain_mapping, i)
objective_start, objective_end = _get_range(cache.objective_mapping, i)

d = @views domain[domain_start:domain_end]
all_outputs[objective_start:objective_end] .= _invoke_and_transform!(ch, d, p)
end
end

function finalize(
config::FittingConfig,
params;
statistic = ChiSquared(),
σparams = nothing,
)
y = _f_objective(config)(config.model_domain, params)
chi2 = measure(statistic, config.objective, y, config.variance)
FittingResult(chi2, params, σparams, config)
all_outputs
end

supports_autodiff(config::FittingConfig{<:JuliaImplementation}) = true
supports_autodiff(config::FittingConfig) = false
function _build_parameter_mapping(model::FittableMultiModel, bindings)
parameters = map(m -> collect(filter(isfree, parameter_tuple(m))), model.m)
parameters_counts = _accumulated_indices(map(length, parameters))

function Base.show(io::IO, @nospecialize(config::FittingConfig))
descr = "FittingConfig"
print(io, descr)
all_parameters = reduce(vcat, parameters)

parameter_mapping, remove = _construct_bound_mapping(bindings, parameters_counts)
# remove duplicate parameters that are bound
deleteat!(all_parameters, remove)

all_parameters, parameter_mapping
end

function Base.show(io::IO, ::MIME"text/plain", @nospecialize(config::FittingConfig))
descr = "FittingConfig"
print(io, descr)
function _build_mapping_length(f, itt::Tuple)
values = map(f, itt)
mapping = _accumulated_indices(map(length, values))
values, mapping
end

export FittingConfig
_build_objective_mapping(layout::AbstractDataLayout, dataset::FittableMultiDataset) =
_build_mapping_length(i -> make_objective(layout, i), dataset.d)
_build_domain_mapping(layout::AbstractDataLayout, dataset::FittableMultiDataset) =
_build_mapping_length(i -> make_model_domain(layout, i), dataset.d)
_build_output_domain_mapping(layout::AbstractDataLayout, dataset::FittableMultiDataset) =
_build_mapping_length(i -> make_output_domain(layout, i), dataset.d)
160 changes: 160 additions & 0 deletions src/fitting/config.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
struct FittingConfig{ImplType,CacheType,StatT,ProbT,P,D,O}
cache::CacheType
stat::StatT
prob::ProbT
parameters::P
model_domain::D
output_domain::D
objective::O
variance::O
covariance::O
function FittingConfig(
impl::AbstractSpectralModelImplementation,
cache::C,
stat::AbstractStatistic,
prob::FP,
params::P,
model_domain::D,
output_domain::D,
objective::O,
variance::O;
covariance::O = inv.(variance),
) where {C<:AbstractFittingCache,FP,P,D,O}
new{typeof(impl),C,typeof(stat),FP,P,D,O}(
cache,
stat,
prob,
params,
model_domain,
output_domain,
objective,
variance,
covariance,
)
end
end

fit_statistic(::Type{<:FittingConfig{Impl,Cache,Stat}}) where {Impl,Cache,Stat} = Stat()
fit_statistic(::T) where {T<:FittingConfig} = fit_statistic(T)

function make_single_config(prob::FittingProblem, stat::AbstractStatistic)
model = prob.model.m[1]
dataset = prob.data.d[1]

layout = common_support(model, dataset)
model_domain = make_model_domain(layout, dataset)
output_domain = make_output_domain(layout, dataset)
objective = make_objective(layout, dataset)
variance = make_objective_variance(layout, dataset)
params::Vector{paramtype(model)} = collect(filter(isfree, parameter_tuple(model)))
cache = SpectralCache(
layout,
model,
model_domain,
objective,
objective_transformer(layout, dataset),
)
FittingConfig(
implementation(model),
cache,
stat,
prob,
params,
model_domain,
output_domain,
objective,
variance,
)
end

function make_multi_config(prob::FittingProblem, stat::AbstractStatistic)
impl =
all(model -> implementation(model) isa JuliaImplementation, prob.model.m) ?
JuliaImplementation() : XSPECImplementation()

layout = common_support(prob.model.m..., prob.data.d...)

variances = map(d -> make_objective_variance(layout, d), prob.data.d)
# build index mappings for pulling out the data
domains, domain_mapping = _build_domain_mapping(layout, prob.data)
output_domains, output_domain_mapping = _build_output_domain_mapping(layout, prob.data)
objectives, objective_mapping = _build_objective_mapping(layout, prob.data)
parameters, parameter_mapping = _build_parameter_mapping(prob.model, prob.bindings)

i::Int = 1
caches = map(prob.model.m) do m
c = SpectralCache(
layout,
m,
domains[i],
objectives[i],
objective_transformer(layout, prob.data.d[i]),
param_diff_cache_size = length(parameters),
)
i += 1
c
end

all_objectives = reduce(vcat, objectives)

cache = MultiModelCache(
caches,
DiffCache(similar(all_objectives)),
domain_mapping,
output_domain_mapping,
objective_mapping,
parameter_mapping,
)
FittingConfig(
impl,
cache,
stat,
prob,
parameters,
reduce(vcat, domains),
reduce(vcat, output_domains),
all_objectives,
reduce(vcat, variances),
)
end

function FittingConfig(prob::FittingProblem; stat = ChiSquared())
config = if model_count(prob) == 1 && data_count(prob) == 1
make_single_config(prob, stat)
elseif model_count(prob) == data_count(prob)
make_multi_config(prob, stat)
elseif model_count(prob) < data_count(prob)
error("Single model, many data not yet implemented.")
else
error("Multi model, single data not yet implemented.")
end

return config
end

function _unpack_config(prob::FittingProblem; stat = ChiSquared(), kwargs...)
config = FittingConfig(prob; stat = stat)
kwargs, config
end

function _f_objective(config::FittingConfig)
function f!!(domain, parameters)
_invoke_and_transform!(config.cache, domain, parameters)
end
end


supports_autodiff(config::FittingConfig{<:JuliaImplementation}) = true
supports_autodiff(config::FittingConfig) = false

function Base.show(io::IO, @nospecialize(config::FittingConfig))
descr = "FittingConfig"
print(io, descr)
end

function Base.show(io::IO, ::MIME"text/plain", @nospecialize(config::FittingConfig))
descr = "FittingConfig"
print(io, descr)
end

export FittingConfig
Loading

0 comments on commit 8b2c632

Please sign in to comment.