diff --git a/Project.toml b/Project.toml index 822429e8..247daa7a 100644 --- a/Project.toml +++ b/Project.toml @@ -11,8 +11,8 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" -FITSIO = "525bcba6-941b-5504-bd06-fd0dc1a4d2eb" FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +FITSIO = "525bcba6-941b-5504-bd06-fd0dc1a4d2eb" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" diff --git a/docs/src/examples/sherpa-example.md b/docs/src/examples/sherpa-example.md index 74badc75..de92809a 100644 --- a/docs/src/examples/sherpa-example.md +++ b/docs/src/examples/sherpa-example.md @@ -69,7 +69,7 @@ The result card tells us a little bit about how successful the fit was. We furth ```@example sherpa plot(data, markersize = 3) -plot!(data, result) +plot!(result) ``` We can create a contour plot of the fit statistic by evaluating the result everywhere on the grid and measuring the statistic: diff --git a/docs/src/walkthrough.md b/docs/src/walkthrough.md index 690095f6..6a851412 100644 --- a/docs/src/walkthrough.md +++ b/docs/src/walkthrough.md @@ -99,7 +99,7 @@ plot(data, xscale = :log10, yscale = :log10 ) -plot!(data, result) +plot!(result) ``` Our model does not account for the high energy range well. We can ignore that range for now, and select everything from 0 to 15 keV and refit: @@ -115,7 +115,7 @@ plot(data, xscale = :log10, yscale = :log10 ) -plot!(data, result, label = "PowerLaw") +plot!(result, label = "PowerLaw") ``` The result is not yet baked into our model, and represents just the outcome of the fit. To update the parameters and errors in the model, we can use [`update_model!`](@ref) @@ -179,7 +179,7 @@ plot(data, xscale = :log10, yscale = :log10 ) -plot!(data, flux_result) +plot!(flux_result) vspan!([flux_model.c1.E_min.value, flux_model.c1.E_max.value], alpha = 0.5) ``` @@ -205,8 +205,8 @@ dp = plot(data, yscale = :log10, legend = :bottomleft, ) -plot!(dp, data, result, label = "PowerLaw $(round(result.χ2))") -plot!(dp, data, result2, label = "BlackBody $(round(result2.χ2))") +plot!(dp, result, label = "PowerLaw $(round(result.χ2))") +plot!(dp, result2, label = "BlackBody $(round(result2.χ2))") ``` Or a bremsstrahlung model: @@ -218,13 +218,13 @@ result3 = fit(prob3, LevenbergMarquadt()) ``` ```@example walk -plot!(dp, data, result3, label = "Brems $(round(result3.χ2))") +plot!(dp, result3, label = "Brems $(round(result3.χ2))") ``` Let's take a look at the residuals of these three models. There are utility methods for this in SpectralFitting.jl, but we can easily just interact with the result directly: ```@example walk -function residuals(result) +function calc_residuals(result) # select which result we want (only have one, but for generalisation to multi-model fits) r = result[1] y = invoke_result(r) @@ -234,9 +234,9 @@ end domain = SpectralFitting.plotting_domain(data) rp = hline([0], linestyle = :dash, legend = false) -plot!(rp,domain, residuals(result), seriestype = :stepmid) -plot!(rp, domain, residuals(result2), seriestype = :stepmid) -plot!(rp, domain, residuals(result3), seriestype = :stepmid) +plot!(rp,domain, calc_residuals(result), seriestype = :stepmid) +plot!(rp, domain, calc_residuals(result2), seriestype = :stepmid) +plot!(rp, domain, calc_residuals(result3), seriestype = :stepmid) rp ``` @@ -246,6 +246,19 @@ We can compose this figure with our previous one, and change to a linear x scale plot(dp, rp, layout = grid(2, 1, heights = [0.7, 0.3]), link = :x, xscale = :linear) ``` +We can do all that plotting work in one go with the [`plotresult`](@ref) recipe: + +```@example walk +plotresult( + data, + [result, result2, result3], + ylims = (0.001, 2.0), + xscale = :log10, + yscale = :log10, + legend = :bottomleft, +) +``` + Let's modify the black body model with a continuum component ```@example walk @@ -281,7 +294,7 @@ plot(data, yscale = :log10, legend = :bottomleft, ) -plot!(data, bbpl_result) +plot!(bbpl_result) ``` Update the model and fix the black body temperature to 2 keV: @@ -306,7 +319,7 @@ bbpl_result2 = fit( Overplotting this new result: ```@example walk -plot!(data, bbpl_result2) +plot!(bbpl_result2) ``` ## MCMC diff --git a/src/SpectralFitting.jl b/src/SpectralFitting.jl index c795457a..f450b1b8 100644 --- a/src/SpectralFitting.jl +++ b/src/SpectralFitting.jl @@ -81,8 +81,6 @@ include("fitting/methods.jl") include("simulate.jl") include("fitting/goodness.jl") -include("plotting-recipes.jl") - # include xspec models include("xspec-models/additive.jl") include("xspec-models/multiplicative.jl") @@ -94,11 +92,17 @@ include("julia-models/additive.jl") include("julia-models/multiplicative.jl") include("julia-models/convolutional.jl") +include("plots-recipes.jl") + function __init__() # check if we have the minimum model data already _check_model_directory_present() # init HEASOFT - ccall((:FNINIT, libXSFunctions), Cvoid, ()) + if get(ENV, "SPECTRAL_FITTING_XSPEC_INIT", "") == "" + ccall((:FNINIT, libXSFunctions), Cvoid, ()) + # set an environment variable so we don't accidentally init again + ENV["SPECTRAL_FITTING_XSPEC_INIT"] = "true" + end end end # module diff --git a/src/datasets/spectraldata.jl b/src/datasets/spectraldata.jl index 0eac96bf..c41dfdc0 100644 --- a/src/datasets/spectraldata.jl +++ b/src/datasets/spectraldata.jl @@ -142,6 +142,19 @@ function SpectralData( return data end +function Base.copy(data::SpectralData) + SpectralData( + data.spectrum, + data.response, + data.background, + data.ancillary, + copy(data.energy_low), + copy(data.energy_high), + copy(data.domain), + copy(data.data_mask), + ) +end + supports(::Type{<:SpectralData}) = (ContiguouslyBinned(),) function _objective_to_units(dataset::SpectralData, obj, units) @@ -346,7 +359,7 @@ function _adjust_by_unit_difference!( end function adjust_to_units!(data::SpectralData, s::Spectrum, x, units) - ΔE = bin_widths(data) + ΔE = unmasked_bin_widths(data) exposure_time = s.exposure_time _adjust_by_unit_difference!(ΔE, exposure_time, x, units / s.units) x @@ -454,6 +467,47 @@ function match_domains!(data::SpectralData) ) end +function background_dataset(data::SpectralData) + new_data = copy(data) + new_data.spectrum = new_data.background + new_data.background = nothing + new_data +end + +function rescale!(data::SpectralData) + @. data.spectrum.data = data.spectrum.data / data.spectrum.area_scale + @. data.spectrum.errors = data.spectrum.errors / data.spectrum.area_scale + data.spectrum.area_scale = 1 + if has_background(data) + rescale_background!(data) + end + data +end + +function rescale_background!(data::SpectralData) + if has_background(data) + data.background.data = _scaled_background( + data.background.data, + data.background.area_scale, + data.spectrum.background_scale, + data.background.background_scale, + ) + data.background.errors = _scaled_background( + data.background.errors, + data.background.area_scale, + data.spectrum.background_scale, + data.background.background_scale, + ) + + data.spectrum.background_scale = 1 + data.background.background_scale = 1 + data.background.area_scale = 1 + else + throw("No background to subtract") + end + data +end + macro _forward_SpectralData_api(args) if args.head !== :. error("Bad syntax") @@ -511,6 +565,16 @@ macro _forward_SpectralData_api(args) SpectralFitting.error_statistic(getfield(t, $(field))) SpectralFitting.set_units!(t::$(T), args...) = SpectralFitting.set_units!(getfield(t, $(field)), args...) + SpectralFitting.background_dataset(t::$(T), args...; kwargs...) = + SpectralFitting.background_dataset(getfield(t, $(field)), args...; kwargs...) + SpectralFitting.rescale_background!(t::$(T), args...; kwargs...) = + SpectralFitting.rescale_background!( + getfield(t, $(field)), + args...; + kwargs..., + ) + SpectralFitting.rescale!(t::$(T), args...; kwargs...) = + SpectralFitting.rescale!(getfield(t, $(field)), args...; kwargs...) end |> esc end @@ -609,4 +673,7 @@ export SpectralData, normalize!, subtract_background!, set_domain!, - set_units! + set_units!, + background_dataset, + rescale!, + rescale_background! diff --git a/src/datasets/spectrum.jl b/src/datasets/spectrum.jl index cfaee641..69b30136 100644 --- a/src/datasets/spectrum.jl +++ b/src/datasets/spectrum.jl @@ -150,7 +150,7 @@ end error_statistic(spec::Spectrum) = spec.error_statistics function subtract_background!(spectrum::Spectrum, background::Spectrum) - # should all already be rates + @assert spectrum.units == u"counts" # errors added in quadrature # TODO: this needs fixing to propagate errors properly data_variance = spectrum.errors .^ 2 @@ -163,6 +163,8 @@ function subtract_background!(spectrum::Spectrum, background::Spectrum) background.area_scale, spectrum.background_scale, background.background_scale, + spectrum.exposure_time, + background.exposure_time, ) @. spectrum.errors = √abs(spectrum.errors) _subtract_background!( @@ -173,12 +175,21 @@ function subtract_background!(spectrum::Spectrum, background::Spectrum) background.area_scale, spectrum.background_scale, background.background_scale, + spectrum.exposure_time, + background.exposure_time, ) spectrum end -_subtract_background!(output, spec, back, aD, aB, bD, bB) = - @. output = (spec / aD) - (bD / bB) * (back / aB) +""" +Does the background subtraction and returns units of counts. That means we have +multiplied through by a factor ``t_D`` relative to the reference equation (2.3) +in the XSPEC manual. +""" +_subtract_background!(output, spec, back, aD, aB, bD, bB, tD, tB) = + @. output = (spec / (aD)) - (tD / tB) * _scaled_background(back, aB, bD, bB) + +_scaled_background(back, aB, bD, bB) = (bD / bB) * (back / aB) export Spectrum diff --git a/src/fitting/result.jl b/src/fitting/result.jl index 49ac4d8d..3a7de68a 100644 --- a/src/fitting/result.jl +++ b/src/fitting/result.jl @@ -35,13 +35,19 @@ end get_cache(f::FittingResultSlice) = f.parent.config.cache get_model(f::FittingResultSlice) = f.parent.config.prob.model.m[f.index] get_dataset(f::FittingResultSlice) = f.parent.config.prob.data.d[f.index] +fit_statistic(f::FittingResultSlice) = fit_statistic(f.parent.config) estimated_error(r::FittingResultSlice) = r.σu estimated_params(r::FittingResultSlice) = r.u -function invoke_result(slice::FittingResultSlice, u) +function invoke_result(slice::FittingResultSlice{P}, u) where {P} @assert length(u) == length(slice.u) - _invoke_and_transform!(get_cache(slice), slice.domain, u) + cache = if P <: MultiFittingResult + get_cache(slice).caches[slice.index] + else + get_cache(slice) + end + _invoke_and_transform!(cache, slice.domain, u) end function _pretty_print(slice::FittingResultSlice) @@ -77,7 +83,7 @@ function Base.getindex(result::FittingResult, i) result.config.objective[:], result.config.variance[:], result.u[:], - result.σu[:], + isnothing(result.σu) ? nothing : result.σu[:], result.χ2, ) else @@ -208,3 +214,18 @@ function finalize( config, ) end + +function determine_layout(result::FittingResultSlice) + dataset = get_dataset(result) + with_units( + common_support(get_model(result), dataset), + preferred_units(dataset, fit_statistic(result)), + ) +end + +function residuals(result::FittingResultSlice) + y = invoke_result(result, result.u) + y_residual = @. (result.objective - y) / sqrt(result.variance) + y_residual +end +residuals(result::FittingResult; kwargs...) = residuals(result[1]; kwargs...) diff --git a/src/plots-recipes.jl b/src/plots-recipes.jl new file mode 100644 index 00000000..a636b5d8 --- /dev/null +++ b/src/plots-recipes.jl @@ -0,0 +1,271 @@ +using Printf +using RecipesBase + +plotting_domain(dataset::AbstractDataset) = SpectralFitting.spectrum_energy(dataset) +plotting_domain(dataset::InjectiveData) = dataset.domain + +@recipe function _plotting_func(dataset::InjectiveData; data_layout = OneToOne()) + seriestype --> :scatter + markersize --> 1.0 + markershape --> :none + markerstrokecolor --> :auto + yerr -> dataset.codomain_variance + xerr -> dataset.domain_variance + label --> make_label(dataset) + minorgrid --> true + dataset.domain, dataset.codomain +end + +@recipe function _plotting_func( + dataset::AbstractDataset; + data_layout = ContiguouslyBinned(), + xscale = :linear, +) + seriestype --> :scatter + markersize --> 0.5 + markershape --> :none + (rate, rateerror) = ( + make_objective(data_layout, dataset), + make_objective_variance(data_layout, dataset), + ) + _yerr = sqrt.(rateerror) + yerr --> _yerr + _xerr = SpectralFitting.bin_widths(dataset) ./ 2 + xerr --> _xerr + markerstrokecolor --> :auto + xlabel --> "Energy (keV)" + ylabel --> SpectralFitting.objective_units(dataset) + label --> SpectralFitting.make_label(dataset) + minorgrid --> true + x = plotting_domain(dataset) + + if xscale == :log10 + x = plotting_domain(dataset) + _xerr = SpectralFitting.bin_widths(dataset) ./ 2 + min_x = x[1] - _xerr[1] + max_x = x[end] + _xerr[end] + xticks --> get_tickslogscale((min_x, max_x)) + end + + I = @. !isinf(x) && !isinf(rate) + @views (x[I], rate[I]) +end + +# ratio plots +@userplot plotbackground +@recipe function _plotting_func(p::plotbackground) + data = p.args[1] + background_dataset(data) +end + + +@recipe _plotting_func(::Type{<:FittingResult}, result::FittingResult) = result[1] + +@recipe function _plotting_func(result::FittingResultSlice) + label --> Printf.@sprintf("χ2=%.2f", result.χ2) + seriestype --> :stepmid + dataset = SpectralFitting.get_dataset(result) + y = invoke_result(result, result.u) + x = plotting_domain(dataset) + x, y +end + +# ratio plots +@userplot RatioPlot +@recipe function _plotting_func( + r::RatioPlot; + datacolor = :auto, + modelcolor = :auto, + label = :auto, +) + if length(r.args) != 1 || !(typeof(r.args[1]) <: AbstractFittingResult) + error( + "Ratio plots first argument must be `AbstractDataset` and second argument of type `AbstractFittingResult`.", + ) + end + + result = r.args[1] isa FittingResult ? r.args[1][1] : r.args[1] + data = get_dataset(result) + x = plotting_domain(data) + y = invoke_result(result, result.u) + + y_ratio = @. result.objective / y + + ylabel --> "Ratio [data / model]" + xlabel --> "Energy (keV)" + minorgrid --> true + + if (label == :auto) + label = make_label(data) + end + + @series begin + linestyle --> :dash + seriestype --> :hline + label --> false + color --> modelcolor + [1.0] + end + + @series begin + markerstrokecolor --> datacolor + label --> label + seriestype --> :scatter + markershape --> :none + markersize --> 0.5 + yerror --> sqrt.(result.variance) ./ y + xerror --> SpectralFitting.bin_widths(data) ./ 2 + x, y_ratio + end +end + +# residual plots +# TODO: multiple datasets require repeated calls to this function (write a wrapper later) +@userplot ResidualPlot +@recipe function _plotting_fun(r::ResidualPlot) + # check that the function has been passed one dataset and one fit result + if length(r.args) != 1 || !(typeof(r.args[1]) <: AbstractFittingResult) + error( + "Ratio plots first argument must be `AbstractDataset` and second argument of type `AbstractFittingResult`.", + ) + end + result = r.args[1] isa FittingResult ? r.args[1][1] : r.args[1] + data = SpectralFitting.get_dataset(result) + + @series begin + linestyle --> :dash + seriestype --> :hline + label --> false + [0] + end + + seriestype --> :stepmid + fill --> (0, 0.3, :auto) + y_residuals = residuals(result) + x = plotting_domain(data) + (x, y_residuals) +end + +@userplot PlotResult +@recipe function _plotting_fun(r::PlotResult; xscale = :identity) + if length(r.args) != 2 || + !(typeof(r.args[1]) <: AbstractDataset) || + !( + (typeof(r.args[2]) <: AbstractFittingResult) || + (eltype(r.args[2]) <: AbstractFittingResult) + ) + error( + "First argument must be `AbstractDataset` and second argument of (el)type `AbstractFittingResult` (got $(typeof(r.args[1])) and $(typeof(r.args[2])))", + ) + end + layout --> @layout [ + top{0.75h} + bottom{0.25h} + ] + + data = r.args[1] + results = r.args[2] isa Base.AbstractVecOrTuple ? r.args[2] : (r.args[2],) + + if xscale == :log10 + _x = plotting_domain(data) + _xerr = SpectralFitting.bin_widths(data) ./ 2 + min_x = _x[1] - _xerr[1] + max_x = _x[end] + _xerr[end] + xticks --> get_tickslogscale((min_x, max_x)) + end + + ylabel --> SpectralFitting.objective_units(data) + @series begin + subplot := 1 + xlabel := "" + data + end + + for (i, res) in enumerate(results) + color := i + 1 + r = res isa FittingResultSlice ? res : res[1] + x = plotting_domain(SpectralFitting.get_dataset(r)) + @series begin + subplot := 1 + r + end + @series begin + xlabel --> "Energy (keV)" + subplot := 2 + link := :x + seriestype --> :stepmid + yscale := :identity + ylabel := "Residuals" + ylims := :auto + label := false + fill --> (0, 0.3, :auto) + y_residuals = residuals(r) + (x, y_residuals) + end + end +end + +""" + get_tickslogscale(lims; skiplog=false) + +Return a tuple (ticks, ticklabels) for the axis limit `lims` +where multiples of 10 are major ticks with label and minor ticks have no label +skiplog argument should be set to true if `lims` is already in log scale. + +Modified from [https://github.com/JuliaPlots/Plots.jl/issues/3318](Plots.jl/#3318). +""" +function get_tickslogscale(lims::Tuple{T,T}; skiplog::Bool = false) where {T<:AbstractFloat} + mags = if skiplog + # if the limits are already in log scale + floor.(lims) + else + floor.(log10.(lims)) + end + rlims = if skiplog + 10 .^ (lims) + else + lims + end + + total_tickvalues = [] + total_ticknames = [] + + rgs = range(mags..., step = 1) + for (i, m) in enumerate(rgs) + if m >= 0 + tickvalues = range(Int(10^m), Int(10^(m + 1)); step = Int(10^m)) + ticknames = vcat( + [string(round(Int, 10^(m)))], + ["" for i = 2:9], + [string(round(Int, 10^(m + 1)))], + ) + else + tickvalues = range(10^m, 10^(m + 1); step = 10^m) + ticknames = vcat([string(10^(m))], ["" for i = 2:9], [string(10^(m + 1))]) + end + + if i == 1 + # lower bound + indexlb = findlast(x -> x < rlims[1], tickvalues) + if isnothing(indexlb) + indexlb = 1 + end + else + indexlb = 1 + end + if i == length(rgs) + # higher bound + indexhb = findfirst(x -> x > rlims[2], tickvalues) + if isnothing(indexhb) + indexhb = 10 + end + else + # do not take the last index if not the last magnitude + indexhb = 9 + end + + total_tickvalues = vcat(total_tickvalues, tickvalues[indexlb:indexhb]) + total_ticknames = vcat(total_ticknames, ticknames[indexlb:indexhb]) + end + return (total_tickvalues, total_ticknames) +end diff --git a/src/plotting-recipes.jl b/src/plotting-recipes.jl deleted file mode 100644 index 0bf7a306..00000000 --- a/src/plotting-recipes.jl +++ /dev/null @@ -1,241 +0,0 @@ -using RecipesBase - -plotting_domain(dataset::AbstractDataset) = spectrum_energy(dataset) -plotting_domain(dataset::InjectiveData) = dataset.domain - -@recipe function _plotting_func(dataset::InjectiveData; data_layout = OneToOne()) - seriestype --> :scatter - markersize --> 1.0 - markershape --> :none - markerstrokecolor --> :auto - yerr -> dataset.codomain_variance - xerr -> dataset.domain_variance - label --> make_label(dataset) - minorgrid --> true - dataset.domain, dataset.codomain -end - -@recipe function _plotting_func( - dataset::AbstractDataset; - data_layout = ContiguouslyBinned(), -) - seriestype --> :scatter - markersize --> 0.5 - markershape --> :none - (rate, rateerror) = ( - make_objective(data_layout, dataset), - make_objective_variance(data_layout, dataset), - ) - _yerr = sqrt.(rateerror) - yerr --> _yerr - xerr --> bin_widths(dataset) ./ 2 - markerstrokecolor --> :auto - if all(>(0), rate) - yticks --> ([0.01, 0.1, 1, 10, 100], [0.01, 0.1, 1, 10, 100]) - yscale --> :log10 - end - if all(>(0), rate) - xticks --> ([1e-1, 1, 2, 5, 10, 20, 50, 100], [1e-1, 1, 2, 5, 10, 20, 50, 100]) - xscale --> :log10 - end - xlabel --> "Energy (keV)" - ylabel --> objective_units(dataset) - label --> make_label(dataset) - minorgrid --> true - x = plotting_domain(dataset) - - I = @. !isinf(x) && !isinf(rate) - @views (x[I], rate[I]) -end - -@recipe function _plotting_func(dataset::AbstractDataset, result::FittingResult) - label --> "fit" - seriestype --> :stepmid - y = _f_objective(result.config)(result.config.model_domain, result.u) - x = plotting_domain(dataset) - if length(y) != length(x) - error( - "Domain mismatch. Are you sure you're plotting the result with the right dataset?", - ) - end - x, y -end - -@recipe function _plotting_func(dataset::AbstractDataset, result::FittingResultSlice) - label --> "fit" - seriestype --> :stepmid - y = invoke_result(result, result.u) - x = plotting_domain(dataset) - x, y -end - -# ratio plots -@userplot RatioPlot -@recipe function _plotting_func( - r::RatioPlot; - datacolor = :auto, - modelcolor = :auto, - label = :auto, -) - if length(r.args) != 2 || - !(typeof(r.args[1]) <: AbstractDataset) || - !(typeof(r.args[2]) <: AbstractFittingResult) - error( - "Ratio plots first argument must be `AbstractDataset` and second argument of type `AbstractFittingResult`.", - ) - end - - data = r.args[1] - x = plotting_domain(data) - result = r.args[2] isa FittingResult ? r.args[2][1] : r.args[2] - y = invoke_result(result, result.u) - - y_ratio = @. result.objective / y - - ylabel --> "Ratio [data / model]" - xlabel --> "Energy (keV)" - minorgrid --> true - - if (label == :auto) - label = make_label(data) - end - - @series begin - linestyle --> :dash - seriestype --> :hline - label --> false - color --> modelcolor - [1.0] - end - - @series begin - markerstrokecolor --> datacolor - label --> label - seriestype --> :scatter - markershape --> :none - markersize --> 0.5 - yerror --> sqrt.(result.variance) ./ y - xerror --> bin_widths(data) ./ 2 - x, y_ratio - end -end - -# residual plots -# note: multiple datasets require repeated calls to this function (write a wrapper later) -@userplot ResidualPlot -@recipe function _plotting_fun( - r::ResidualPlot, - datacolor = :auto, - modelcolor = :auto, - residualcolor = :auto, - label = :auto, -) - # check that the function has been passed one dataset and one fit result - if length(r.args) != 2 || - !(typeof(r.args[1]) <: AbstractDataset) || - !(typeof(r.args[2]) <: AbstractFittingResult) - error( - "Ratio plots first argument must be `AbstractDataset` and second argument of type `AbstractFittingResult`.", - ) - end - - data = r.args[1] - x = plotting_domain(data) - # at the moment I don't understand why the following line is necessary - # I would assume result = r.args[2] which might be of type FittingResultSlice - result = r.args[2] isa FittingResult ? r.args[2][1] : r.args[2] - y = invoke_result(result, result.u) - - # residual is the difference between the model and the data in units of "sigma" so the error bars have size 1 - # this assumes we have statistics such that sigma = sqrt(variance) - should probably make this more statistically neutral - yerr = sqrt.(result.variance) - y_residual = @. (result.objective - y) / yerr - # is this the best way to ensure y_residual_error has the same type as y_residual, or should it just be fixed at Float64? - y_residual_error = ones(eltype(y_residual), length(y_residual)) - - minorgrid --> true - - if (label == :auto) - label = make_label(data) - end - - # layout --> @layout [grid(2, 1, heights=[0.7 ,0.3]), margin=0mm] - layout --> @layout [ - top{0.75h} - bottom{0.25h} - ] - margins --> (0, :mm) - - # logarithmic x-axis (might want to let this be an option) - xscale --> :log10 - filtered_array = filter(x -> x != 0, result.objective) - min_data_value = 0.8 * minimum(filtered_array) - max_data_value = 1.2 * maximum(filtered_array) - filtered_array = filter(x -> x != 0, y) - min_model_value = 0.8 * minimum(filtered_array) - max_model_value = 1.2 * maximum(filtered_array) - min_non_zero_value = min(min_data_value, min_model_value) - max_non_zero_value = max(max_data_value, max_model_value) - - # plot the data - @series begin - subplot --> 1 - yscale --> :log10 - yrange --> [min_non_zero_value, max_non_zero_value] - markerstrokecolor --> datacolor - label --> label - seriestype --> :scatter - markershape --> :none - markersize --> 0.5 - yerror --> yerr - xerror --> bin_widths(data) ./ 2 - x, result.objective - end - - # plot the model - need to fix this - @series begin - subplot --> 1 - yscale --> :log10 - yrange --> [min_non_zero_value, max_non_zero_value] - xticks --> nothing - ylabel --> "Flux (units)" - markerstrokecolor --> modelcolor - label --> :none - seriestype --> :stepmid - markershape --> :none - markersize --> 0.5 - xerror --> bin_widths(data) ./ 2 - x, y - end - - # plot the residuals - @series begin - subplot --> 2 - yscale --> :identity - xticks --> true - xlabel --> "Energy (keV)" - ylabel --> "(Data - model)/error" - markerstrokecolor --> residualcolor - label --> :none - seriestype --> :scatter - markershape --> :none - markersize --> 0.5 - yerror --> y_residual_error - xerror --> bin_widths(data) ./ 2 - x, y_residual - end - - # zero line - @series begin - subplot --> 2 - yscale --> :identity - xticks --> true - xlabel --> "Energy (keV)" - ylabel --> "(Data - model)/error" - linestyle --> :dash - seriestype --> :hline - label --> false - # not currently specifying a colour - [0.0] - end -end diff --git a/test/fitting/test-sample-data.jl b/test/fitting/test-sample-data.jl index 6e845ad6..3831cd19 100644 --- a/test/fitting/test-sample-data.jl +++ b/test/fitting/test-sample-data.jl @@ -64,7 +64,9 @@ result = fit(prob, LevenbergMarquadt()) # todo: with background subtraction data1_nobkg = deepcopy(data1) +set_units!(data1_nobkg, u"counts") subtract_background!(data1_nobkg) +set_units!(data1_nobkg, u"counts / (s * keV)") prob = FittingProblem(model, data1_nobkg) result = fit(prob, LevenbergMarquadt())