Skip to content

Commit

Permalink
Merge pull request #107 from fjebaker/fergus/mcmc-docs
Browse files Browse the repository at this point in the history
Add MCMC docs and result type stability
  • Loading branch information
fjebaker authored May 24, 2024
2 parents 8b2c632 + 8f3af5f commit 6ab9ac7
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 30 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ jobs:
Pkg.add("UnicodePlots")
Pkg.add("BenchmarkTools")
Pkg.add("Surrogates")
Pkg.add("Turing")
Pkg.add("StatsPlots")
Pkg.add(url = "https://github.com/astro-group-bristol/LibXSPEC_jll.jl#master")
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()'
- run: julia --project=docs docs/make.jl
- run: julia --color=yes -tauto --project=docs docs/make.jl
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
62 changes: 61 additions & 1 deletion docs/src/walkthrough.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,4 +307,64 @@ Overplotting this new result:

```@example walk
plot!(data, bbpl_result2)
```
```

## MCMC

We can use libraries like [Pidgeons.jl](https://pigeons.run/dev/) or [Turing.jl](https://turinglang.org/) to perform Bayesian inference on our paramters. SpectralFitting.jl is designed with *BYOO* (Bring Your Own Optimizer) in mind, and so makes it relatively easy to get at the core fitting functions to be used with other packages.

Let's use Turing.jl here, which means we'll also want to use [StatsPlots.jl](https://docs.juliaplots.org/dev/generated/statsplots/) to plot our walker chains.
```@example walk
using StatsPlots
using Turing
```

Turing.jl provides enormous control over the definition of the model, and this is not control SpectralFitting.jl wants to take away from you. Although we will provide utility scripts to do the basics, here we'll show you everything step by step to give you an overview of what you can do.

Let's go back to our first model:
```@example walk
model
```

This gave a pretty good fit but the errors on our paramters are not well defined, being estimated only from a convariance matrix in the least-squares solver. MCMC can give us better confidence regions, and even help us uncover dependencies between paramters. Here we'll take all of our parameters and convert them into a Turing.jl model with use of their macro:

```@example walk
@model function mcmc_model(domain, objective, variance, f)
K ~ Normal(20.0, 1.0)
a ~ Normal(2.2, 0.3)
ηH ~ truncated(Normal(0.5, 0.1); lower = 0)
pred = f(domain, [K, a, ηH])
return objective ~ MvNormal(pred, sqrt.(variance))
end
```

A few things to note here: we use the Turing.jl sampling syntax `~` to say that a variable is sampled from a certain type of prior distribution. There are no fixed criteria for what a distribution can be, and we encourage you to consult the Turing.jl documentation to learn how to define your own custom probability distributions. In this case, we will use Gaussians for all our parameters, and for the means and standard deviations use the best fit and estimated errors.

At the moment we haven't explicitly used our model, but `f` in this case takes the roll of invoking our model, and folding through instrument responses. We call it in much the same way as [`invokemodel`](@ref), despite it going the extra step to fold our model. To instantiate this, we can use the SpectralFitting.jl helper functions:

```@example walk
config = FittingConfig(FittingProblem(model => data))
mm = mcmc_model(
make_model_domain(ContiguouslyBinned(), data),
make_objective(ContiguouslyBinned(), data),
make_objective_variance(ContiguouslyBinned(), data),
# _f_objective returns a function used to evaluate and fold the model through the data
SpectralFitting._f_objective(config),
)
nothing # hide
```

That's it! We're now ready to sample our model. Since all our models are implemented in Julia, we can use gradient-boosted samplers with automatic differentiation, such as NUTS. We'll walk 5000 itterations, just as a small example:

```@example walk
chain = sample(mm, NUTS(), 5_000)
```

In the printout we see summary statistics about or model, in this case that it has converged well (`rhat` close to 1 for all parameters), better estimates of the standard deviation, and various quantiles. We can plot our chains to make sure the caterpillers are healthy and fuzzy, making use of StatsPlots.jl recipes:

```@example walk
plot(chain)
```

Corner plots are currently broken at time of writing.
4 changes: 2 additions & 2 deletions src/fitparam.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ Base.isapprox(f1::FitParam, f2::FitParam; kwargs...) =
Base.:(==)(f1::FitParam, f2::FitParam) = f1.value == f2.value
Base.convert(T::Type{<:Number}, f::FitParam) = convert(T, f.value)

parameter_type(::Type{FitParam{T}}) where {T} = T
parameter_type(::T) where {T<:FitParam} = parameter_type(T)
paramtype(::Type{FitParam{T}}) where {T} = T
paramtype(::T) where {T<:FitParam} = paramtype(T)

function get_info_tuple(f::FitParam)
s1 = Printf.@sprintf "%.3g" get_value(f)
Expand Down
16 changes: 13 additions & 3 deletions src/fitting/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,19 @@ function _f_objective(config::FittingConfig)
end
end


supports_autodiff(config::FittingConfig{<:JuliaImplementation}) = true
supports_autodiff(config::FittingConfig) = false
function paramtype(
::FittingConfig{ImplType,CacheType,StatT,ProbT,P},
) where {ImplType,CacheType,StatT,ProbT,P}
T = eltype(P)
K = if T <: FitParam
paramtype(T)
else
T
end
Vector{K}
end
supports_autodiff(::FittingConfig{<:JuliaImplementation}) = true
supports_autodiff(::FittingConfig) = false

function Base.show(io::IO, @nospecialize(config::FittingConfig))
descr = "FittingConfig"
Expand Down
44 changes: 22 additions & 22 deletions src/fitting/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,30 +90,30 @@ function fit(
config::FittingConfig,
optim_alg;
verbose = false,
autodiff = nothing,
autodiff = _determine_ad_backend(config),
method_kwargs...,
)
objective = _f_wrap_objective(fit_statistic(config), config)
u0 = get_value.(config.parameters)
lower = get_lowerlimit.(config.parameters)
upper = get_upperlimit.(config.parameters)

_autodiff = _determine_ad_backend(config; autodiff = autodiff)
if !(autodiff isa Optimization.SciMLBase.NoAD) && (!supports_autodiff(config))
error("Model does not support automatic differentiation.")
end

lb, ub = _determine_bounds(config, autodiff)

# build problem and solve
opt_f = Optimization.OptimizationFunction(objective, _autodiff)
opt_f = Optimization.OptimizationFunction(objective, autodiff)
# todo: something is broken with passing the boundaries
opt_prob = Optimization.OptimizationProblem(
opt_f,
u0,
config.model_domain;
lb = _autodiff isa Optimization.SciMLBase.NoAD ? nothing : lower,
ub = _autodiff isa Optimization.SciMLBase.NoAD ? nothing : upper,
)
opt_prob =
Optimization.OptimizationProblem(opt_f, u0, config.model_domain; lb = lb, ub = ub)

sol = Optimization.solve(opt_prob, optim_alg; method_kwargs...)

final_stat = objective(sol.u, config.model_domain)
finalize(config, sol.u, final_stat)
# TODO: temporary fix for type instabilities in Optimizations.jl
new_pars::paramtype(config) = sol.u
final_stat = objective(new_pars, config.model_domain)
finalize(config, new_pars, final_stat)
end

function fit!(prob::FittingProblem, args...; kwargs...)
Expand All @@ -122,17 +122,17 @@ function fit!(prob::FittingProblem, args...; kwargs...)
result
end


function _determine_ad_backend(config; autodiff = nothing)
if !((isnothing(autodiff)) || (autodiff isa Optimization.SciMLBase.NoAD)) &&
!supports_autodiff(config)
error("Model does not support automatic differentiation.")
function _determine_bounds(config, ::A) where {A}
if A <: Optimization.SciMLBase.NoAD
nothing, nothing
else
get_lowerlimit.(config.parameters), get_upperlimit.(config.parameters)
end
end

if supports_autodiff(config) && isnothing(autodiff)
function _determine_ad_backend(config)
if supports_autodiff(config)
Optimization.AutoForwardDiff()
elseif !isnothing(autodiff)
autodiff
else
Optimization.SciMLBase.NoAD()
end
Expand Down
2 changes: 1 addition & 1 deletion src/fitting/result.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ end

function _pretty_print(slice::FittingResultSlice)
"FittingResultSlice:\n" *
_pretty_print_result(get_cache(slice).model, slice.u, slice.σu, slice.χ2)
_pretty_print_result(get_model(slice), slice.u, slice.σu, slice.χ2)
end

function Base.show(io::IO, ::MIME"text/plain", @nospecialize(slice::FittingResultSlice))
Expand Down

0 comments on commit 6ab9ac7

Please sign in to comment.