Skip to content

Commit

Permalink
first version aoe sf joint fitter
Browse files Browse the repository at this point in the history
  • Loading branch information
ggmarshall committed Oct 9, 2024
1 parent 7ad4631 commit 2087fde
Showing 1 changed file with 191 additions and 53 deletions.
244 changes: 191 additions & 53 deletions src/pygama/pargen/AoE_cal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import copy
import logging
import re
from datetime import datetime
Expand Down Expand Up @@ -607,24 +608,109 @@ def get_peak_label(peak: float) -> str:
return "Tl FEP @"


def pass_pdf_hpge(
x,
x_lo,
x_hi,
n_sig1,
n_sig2,
mu,
sigma,
htail,
tau,
n_bkg1,
n_bkg2,
hstep1,
hstep2,
hstep3,
):
return hpge_peak.pdf_ext(
x, x_lo, x_hi, n_sig1, mu, sigma, htail, tau, n_bkg1, hstep1
)


def fail_pdf_hpge(
x,
x_lo,
x_hi,
n_sig1,
n_sig2,
mu,
sigma,
htail,
tau,
n_bkg1,
n_bkg2,
hstep1,
hstep2,
hstep3,
):
return hpge_peak.pdf_ext(
x, x_lo, x_hi, n_sig2, mu, sigma, htail, tau, n_bkg2, hstep2
)


def tot_pdf_hpge(
x,
x_lo,
x_hi,
n_sig1,
n_sig2,
mu,
sigma,
htail,
tau,
n_bkg1,
n_bkg2,
hstep1,
hstep2,
hstep3,
):
return hpge_peak.pdf_ext(
x, x_lo, x_hi, n_sig1 + n_sig2, mu, sigma, htail, tau, n_bkg1 + n_bkg2, hstep3
)


def pass_pdf_gos(
x, x_lo, x_hi, n_sig1, n_sig2, mu, sigma, n_bkg1, n_bkg2, hstep1, hstep2, hstep3
):
return gauss_on_step.pdf_ext(x, x_lo, x_hi, n_sig1, mu, sigma, n_bkg1, hstep1)


def fail_pdf_gos(
x, x_lo, x_hi, n_sig1, n_sig2, mu, sigma, n_bkg1, n_bkg2, hstep1, hstep2, hstep3
):
return gauss_on_step.pdf_ext(x, x_lo, x_hi, n_sig2, mu, sigma, n_bkg2, hstep2)


def tot_pdf_gos(
x, x_lo, x_hi, n_sig1, n_sig2, mu, sigma, n_bkg1, n_bkg2, hstep1, hstep2, hstep3
):
return gauss_on_step.pdf_ext(
x, x_lo, x_hi, n_sig1 + n_sig2, mu, sigma, n_bkg1 + n_bkg2, hstep3
)


def update_guess(func, parguess, energies):
if func == gauss_on_step:
if func == gauss_on_step or func == hpge_peak:

total_events = len(energies)
parguess["n_sig"] = len(
energies[
(energies > parguess["mu"] - 2 * parguess["sigma"])
& (energies < parguess["mu"] + 2 * parguess["sigma"])
]
)
parguess["n_bkg"] = total_events - parguess["n_sig"]
return parguess

