Skip to content

Commit

Permalink
Merge pull request #91 from phajy/residual-plots
Browse files Browse the repository at this point in the history
Residual plots
  • Loading branch information
fjebaker authored May 9, 2024
2 parents 11bfeb5 + 20bbd54 commit 8620a6b
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 2 deletions.
122 changes: 122 additions & 0 deletions src/plotting-recipes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,125 @@ end
x, y_ratio
end
end

# residual plots
# note: multiple datasets require repeated calls to this function (write a wrapper later)
@userplot ResidualPlot
@recipe function _plotting_fun(
r::ResidualPlot,
datacolor = :auto,
modelcolor = :auto,
residualcolor = :auto,
label = :auto,
)
# check that the function has been passed one dataset and one fit result
if length(r.args) != 2 ||
!(typeof(r.args[1]) <: AbstractDataset) ||
!(typeof(r.args[2]) <: AbstractFittingResult)
error(
"Ratio plots first argument must be `AbstractDataset` and second argument of type `AbstractFittingResult`.",
)
end

println("Debug: Creating a residual plot")

data = r.args[1]
x = plotting_domain(data)
# at the moment I don't understand why the following line is necessary
# I would assume result = r.args[2] which might be of type FittingResultSlice
result = r.args[2] isa FittingResult ? r.args[2][1] : r.args[2]
y = invoke_result(result, result.u)

# residual is the difference between the model and the data in units of "sigma" so the error bars have size 1
# this assumes we have statistics such that sigma = sqrt(variance) - should probably make this more statistically neutral
yerr = sqrt.(result.variance)
y_residual = @. (result.objective - y) / yerr
# is this the best way to ensure y_residual_error has the same type as y_residual, or should it just be fixed at Float64?
y_residual_error = ones(eltype(y_residual), length(y_residual))

minorgrid --> true

if (label == :auto)
label = make_label(data)
end

# layout --> @layout [grid(2, 1, heights=[0.7 ,0.3]), margin=0mm]
layout --> @layout [
top{0.75h}
bottom{0.25h}
]
margins --> (0, :mm)

# logarithmic x-axis (might want to let this be an option)
xscale --> :log10
filtered_array = filter(x -> x != 0, result.objective)
min_data_value = 0.8 * minimum(filtered_array)
max_data_value = 1.2 * maximum(filtered_array)
filtered_array = filter(x -> x != 0, y)
min_model_value = 0.8 * minimum(filtered_array)
max_model_value = 1.2 * maximum(filtered_array)
min_non_zero_value = min(min_data_value, min_model_value)
max_non_zero_value = max(max_data_value, max_model_value)

# plot the data
@series begin
subplot --> 1
yscale --> :log10
yrange --> [min_non_zero_value, max_non_zero_value]
markerstrokecolor --> datacolor
label --> label
seriestype --> :scatter
markershape --> :none
markersize --> 0.5
yerror --> yerr
xerror --> bin_widths(data) ./ 2
x, result.objective
end

# plot the model - need to fix this
@series begin
subplot --> 1
yscale --> :log10
yrange --> [min_non_zero_value, max_non_zero_value]
xticks --> nothing
ylabel --> "Flux (units)"
markerstrokecolor --> modelcolor
label --> :none
seriestype --> :stepmid
markershape --> :none
markersize --> 0.5
xerror --> bin_widths(data) ./ 2
x, y
end

# plot the residuals
@series begin
subplot --> 2
yscale --> :identity
xticks --> true
xlabel --> "Energy (keV)"
ylabel --> "(Data - model)/error"
markerstrokecolor --> residualcolor
label --> :none
seriestype --> :scatter
markershape --> :none
markersize --> 0.5
yerror --> y_residual_error
xerror --> bin_widths(data) ./ 2
x, y_residual
end

# zero line
@series begin
subplot --> 2
yscale --> :identity
xticks --> true
xlabel --> "Energy (keV)"
ylabel --> "(Data - model)/error"
linestyle --> :dash
seriestype --> :hline
label --> false
# not currently specifying a colour
[0.0]
end
end
2 changes: 1 addition & 1 deletion src/print-utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function prettyfloat(f)
"0.0"
elseif f == Inf
"Inf"
elseif ((f 1) && (f < 1e5) && (f - trunc(Int, f) < 1e-5))
elseif ((f 1) && (f < 1e5) && (f - trunc(Int, f) < 1e-5))
Printf.@sprintf("%.1f", f)
else
Printf.@sprintf("%#.5g", f)
Expand Down
7 changes: 6 additions & 1 deletion src/xspec-models/additive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@ Flux │ '.. │
"Redshift."
z::T
end
function XS_CutOffPowerLaw(; K = FitParam(1.0), Γ = FitParam(2.0), Ecut = FitParam(15.0), z = FitParam(0.0, frozen=true))
function XS_CutOffPowerLaw(;
K = FitParam(1.0),
Γ = FitParam(2.0),
Ecut = FitParam(15.0),
z = FitParam(0.0, frozen = true),
)
XS_CutOffPowerLaw(K, Γ, Ecut, z)
end

Expand Down

0 comments on commit 8620a6b

Please sign in to comment.