Skip to content

Commit

Permalink
Refactor lepton selection code to store channel IDs for each event
Browse files Browse the repository at this point in the history
  • Loading branch information
haddadanas committed Oct 17, 2024
1 parent 50e8c64 commit c6a8207
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion hbt/selection/lepton.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def tau_selection_init(self: Selector) -> None:
produces={
electron_selection, muon_selection, tau_selection,
# new columns
"channel_id", "leptons_os", "tau2_isolated", "single_triggered", "cross_triggered",
"channel_ids", "channel_id", "leptons_os", "tau2_isolated", "single_triggered", "cross_triggered",
},
)
def lepton_selection(
Expand All @@ -397,6 +397,7 @@ def lepton_selection(
# prepare vectors for output vectors
false_mask = (abs(events.event) < 0)
channel_id = np.uint8(1) * false_mask
channel_ids = []
tau2_isolated = false_mask
leptons_os = false_mask
single_triggered = false_mask
Expand Down Expand Up @@ -456,6 +457,7 @@ def lepton_selection(
# store global variables
where = (channel_id == 0) & is_etau
channel_id = ak.where(where, ch_etau.id, channel_id)
channel_ids.append(ak.where(is_etau, [[ch_etau.id]], [[]]))
tau2_isolated = ak.where(where, is_iso, tau2_isolated)
leptons_os = ak.where(where, is_os, leptons_os)
single_triggered = ak.where(where & is_single, True, single_triggered)
Expand All @@ -480,6 +482,7 @@ def lepton_selection(
# store global variables
where = (channel_id == 0) & is_mutau
channel_id = ak.where(where, ch_mutau.id, channel_id)
channel_ids.append(ak.where(is_mutau, [[ch_mutau.id]], [[]]))
tau2_isolated = ak.where(where, is_iso, tau2_isolated)
leptons_os = ak.where(where, is_os, leptons_os)
single_triggered = ak.where(where & is_single, True, single_triggered)
Expand Down Expand Up @@ -513,20 +516,24 @@ def lepton_selection(
# store global variables
where = (channel_id == 0) & is_tautau
channel_id = ak.where(where, ch_tautau.id, channel_id)
channel_ids.append(ak.where(is_tautau, [[ch_tautau.id]], [[]]))
tau2_isolated = ak.where(where, is_iso, tau2_isolated)
leptons_os = ak.where(where, is_os, leptons_os)
single_triggered = ak.where(where & is_single, True, single_triggered)
cross_triggered = ak.where(where & is_cross, True, cross_triggered)
sel_tau_indices = ak.where(where, tau_indices, sel_tau_indices)

# some final type conversions
channel_ids = ak.concatenate(channel_ids, axis=1)
channel_ids = ak.values_astype(channel_ids, np.uint8)
channel_id = ak.values_astype(channel_id, np.uint8)
leptons_os = ak.fill_none(leptons_os, False)
sel_electron_indices = ak.values_astype(sel_electron_indices, np.int32)
sel_muon_indices = ak.values_astype(sel_muon_indices, np.int32)
sel_tau_indices = ak.values_astype(sel_tau_indices, np.int32)

# save new columns
events = set_ak_column(events, "channel_ids", channel_ids)
events = set_ak_column(events, "channel_id", channel_id)
events = set_ak_column(events, "leptons_os", leptons_os)
events = set_ak_column(events, "tau2_isolated", tau2_isolated)
Expand Down

0 comments on commit c6a8207

Please sign in to comment.