if func == hpge_peak:
total_events = len(energies)
parguess["n_sig"] = len(
parguess["n_sig"] -= len(
energies[
(energies > parguess["mu"] - 2 * parguess["sigma"])
& (energies < parguess["mu"] + 2 * parguess["sigma"])
(energies > parguess["x_lo"])
& (energies < parguess["x_lo"] + 2 * parguess["sigma"])
]
)
parguess["n_sig"] -= len(
energies[
(energies > parguess["x_hi"] - 2 * parguess["sigma"])
& (energies < parguess["x_hi"])
]
)
parguess["n_bkg"] = total_events - parguess["n_sig"]
Expand All @@ -643,11 +729,11 @@ def get_survival_fraction(
eres_pars,
fit_range=None,
high_cut=None,
guess_pars_cut=None,
guess_pars_surv=None,
pars=None,
dt_mask=None,
mode="greater",
func=hpge_peak,
fix_step=False,
display=0,
):
if dt_mask is None:
Expand All @@ -672,8 +758,8 @@ def get_survival_fraction(
else:
raise ValueError("mode not recognised")

if guess_pars_cut is None or guess_pars_surv is None:
(pars, errs, cov, _, func, _, _, _) = pgc.unbinned_staged_energy_fit(
if pars is None:
(pars, _, _, _, func, _, _, _) = pgc.unbinned_staged_energy_fit(
energy,
func,
guess_func=energy_guess,
Expand All @@ -682,48 +768,104 @@ def get_survival_fraction(
fit_range=fit_range,
)

guess_pars_cut = pars
guess_pars_surv = pars
guess_pars_cut = copy.deepcopy(pars)
guess_pars_surv = copy.deepcopy(pars)

# add update guess here for n_sig and n_bkg
guess_pars_cut = update_guess(func, guess_pars_cut, energy[(~nan_idxs) & (~idxs)])
(cut_pars, cut_errs, cut_cov, _, _, _, _, _) = pgc.unbinned_staged_energy_fit(
energy[(~nan_idxs) & (~idxs)],
func,
guess=guess_pars_cut,
guess_func=energy_guess,
bounds_func=get_bounds,
fixed_func=fix_all_but_nevents,
guess_kwargs={"peak": peak, "eres": eres_pars},
lock_guess=True,
allow_tail_drop=False,
fit_range=fit_range,
)
guess_pars_surv = update_guess(func, guess_pars_cut, energy[(~nan_idxs) & (idxs)])
(surv_pars, surv_errs, surv_cov, _, _, _, _, _) = pgc.unbinned_staged_energy_fit(
energy[(~nan_idxs) & (idxs)],
func,
guess=guess_pars_surv,
guess_func=energy_guess,
bounds_func=get_bounds,
fixed_func=fix_all_but_nevents,
guess_kwargs={"peak": peak, "eres": eres_pars},
lock_guess=True,
allow_tail_drop=False,
fit_range=fit_range,
)
guess_pars_surv = update_guess(func, guess_pars_surv, energy[(~nan_idxs) & (idxs)])

parguess = {
"x_lo": pars["x_lo"],
"x_hi": pars["x_hi"],
"mu": pars["mu"],
"sigma": pars["sigma"],
"n_sig1": guess_pars_surv["n_sig"],
"n_bkg1": guess_pars_surv["n_bkg"],
"n_sig2": guess_pars_cut["n_sig"],
"n_bkg2": guess_pars_cut["n_bkg"],
"hstep1": pars["hstep"],
"hstep2": pars["hstep"],
"hstep3": pars["hstep"],
}

bounds = {
"n_sig1": (0, pars["n_sig"] + pars["n_bkg"]),
"n_sig2": (0, pars["n_sig"] + pars["n_bkg"]),
"n_bkg1": (0, pars["n_bkg"] + pars["n_sig"]),
"n_bkg2": (0, pars["n_bkg"] + pars["n_sig"]),
"hstep1": (-1, 1),
"hstep2": (-1, 1),
"hstep3": (-1, 1),
}

if func == hpge_peak:
parguess.update({"htail": pars["htail"], "tau": pars["tau"]})

if func == hpge_peak:
lh = (
cost.ExtendedUnbinnedNLL(energy[(~nan_idxs) & (idxs)], pass_pdf_hpge)
+ cost.ExtendedUnbinnedNLL(energy[(~nan_idxs) & (~idxs)], fail_pdf_hpge)
+ cost.ExtendedUnbinnedNLL(energy[(~nan_idxs)], tot_pdf_hpge)
)
elif func == gauss_on_step:
lh = (
cost.ExtendedUnbinnedNLL(energy[(~nan_idxs) & (idxs)], pass_pdf_gos)
+ cost.ExtendedUnbinnedNLL(energy[(~nan_idxs) & (~idxs)], fail_pdf_gos)
+ cost.ExtendedUnbinnedNLL(energy[(~nan_idxs)], tot_pdf_gos)
)

ct_n = cut_pars["n_sig"]
ct_err = cut_errs["n_sig"]
surv_n = surv_pars["n_sig"]
surv_err = surv_errs["n_sig"]
else:
raise ValueError("Unknown func")

m = Minuit(lh, **parguess)
fixed = ["x_lo", "x_hi", "mu", "sigma"]
if func == hpge_peak:
fixed += ["tau", "htail"]
if fix_step is True:
fixed += ["hstep1", "hstep2", "hstep3"]

m.fixed[fixed] = True
for arg, val in bounds.items():
m.limits[arg] = val

m.simplex().migrad()
m.hesse()

ct_n = m.values["n_sig2"]
ct_err = m.errors["n_sig2"]
surv_n = m.values["n_sig1"]
surv_err = m.errors["n_sig1"]

pc_n = ct_n + surv_n

sf = surv_n / pc_n
err = 100 * sf * (1 - sf) * np.sqrt((ct_err / ct_n) ** 2 + (surv_err / surv_n) ** 2)
sf *= 100

return sf, err, cut_pars, surv_pars
if display > 1:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
bins = np.arange(1552, 1612, 1)
ax1.hist(energy[(~nan_idxs) & (idxs)], bins=bins, histtype="step")

ax2.hist(energy[(~nan_idxs) & (~idxs)], bins=bins, histtype="step")

ax3.hist(energy[(~nan_idxs)], bins=bins, histtype="step")

if func == hpge_peak:
ax1.plot(bins, pass_pdf_hpge(bins, **m.values.to_dict())[1])
ax2.plot(bins, fail_pdf_hpge(bins, **m.values.to_dict())[1])

ax3.plot(bins, tot_pdf_hpge(bins, **m.values.to_dict())[1])
elif func == gauss_on_step:
ax1.plot(bins, pass_pdf_gos(bins, **m.values.to_dict())[1])
ax2.plot(bins, fail_pdf_gos(bins, **m.values.to_dict())[1])

ax3.plot(bins, tot_pdf_gos(bins, **m.values.to_dict())[1])

plt.show()

return sf, err, m.values, m.errors


def get_sf_sweep(
Expand Down Expand Up @@ -754,16 +896,14 @@ def get_sf_sweep(
cut_vals = np.linspace(cut_range[0], cut_range[1], n_samples)
out_df = pd.DataFrame()

(pars, _, _, _, func, _, _, _) = pgc.unbinned_staged_energy_fit(
(pars, errs, _, _, func, _, _, _) = pgc.unbinned_staged_energy_fit(
energy,
hpge_peak,
guess_func=energy_guess,
bounds_func=get_bounds,
guess_kwargs={"peak": peak, "eres": eres_pars},
fit_range=fit_range,
)
guess_pars_cut = pars
guess_pars_surv = pars

for cut_val in cut_vals:
try:
Expand All @@ -776,8 +916,7 @@ def get_sf_sweep(
fit_range=fit_range,
dt_mask=dt_mask,
mode=mode,
guess_pars_cut=guess_pars_cut,
guess_pars_surv=guess_pars_surv,
pars=pars,
func=func,
)
out_df = pd.concat(
Expand All @@ -790,7 +929,7 @@ def get_sf_sweep(
raise (e)
out_df.set_index("cut_val", inplace=True)
if final_cut_value is not None:
sf, sf_err, cut_pars, surv_pars = get_survival_fraction(
sf, sf_err, _, _ = get_survival_fraction(
energy,
cut_param,
final_cut_value,
Expand All @@ -799,8 +938,7 @@ def get_sf_sweep(
fit_range=fit_range,
dt_mask=dt_mask,
mode=mode,
guess_pars_cut=guess_pars_cut,
guess_pars_surv=guess_pars_surv,
pars=pars,
func=func,
)
else:
Expand Down

0 comments on commit 2087fde

Please sign in to comment.