diff --git a/src/pygama/evt/modules/larveto.py b/src/pygama/evt/modules/larveto.py index 28b827c83..f26ae3b8e 100644 --- a/src/pygama/evt/modules/larveto.py +++ b/src/pygama/evt/modules/larveto.py @@ -70,8 +70,16 @@ def l200_test_stat(relative_t0, amp, ts_bkg_prob, rc_density): n_pe_tot = np.where(n_pe_tot == 0, np.nan, n_pe_tot) # calculate the test statistic term related to the time distribution - transform_function = transform_wrapper(ts_bkg_prob) - ts_time = -ak.sum(ak.transform(transform_function, relative_t0, amp), axis=-1) + ts_time = -ak.sum( + ak.transform( + lambda layouts, **kwargs: _ak_l200_test_stat_time_term( + layouts, ts_bkg_prob, **kwargs + ), + relative_t0, + amp, + ), + axis=-1, + ) # calculate the amplitude contribution ts_amp = [l200_rc_amp_logpdf(n, rc_density) for n in n_pe_tot] @@ -81,14 +89,6 @@ def l200_test_stat(relative_t0, amp, ts_bkg_prob, rc_density): return t_stat -# need this to pass ts_bkg_prob parameter with ak.transform() -def transform_wrapper(ts_bkg_prob): - def _transform(layouts, **kwargs): - return _ak_l200_test_stat_time_term(layouts, ts_bkg_prob=ts_bkg_prob) - - return _transform - - # need to define this function and use it with ak.transform() because scipy # routines are not NumPy universal functions def _ak_l200_test_stat_time_term(layouts, ts_bkg_prob, **kwargs):