Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cash money #105

Merged
merged 4 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/SpectralFitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ include("datasets/injectivedata.jl")
include("model-data-io.jl")

# include fitting api
include("fitting/result.jl")
include("fitting/cache.jl")
include("fitting/problem.jl")
include("fitting/cache.jl")
include("fitting/config.jl")
include("fitting/result.jl")
include("fitting/statistics.jl")
include("fitting/binding.jl")
include("fitting/multi-cache.jl")
include("fitting/methods.jl")
include("fitting/statistics.jl")

include("simulate.jl")
include("fitting/goodness.jl")
Expand Down
18 changes: 18 additions & 0 deletions src/abstract-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export AbstractSpectralModel,
Multiplicative,
Additive,
numbertype,
paramtype,
Convolutional,
modelkind,
AbstractSpectralModelImplementation,
Expand Down Expand Up @@ -103,6 +104,8 @@ supports(::ContiguouslyBinned, ::Type{<:AbstractSpectralModel}) = true
Get the numerical type of the model. This goes through [`FitParam`](@ref), so
that the number type returned is as close to a primative as possible.

See also [`paramtype`](@ref).

## Example

```julia
Expand All @@ -112,6 +115,21 @@ numbertype(PowerLaw()) == Float64
numbertype(::AbstractSpectralModel{T}) where {T<:Number} = T
numbertype(::AbstractSpectralModel{FitParam{T}}) where {T<:Number} = T

"""
paramtype(::AbstractSpectralModel)

Get the parameter type of the model. This, unlike [`numbertype`](@ref) does not
go through [`FitParam`](@ref).

## Example

```julia
paramtype(PowerLaw()) == FitParam{Float64}
```
"""
paramtype(::T) where {T<:AbstractSpectralModel} = paramtype(T)
paramtype(::Type{<:AbstractSpectralModel{T}}) where {T} = T

"""
modelkind(M::Type{<:AbstractSpectralModel})
modelkind(::AbstractSpectralModel)
Expand Down
116 changes: 43 additions & 73 deletions src/fitting/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,89 +80,59 @@ function _invoke_and_transform!(cache::SpectralCache, domain, params)
output_vector
end

struct FittingConfig{ImplType,CacheType,P,D,O}
cache::CacheType
parameters::P
model_domain::D
output_domain::D
objective::O
variance::O
covariance::O
function FittingConfig(
impl::AbstractSpectralModelImplementation,
cache::C,
params::P,
model_domain::D,
output_domain::D,
objective::O,
variance::O;
covariance::O = inv.(variance),
) where {C,P,D,O}
new{typeof(impl),C,P,D,O}(
cache,
params,
model_domain,
output_domain,
objective,
variance,
covariance,
)
end
struct MultiModelCache{K,N,CacheTypes<:Tuple,ParameterMappingType} <: AbstractFittingCache
caches::CacheTypes
all_outputs::K
domain_mapping::NTuple{N,Int}
output_domain_mapping::NTuple{N,Int}
objective_mapping::NTuple{N,Int}
parameter_mapping::ParameterMappingType
end

function FittingConfig(model::AbstractSpectralModel{T}, dataset::AbstractDataset) where {T}
layout = common_support(model, dataset)
model_domain = make_model_domain(layout, dataset)
output_domain = make_output_domain(layout, dataset)
objective = make_objective(layout, dataset)
variance = make_objective_variance(layout, dataset)
params::Vector{T} = collect(filter(isfree, parameter_tuple(model)))
cache = SpectralCache(
layout,
model,
model_domain,
objective,
objective_transformer(layout, dataset),
)
FittingConfig(
implementation(model),
cache,
params,
model_domain,
output_domain,
objective,
variance,
)
function _get_range(mapping::NTuple, i)
m_start = i == 1 ? 1 : mapping[i-1] + 1
m_end = mapping[i]
(m_start, m_end)
end

function _f_objective(config::FittingConfig)
function f!!(domain, parameters)
_invoke_and_transform!(config.cache, domain, parameters)
function _invoke_and_transform!(cache::MultiModelCache, domain, params)
all_outputs = get_tmp(cache.all_outputs, params)

for (i, ch) in enumerate(cache.caches)
p = @views params[cache.parameter_mapping[i]]

domain_start, domain_end = _get_range(cache.domain_mapping, i)
objective_start, objective_end = _get_range(cache.objective_mapping, i)

d = @views domain[domain_start:domain_end]
all_outputs[objective_start:objective_end] .= _invoke_and_transform!(ch, d, p)
end
end

function finalize(
config::FittingConfig,
params;
statistic = ChiSquared(),
σparams = nothing,
)
y = _f_objective(config)(config.model_domain, params)
chi2 = measure(statistic, config.objective, y, config.variance)
FittingResult(chi2, params, σparams, config)
all_outputs
end

supports_autodiff(config::FittingConfig{<:JuliaImplementation}) = true
supports_autodiff(config::FittingConfig) = false
function _build_parameter_mapping(model::FittableMultiModel, bindings)
parameters = map(m -> collect(filter(isfree, parameter_tuple(m))), model.m)
parameters_counts = _accumulated_indices(map(length, parameters))

function Base.show(io::IO, @nospecialize(config::FittingConfig))
descr = "FittingConfig"
print(io, descr)
all_parameters = reduce(vcat, parameters)

parameter_mapping, remove = _construct_bound_mapping(bindings, parameters_counts)
# remove duplicate parameters that are bound
deleteat!(all_parameters, remove)

all_parameters, parameter_mapping
end

function Base.show(io::IO, ::MIME"text/plain", @nospecialize(config::FittingConfig))
descr = "FittingConfig"
print(io, descr)
function _build_mapping_length(f, itt::Tuple)
values = map(f, itt)
mapping = _accumulated_indices(map(length, values))
values, mapping
end

export FittingConfig
_build_objective_mapping(layout::AbstractDataLayout, dataset::FittableMultiDataset) =
_build_mapping_length(i -> make_objective(layout, i), dataset.d)
_build_domain_mapping(layout::AbstractDataLayout, dataset::FittableMultiDataset) =
_build_mapping_length(i -> make_model_domain(layout, i), dataset.d)
_build_output_domain_mapping(layout::AbstractDataLayout, dataset::FittableMultiDataset) =
_build_mapping_length(i -> make_output_domain(layout, i), dataset.d)
160 changes: 160 additions & 0 deletions src/fitting/config.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
struct FittingConfig{ImplType,CacheType,StatT,ProbT,P,D,O}
cache::CacheType
stat::StatT
prob::ProbT
parameters::P
model_domain::D
output_domain::D
objective::O
variance::O
covariance::O
function FittingConfig(
impl::AbstractSpectralModelImplementation,
cache::C,
stat::AbstractStatistic,
prob::FP,
params::P,
model_domain::D,
output_domain::D,
objective::O,
variance::O;
covariance::O = inv.(variance),
) where {C<:AbstractFittingCache,FP,P,D,O}
new{typeof(impl),C,typeof(stat),FP,P,D,O}(
cache,
stat,
prob,
params,
model_domain,
output_domain,
objective,
variance,
covariance,
)
end
end

fit_statistic(::Type{<:FittingConfig{Impl,Cache,Stat}}) where {Impl,Cache,Stat} = Stat()
fit_statistic(::T) where {T<:FittingConfig} = fit_statistic(T)

function make_single_config(prob::FittingProblem, stat::AbstractStatistic)
model = prob.model.m[1]
dataset = prob.data.d[1]

layout = common_support(model, dataset)
model_domain = make_model_domain(layout, dataset)
output_domain = make_output_domain(layout, dataset)
objective = make_objective(layout, dataset)
variance = make_objective_variance(layout, dataset)
params::Vector{paramtype(model)} = collect(filter(isfree, parameter_tuple(model)))
cache = SpectralCache(
layout,
model,
model_domain,
objective,
objective_transformer(layout, dataset),
)
FittingConfig(
implementation(model),
cache,
stat,
prob,
params,
model_domain,
output_domain,
objective,
variance,
)
end

function make_multi_config(prob::FittingProblem, stat::AbstractStatistic)
impl =
all(model -> implementation(model) isa JuliaImplementation, prob.model.m) ?
JuliaImplementation() : XSPECImplementation()

layout = common_support(prob.model.m..., prob.data.d...)

variances = map(d -> make_objective_variance(layout, d), prob.data.d)
# build index mappings for pulling out the data
domains, domain_mapping = _build_domain_mapping(layout, prob.data)
output_domains, output_domain_mapping = _build_output_domain_mapping(layout, prob.data)
objectives, objective_mapping = _build_objective_mapping(layout, prob.data)
parameters, parameter_mapping = _build_parameter_mapping(prob.model, prob.bindings)

i::Int = 1
caches = map(prob.model.m) do m
c = SpectralCache(
layout,
m,
domains[i],
objectives[i],
objective_transformer(layout, prob.data.d[i]),
param_diff_cache_size = length(parameters),
)
i += 1
c
end

all_objectives = reduce(vcat, objectives)

cache = MultiModelCache(
caches,
DiffCache(similar(all_objectives)),
domain_mapping,
output_domain_mapping,
objective_mapping,
parameter_mapping,
)
FittingConfig(
impl,
cache,
stat,
prob,
parameters,
reduce(vcat, domains),
reduce(vcat, output_domains),
all_objectives,
reduce(vcat, variances),
)
end

function FittingConfig(prob::FittingProblem; stat = ChiSquared())
config = if model_count(prob) == 1 && data_count(prob) == 1
make_single_config(prob, stat)
elseif model_count(prob) == data_count(prob)
make_multi_config(prob, stat)
elseif model_count(prob) < data_count(prob)
error("Single model, many data not yet implemented.")
else
error("Multi model, single data not yet implemented.")
end

return config
end

function _unpack_config(prob::FittingProblem; stat = ChiSquared(), kwargs...)
config = FittingConfig(prob; stat = stat)
kwargs, config
end

function _f_objective(config::FittingConfig)
function f!!(domain, parameters)
_invoke_and_transform!(config.cache, domain, parameters)
end
end


supports_autodiff(config::FittingConfig{<:JuliaImplementation}) = true
supports_autodiff(config::FittingConfig) = false

function Base.show(io::IO, @nospecialize(config::FittingConfig))
descr = "FittingConfig"
print(io, descr)
end

function Base.show(io::IO, ::MIME"text/plain", @nospecialize(config::FittingConfig))
descr = "FittingConfig"
print(io, descr)
end

export FittingConfig
Loading
Loading