Skip to content

Commit

Permalink
Cross talk correction code for build_evt() (#572)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: ggmarshall <[email protected]>
Co-authored-by: Luigi Pertoldi <[email protected]>
  • Loading branch information
3 people authored May 8, 2024
1 parent d8648ce commit d53c613
Show file tree
Hide file tree
Showing 12 changed files with 699 additions and 40 deletions.
3 changes: 3 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ authors:
- family-names: Marshall
given-names: George
orcid: https://orcid.org/0000-0002-5470-5132
- family-names: Dixon
given-names: Toby
orcid: https://orcid.org/0000-0001-8787-6336
- family-names: D'Andrea
given-names: Valerio
orcid: https://orcid.org/0000-0003-2037-4133
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
"iminuit",
"legend-daq2lh5>=1.2.1",
"legend-pydataobj>=1.6",
"pylegendmeta>=0.9",
"matplotlib",
"numba!=0.53.*,!=0.54.*,!=0.57",
"numpy>=1.21",
Expand Down
51 changes: 30 additions & 21 deletions src/pygama/evt/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def evaluate_to_scalar(
def evaluate_at_channel(
datainfo,
tcm,
channels,
channels_skip,
expr,
field_list,
Expand All @@ -253,6 +254,8 @@ def evaluate_at_channel(
input and output LH5 datainfo with HDF5 groups where tables are found.
tcm
TCM data arrays in an object that can be accessed by attribute.
channels
list of channels to be included for evaluation.
channels_skip
list of channels to be skipped from evaluation and set to default value.
expr
Expand Down Expand Up @@ -281,7 +284,7 @@ def evaluate_at_channel(
evt_ids_ch = np.searchsorted(
tcm.cumulative_length, np.where(tcm.id == ch)[0], "right"
)
if table_name not in channels_skip:
if (table_name in channels) and (table_name not in channels_skip):
res = utils.get_data_at_channel(
datainfo=datainfo,
ch=table_name,
Expand All @@ -307,6 +310,7 @@ def evaluate_at_channel_vov(
expr,
field_list,
ch_comp,
channels,
channels_skip,
pars_dict=None,
default_value=np.nan,
Expand All @@ -326,6 +330,8 @@ def evaluate_at_channel_vov(
list of `dsp/hit/evt` parameter tuples in expression ``(tier, field)``.
ch_comp
array of "rawid"s at which the expression is evaluated.
channels
list of channels to be included for evaluation.
channels_skip
list of channels to be skipped from evaluation and set to default value.
pars_dict
Expand All @@ -335,20 +341,19 @@ def evaluate_at_channel_vov(
"""
f = utils.make_files_config(datainfo)

# blow up vov to aoesa
out = ak.Array([[] for _ in range(len(ch_comp))])
ch_comp_channels = np.unique(ch_comp.flattened_data.nda).astype(int)

channels = np.unique(ch_comp.flattened_data.nda).astype(int)
ch_comp = ch_comp.view_as("ak")
out = np.full(
len(ch_comp.flattened_data.nda), default_value, dtype=type(default_value)
)

type_name = None
for ch in channels:
for ch in ch_comp_channels:
table_name = utils.get_table_name_by_pattern(f.hit.table_fmt, ch)

evt_ids_ch = np.searchsorted(
tcm.cumulative_length, np.where(tcm.id == ch)[0], "right"
)
if table_name not in channels_skip:
if (table_name in channels) and (table_name not in channels_skip):
res = utils.get_data_at_channel(
datainfo=datainfo,
ch=table_name,
Expand All @@ -357,23 +362,27 @@ def evaluate_at_channel_vov(
field_list=field_list,
pars_dict=pars_dict,
)
else:
idx_ch = tcm.idx[tcm.id == ch]
res = np.full(len(idx_ch), default_value)

# see in which events the current channel is present
mask = ak.to_numpy(ak.any(ch_comp == ch, axis=-1), allow_missing=False)
cv = np.full(len(ch_comp), np.nan)
cv[evt_ids_ch] = res
cv[~mask] = np.nan
cv = ak.drop_none(ak.nan_to_none(ak.Array(cv)[:, None]))
new_evt_ids_ch = np.searchsorted(
ch_comp.cumulative_length,
np.where(ch_comp.flattened_data.nda == ch)[0],
"right",
)
matches = np.isin(evt_ids_ch, new_evt_ids_ch)
out[ch_comp.flattened_data.nda == ch] = res[matches]

out = ak.concatenate((out, cv), axis=-1)
else:
length = len(np.where(ch_comp.flattened_data.nda == ch)[0])
res = np.full(length, default_value)
out[ch_comp.flattened_data.nda == ch] = res

if ch == channels[0]:
if ch == ch_comp_channels[0]:
out = out.astype(res.dtype)
type_name = res.dtype

return types.VectorOfVectors(ak.values_astype(out, type_name))
return types.VectorOfVectors(
flattened_data=types.Array(out, dtype=type_name),
cumulative_length=ch_comp.cumulative_length,
)


def evaluate_to_aoesa(
Expand Down
2 changes: 2 additions & 0 deletions src/pygama/evt/build_evt.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def evaluate_expression(
return aggregators.evaluate_at_channel(
datainfo=datainfo,
tcm=tcm,
channels=channels,
channels_skip=channels_skip,
expr=expr,
field_list=field_list,
Expand All @@ -512,6 +513,7 @@ def evaluate_expression(
expr=expr,
field_list=field_list,
ch_comp=ch_comp,
channels=channels,
channels_skip=channels_skip,
pars_dict=pars_dict,
default_value=default_value,
Expand Down
180 changes: 163 additions & 17 deletions src/pygama/evt/modules/geds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

from collections.abc import Sequence

import awkward as ak
import numpy as np
from lgdo import lh5, types

from .. import utils
from . import xtalk


def apply_recovery_cut(
Expand All @@ -26,7 +28,7 @@ def apply_recovery_cut(
is_recovering = is_recovering | np.where(
(
((timestamps.nda - tstamp) < time_window)
& ((timestamps.nda - tstamp) > 0)
& ((timestamps.nda - tstamp) >= 0)
),
True,
False,
Expand All @@ -41,33 +43,177 @@ def apply_xtalk_correction(
tcm: utils.TCMData,
table_names: Sequence[str],
*,
energy_observable: types.VectorOfVectors,
rawids: types.VectorOfVectors,
return_mode: str,
uncal_energy_expr: str,
cal_energy_expr: str,
multiplicity_expr: str,
xtalk_threshold: float = None,
xtalk_matrix_filename: str = "",
xtalk_rawid_obj: str = "xtc/rawid_index",
xtalk_matrix_obj: str = "xtc/xtalk_matrix_negative",
positive_xtalk_matrix_obj: str = "xtc/xtalk_matrix_positive",
) -> types.VectorOfVectors:
"""Applies the cross-talk correction to the energy observable.
The format of `xtalk_matrix_filename` should be currently be a path to a lh5 file.
The correction is applied using matrix algebra for all triggers above the threshold.
Parameters
----------
datainfo, tcm, table_names
positional arguments automatically supplied by :func:`.build_evt`.
return_mode
string which can be either energy to return corrected energy or tcm_index
uncal_energy_expr
expression for the pulse parameter to be gathered for the uncalibrated energy (used for correction),
can be a combination of different fields.
cal_energy_expr
expression for the pulse parameter to be gathered for the calibrated energy, used for the xtalk threshold,
can be a combination of different fields.
xtalk_threshold
threshold used for xtalk correction, hits below this energy will not
be used to correct the other hits.
xtalk_matrix_filename
name of the file containing the xtalk matrices.
xtalk_matrix_obj
name of the lh5 object containing the xtalk matrix
positive_xtalk_matrix_obj
name of the lh5 object containing the positive polarity xtalk matrix
xtalk_rawids_obj
name of the lh5 object containing the name of the rawids
"""

xtalk_matrix_rawids = lh5.read_as(xtalk_rawid_obj, xtalk_matrix_filename, "np")
tcm_index_array = xtalk.build_tcm_index_array(tcm, datainfo, xtalk_matrix_rawids)

energy_corr = xtalk.get_xtalk_correction(
tcm,
datainfo,
uncal_energy_expr,
cal_energy_expr,
xtalk_threshold,
xtalk_matrix_filename,
xtalk_rawid_obj,
xtalk_matrix_obj,
positive_xtalk_matrix_obj,
)

multiplicity_mask = xtalk.filter_hits(
datainfo,
tcm,
multiplicity_expr,
energy_corr,
xtalk_matrix_rawids,
)
energy_corr = ak.from_regular(energy_corr)
multiplicity_mask = ak.from_regular(multiplicity_mask)
tcm_index_array = ak.from_regular(tcm_index_array)

if return_mode == "energy":
return types.VectorOfVectors(energy_corr[multiplicity_mask])
elif return_mode == "tcm_index":
return types.VectorOfVectors(tcm_index_array[multiplicity_mask])
else:
raise ValueError(f"Unknown mode: {return_mode}")


def apply_xtalk_correction_and_calibrate(
datainfo: utils.DataInfo,
tcm: utils.TCMData,
table_names: Sequence[str],
*,
return_mode: str,
uncal_energy_expr: str,
cal_energy_expr: str,
cal_par_files: str | Sequence[str],
multiplicity_expr: str,
xtalk_matrix_filename: str,
xtalk_threshold: float = None,
xtalk_rawid_obj: str = "xtc/rawid_index",
xtalk_matrix_obj: str = "xtc/xtalk_matrix_negative",
positive_xtalk_matrix_obj: str = "xtc/xtalk_matrix_positive",
uncal_var: str = "dsp.cuspEmax",
recal_var: str = "hit.cuspEmax_ctc_cal",
) -> types.VectorOfVectors:
"""Applies the cross-talk correction to the energy observable.
The format of `xtalk_matrix_filename` should be...
The correction is applied using matrix algebra for all triggers above the
xalk threshold.
Parameters
----------
datainfo, tcm, table_names
positional arguments automatically supplied by :func:`.build_evt`.
energy_observable
array of energy values to correct, one event per row. The detector
identifier is stored in `rawids`, which has the same layout.
rawids
array of detector identifiers for each energy in `energy_observable`.
return_mode
string which can be either ``energy`` to return corrected energy or
``tcm_index``.
uncal_energy_expr
expression for the pulse parameter to be gathered for the uncalibrated
energy (used for correction), can be a combination of different fields.
cal_energy_expr
expression for the pulse parameter to be gathered for the calibrated
energy, used for the xtalk threshold, can be a combination of different
fields.
cal_par_files
path to the generated hit tier par-files defining the calibration
curves. Used to recalibrate the data after xtalk correction.
multiplicity_expr:
expression defining the logic used to compute the event multiplicity.
xtalk_threshold
threshold used for xtalk correction, hits below this energy will not be
used to correct the other hits.
xtalk_matrix_filename
name of the file containing the cross-talk matrices.
path to the file containing the xtalk matrices.
xtalk_matrix_obj
name of the lh5 object containing the xtalk matrix.
positive_xtalk_matrix_obj
name of the lh5 object containing the positive polarity xtalk matrix.
xtalk_matrix_rawids
name of the lh5 object containing the name of the rawids.
recal_var
name of the energy variable to use for re-calibration.
"""
# read in xtalk matrices
lh5.read_as("", xtalk_matrix_filename, "ak")

# do the correction
energies_corr = ...
xtalk_matrix_rawids = lh5.read_as(xtalk_rawid_obj, xtalk_matrix_filename, "np")
tcm_index_array = xtalk.build_tcm_index_array(tcm, datainfo, xtalk_matrix_rawids)

# return the result as LGDO
return types.VectorOfVectors(
energies_corr, attrs=utils.copy_lgdo_attrs(energy_observable)
energy_corr = xtalk.get_xtalk_correction(
tcm,
datainfo,
uncal_energy_expr,
cal_energy_expr,
xtalk_threshold,
xtalk_matrix_filename,
xtalk_rawid_obj,
xtalk_matrix_obj,
positive_xtalk_matrix_obj,
)

calibrated_corr = xtalk.calibrate_energy(
datainfo,
tcm,
energy_corr,
xtalk_matrix_rawids,
cal_par_files,
uncal_var,
recal_var,
)

multiplicity_mask = xtalk.filter_hits(
datainfo,
tcm,
multiplicity_expr,
calibrated_corr,
xtalk_matrix_rawids,
)

calibrated_corr = ak.from_regular(calibrated_corr)
multiplicity_mask = ak.from_regular(multiplicity_mask)
tcm_index_array = ak.from_regular(tcm_index_array)

if return_mode == "energy":
return types.VectorOfVectors(calibrated_corr[multiplicity_mask])
elif return_mode == "tcm_index":
return types.VectorOfVectors(tcm_index_array[multiplicity_mask])
else:
raise ValueError(f"Unknown mode: {return_mode}")
Loading

0 comments on commit d53c613

Please sign in to comment.