diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 76eb5c78b..132de5c1d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -17,29 +17,28 @@ env: TQDM_MININTERVAL: 100 jobs: - build-and-test: name: Test pygama with Python runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: - python-version: ['3.9', '3.10', '3.11'] + python-version: ["3.9", "3.10", "3.11"] os: [ubuntu-latest, macOS-latest] steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Get dependencies and install the package - run: | - python -m pip install --upgrade pip wheel setuptools - python -m pip install --upgrade .[test] - - name: Run unit tests - run: | - python -m pytest + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Get dependencies and install the package + run: | + python -m pip install --upgrade pip wheel setuptools + python -m pip install --upgrade .[test] + - name: Run unit tests + run: | + python -m pytest test-coverage: name: Calculate and upload test coverage @@ -50,7 +49,7 @@ jobs: fetch-depth: 2 - uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: "3.10" - name: Generate Report run: | @@ -72,7 +71,7 @@ jobs: fetch-depth: 0 - uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: "3.10" - name: Setup build environment run: | sudo apt-get install -y pandoc diff --git a/codecov.yml b/codecov.yml index 2e66c0446..55c040463 100644 --- a/codecov.yml +++ b/codecov.yml @@ -9,7 +9,7 @@ coverage: patch: false github_checks: - annotations: false + annotations: false comment: require_changes: true diff --git a/docs/source/conf.py b/docs/source/conf.py index 3bcd7d1fc..0b86f6627 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -76,7 +76,7 @@ autodoc_default_options = {"ignore-module-all": True} # Include __init__() docstring in class docstring autoclass_content = "both" -autodoc_typehints = "both" +autodoc_typehints = "description" autodoc_typehints_description_target = "documented_params" autodoc_typehints_format = "short" diff --git a/pyproject.toml b/pyproject.toml index 0b196a424..3c47baf42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,13 +34,14 @@ dependencies = [ "dspeed>=1.3", "h5py>=3.2", "iminuit", - "legend-daq2lh5>=1.2", - "legend-pydataobj>=1.5", + "legend-daq2lh5>=1.2.1", + "legend-pydataobj>=1.6", "matplotlib", "numba!=0.53.*,!=0.54.*,!=0.57", "numpy>=1.21", "pandas>=1.4.4", "pint", + "pyyaml", "scikit-learn", "scipy>=1.0.1", "tables", diff --git a/src/pygama/evt/aggregators.py b/src/pygama/evt/aggregators.py index dbcae2829..c9adee29b 100644 --- a/src/pygama/evt/aggregators.py +++ b/src/pygama/evt/aggregators.py @@ -6,135 +6,113 @@ import awkward as ak import numpy as np -from lgdo import Array, ArrayOfEqualSizedArrays, VectorOfVectors, lh5 +from lgdo import lh5, types from lgdo.lh5 import LH5Store -from numpy.typing import NDArray from . import utils def evaluate_to_first_or_last( - cumulength: NDArray, - idx: NDArray, - ids: NDArray, - f_hit: str, - f_dsp: str, - chns: list, - chns_rm: list, - expr: str, - exprl: list, - qry: str | NDArray, - nrows: int, - sorter: tuple, - var_ph: dict = None, - defv: bool | int | float = np.nan, + datainfo, + tcm, + channels, + channels_skip, + expr, + field_list, + query, + n_rows, + sorter, + pars_dict=None, + default_value=np.nan, is_first: bool = True, - tcm_id_table_pattern: str = "ch{}", - evt_group: str = "evt", - hit_group: str = "hit", - dsp_group: str = "dsp", -) -> Array: +) -> types.Array: """Aggregates across channels by returning the expression of the channel with value of `sorter`. Parameters ---------- - idx - `tcm` index array. - ids - `tcm` id array. - f_hit - path to `hit` tier file. - f_dsp - path to `dsp` tier file. - chns + datainfo + input and output LH5 datainfo with HDF5 groups where tables are found. + tcm + TCM data arrays in an object that can be accessed by attribute. + channels list of channels to be aggregated. - chns_rm + channels_skip list of channels to be skipped from evaluation and set to default value. expr expression string to be evaluated. - exprl + field_list list of `dsp/hit/evt` parameter tuples in expression ``(tier, field)``. - qry + query query expression to mask aggregation. - nrows + n_rows length of output array. sorter tuple of field in `hit/dsp/evt` tier to evaluate ``(tier, field)``. - var_ph + pars_dict dictionary of `evt` and additional parameters and their values. - defv + default_value default value. is_first defines if sorted by smallest or largest value of `sorter` - tcm_id_table_pattern - pattern to format `tcm` id values to table name in higher tiers. Must have one - placeholder which is the `tcm` id. - dsp_group - LH5 root group in `dsp` file. - hit_group - LH5 root group in `hit` file. - evt_group - LH5 root group in `evt` file. """ + f = utils.make_files_config(datainfo) - # define dimension of output array - out = np.full(nrows, defv, dtype=type(defv)) - outt = np.zeros(len(out)) + out = None + outt = None + store = LH5Store(keep_open=True) - store = LH5Store() + for ch in channels: + table_id = utils.get_tcm_id_by_pattern(f.hit.table_fmt, ch) - for ch in chns: # get index list for this channel to be loaded - idx_ch = idx[ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch)] - evt_ids_ch = np.searchsorted( - cumulength, - np.where(ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch))[0], - "right", - ) + idx_ch = tcm.idx[tcm.id == table_id] # evaluate at channel - res = utils.get_data_at_channel( - ch=ch, - ids=ids, - idx=idx, - expr=expr, - exprl=exprl, - var_ph=var_ph, - is_evaluated=ch not in chns_rm, - f_hit=f_hit, - f_dsp=f_dsp, - defv=defv, - tcm_id_table_pattern=tcm_id_table_pattern, - evt_group=evt_group, - hit_group=hit_group, - dsp_group=dsp_group, - ) + if ch not in channels_skip: + res = utils.get_data_at_channel( + datainfo=datainfo, + ch=ch, + tcm=tcm, + expr=expr, + field_list=field_list, + pars_dict=pars_dict, + ) + + if out is None: + # define dimension of output array + out = utils.make_numpy_full(n_rows, default_value, res.dtype) + outt = np.zeros(len(out)) + else: + res = np.full(len(idx_ch), default_value) # get mask from query limarr = utils.get_mask_from_query( - qry=qry, + datainfo=datainfo, + query=query, length=len(res), ch=ch, idx_ch=idx_ch, - f_hit=f_hit, - f_dsp=f_dsp, - hit_group=hit_group, - dsp_group=dsp_group, ) # find if sorter is in hit or dsp t0 = store.read( f"{ch}/{sorter[0]}/{sorter[1]}", - f_hit if f"{hit_group}" == sorter[0] else f_dsp, + f.hit.file if f"{f.hit.group}" == sorter[0] else f.dsp.file, idx=idx_ch, )[0].view_as("np") if t0.ndim > 1: raise ValueError(f"sorter '{sorter[0]}/{sorter[1]}' must be a 1D array") + evt_ids_ch = np.searchsorted( + tcm.cumulative_length, + np.where(tcm.id == table_id)[0], + "right", + ) + if is_first: - if ch == chns[0]: + if ch == channels[0]: outt[:] = np.inf out[evt_ids_ch] = np.where( @@ -152,292 +130,236 @@ def evaluate_to_first_or_last( (t0 > outt[evt_ids_ch]) & (limarr), t0, outt[evt_ids_ch] ) - return Array(nda=out, dtype=type(defv)) + return types.Array(nda=out) def evaluate_to_scalar( - mode: str, - cumulength: NDArray, - idx: NDArray, - ids: NDArray, - f_hit: str, - f_dsp: str, - chns: list, - chns_rm: list, - expr: str, - exprl: list, - qry: str | NDArray, - nrows: int, - var_ph: dict = None, - defv: bool | int | float = np.nan, - tcm_id_table_pattern: str = "ch{}", - evt_group: str = "evt", - hit_group: str = "hit", - dsp_group: str = "dsp", -) -> Array: + datainfo, + tcm, + mode, + channels, + channels_skip, + expr, + field_list, + query, + n_rows, + pars_dict=None, + default_value=np.nan, +) -> types.Array: """Aggregates by summation across channels. Parameters ---------- + datainfo + input and output LH5 datainfo with HDF5 groups where tables are found. + tcm + TCM data arrays in an object that can be accessed by attribute. mode aggregation mode. - idx - `tcm` index array. - ids - `tcm` id array. - f_hit - path to `hit` tier file. - f_dsp - path to `dsp` tier file. - chns + channels list of channels to be aggregated. - chns_rm + channels_skip list of channels to be skipped from evaluation and set to default value. expr expression string to be evaluated. - exprl + field_list list of `dsp/hit/evt` parameter tuples in expression ``(tier, field)``. - qry + query query expression to mask aggregation. - nrows + n_rows length of output array - var_ph + pars_dict dictionary of `evt` and additional parameters and their values. - defv + default_value default value. - tcm_id_table_pattern - pattern to format `tcm` id values to table name in higher tiers. Must have one - placeholder which is the `tcm` id. - dsp_group - LH5 root group in `dsp` file. - hit_group - LH5 root group in `hit` file. - evt_group - LH5 root group in `evt` file. """ + f = utils.make_files_config(datainfo) + out = None - # define dimension of output array - out = np.full(nrows, defv, dtype=type(defv)) + for ch in channels: + table_id = utils.get_tcm_id_by_pattern(f.hit.table_fmt, ch) - for ch in chns: # get index list for this channel to be loaded - idx_ch = idx[ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch)] - evt_ids_ch = np.searchsorted( - cumulength, - np.where(ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch))[0], - "right", - ) + idx_ch = tcm.idx[tcm.id == table_id] + + if ch not in channels_skip: + res = utils.get_data_at_channel( + datainfo=datainfo, + ch=ch, + tcm=tcm, + expr=expr, + field_list=field_list, + pars_dict=pars_dict, + ) - res = utils.get_data_at_channel( - ch=ch, - ids=ids, - idx=idx, - expr=expr, - exprl=exprl, - var_ph=var_ph, - is_evaluated=ch not in chns_rm, - f_hit=f_hit, - f_dsp=f_dsp, - defv=defv, - tcm_id_table_pattern=tcm_id_table_pattern, - evt_group=evt_group, - hit_group=hit_group, - dsp_group=dsp_group, - ) + if out is None: + # define dimension of output array + out = utils.make_numpy_full(n_rows, default_value, res.dtype) + else: + res = np.full(len(idx_ch), default_value) # get mask from query limarr = utils.get_mask_from_query( - qry=qry, + datainfo=datainfo, + query=query, length=len(res), ch=ch, idx_ch=idx_ch, - f_hit=f_hit, - f_dsp=f_dsp, - hit_group=hit_group, - dsp_group=dsp_group, + ) + + evt_ids_ch = np.searchsorted( + tcm.cumulative_length, + np.where(tcm.id == table_id)[0], + side="right", ) # switch through modes if "sum" == mode: if res.dtype == bool: res = res.astype(int) + out[evt_ids_ch] = np.where(limarr, res + out[evt_ids_ch], out[evt_ids_ch]) + if "any" == mode: if res.dtype != bool: res = res.astype(bool) + out[evt_ids_ch] = out[evt_ids_ch] | (res & limarr) + if "all" == mode: if res.dtype != bool: res = res.astype(bool) + out[evt_ids_ch] = out[evt_ids_ch] & res & limarr - return Array(nda=out, dtype=type(defv)) + return types.Array(nda=out) def evaluate_at_channel( - cumulength: NDArray, - idx: NDArray, - ids: NDArray, - f_hit: str, - f_dsp: str, - chns_rm: list, - expr: str, - exprl: list, - ch_comp: Array, - var_ph: dict = None, - defv: bool | int | float = np.nan, - tcm_id_table_pattern: str = "ch{}", - evt_group: str = "evt", - hit_group: str = "hit", - dsp_group: str = "dsp", -) -> Array: + datainfo, + tcm, + channels_skip, + expr, + field_list, + ch_comp, + pars_dict=None, + default_value=np.nan, +) -> types.Array: """Aggregates by evaluating the expression at a given channel. Parameters ---------- - idx - `tcm` index array. - ids - `tcm` id array. - f_hit - path to `hit` tier file. - f_dsp - path to `dsp` tier file. - chns_rm + datainfo + input and output LH5 datainfo with HDF5 groups where tables are found. + tcm + TCM data arrays in an object that can be accessed by attribute. + channels_skip list of channels to be skipped from evaluation and set to default value. expr expression string to be evaluated. - exprl + field_list list of `dsp/hit/evt` parameter tuples in expression ``(tier, field)``. ch_comp array of rawids at which the expression is evaluated. - var_ph + pars_dict dictionary of `evt` and additional parameters and their values. - defv + default_value default value. - tcm_id_table_pattern - pattern to format `tcm` id values to table name in higher tiers. Must have one - placeholder which is the `tcm` id. - dsp_group - LH5 root group in `dsp` file. - hit_group - LH5 root group in `hit` file. - evt_group - LH5 root group in `evt` file. """ + f = utils.make_files_config(datainfo) + table_id_fmt = f.hit.table_fmt - out = np.full(len(ch_comp.nda), defv, dtype=type(defv)) + out = None for ch in np.unique(ch_comp.nda.astype(int)): + table_name = utils.get_table_name_by_pattern(table_id_fmt, ch) # skip default value - if utils.get_table_name_by_pattern(tcm_id_table_pattern, ch) not in lh5.ls( - f_hit - ): + if table_name not in lh5.ls(f.hit.file): continue - idx_ch = idx[ids == ch] - evt_ids_ch = np.searchsorted(cumulength, np.where(ids == ch)[0], "right") - res = utils.get_data_at_channel( - ch=utils.get_table_name_by_pattern(tcm_id_table_pattern, ch), - ids=ids, - idx=idx, - expr=expr, - exprl=exprl, - var_ph=var_ph, - is_evaluated=utils.get_table_name_by_pattern(tcm_id_table_pattern, ch) - not in chns_rm, - f_hit=f_hit, - f_dsp=f_dsp, - defv=defv, - tcm_id_table_pattern=tcm_id_table_pattern, - evt_group=evt_group, - hit_group=hit_group, - dsp_group=dsp_group, + + idx_ch = tcm.idx[tcm.id == ch] + evt_ids_ch = np.searchsorted( + tcm.cumulative_length, np.where(tcm.id == ch)[0], "right" ) + if table_name not in channels_skip: + res = utils.get_data_at_channel( + datainfo=datainfo, + ch=table_name, + tcm=tcm, + expr=expr, + field_list=field_list, + pars_dict=pars_dict, + ) + else: + res = np.full(len(idx_ch), default_value) + + if out is None: + out = utils.make_numpy_full(len(ch_comp.nda), default_value, res.dtype) out[evt_ids_ch] = np.where(ch == ch_comp.nda[idx_ch], res, out[evt_ids_ch]) - return Array(nda=out, dtype=type(defv)) + return types.Array(nda=out) def evaluate_at_channel_vov( - cumulength: NDArray, - idx: NDArray, - ids: NDArray, - f_hit: str, - f_dsp: str, - expr: str, - exprl: list, - ch_comp: VectorOfVectors, - chns_rm: list, - var_ph: dict = None, - defv: bool | int | float = np.nan, - tcm_id_table_pattern: str = "ch{}", - evt_group: str = "evt", - hit_group: str = "hit", - dsp_group: str = "dsp", -) -> VectorOfVectors: + datainfo, + tcm, + expr, + field_list, + ch_comp, + channels_skip, + pars_dict=None, + default_value=np.nan, +) -> types.VectorOfVectors: """Same as :func:`evaluate_at_channel` but evaluates expression at non flat channels :class:`.VectorOfVectors`. Parameters ---------- - idx - `tcm` index array. - ids - `tcm` id array. - f_hit - path to `hit` tier file. - f_dsp - path to `dsp` tier file. + datainfo + input and output LH5 datainfo with HDF5 groups where tables are found. + tcm + TCM data arrays in an object that can be accessed by attribute. expr expression string to be evaluated. - exprl + field_list list of `dsp/hit/evt` parameter tuples in expression ``(tier, field)``. ch_comp array of "rawid"s at which the expression is evaluated. - chns_rm + channels_skip list of channels to be skipped from evaluation and set to default value. - var_ph + pars_dict dictionary of `evt` and additional parameters and their values. - defv + default_value default value. - tcm_id_table_pattern - pattern to format `tcm` id values to table name in higher tiers. Must have one - placeholder which is the `tcm` id. - dsp_group - LH5 root group in `dsp` file. - hit_group - LH5 root group in `hit` file. - evt_group - LH5 root group in `evt` file. """ + f = utils.make_files_config(datainfo) # blow up vov to aoesa out = ak.Array([[] for _ in range(len(ch_comp))]) - chns = np.unique(ch_comp.flattened_data.nda).astype(int) + channels = np.unique(ch_comp.flattened_data.nda).astype(int) ch_comp = ch_comp.view_as("ak") type_name = None - for ch in chns: - evt_ids_ch = np.searchsorted(cumulength, np.where(ids == ch)[0], "right") - res = utils.get_data_at_channel( - ch=utils.get_table_name_by_pattern(tcm_id_table_pattern, ch), - ids=ids, - idx=idx, - expr=expr, - exprl=exprl, - var_ph=var_ph, - is_evaluated=utils.get_table_name_by_pattern(tcm_id_table_pattern, ch) - not in chns_rm, - f_hit=f_hit, - f_dsp=f_dsp, - defv=defv, - tcm_id_table_pattern=tcm_id_table_pattern, - evt_group=evt_group, - hit_group=hit_group, - dsp_group=dsp_group, + for ch in channels: + table_name = utils.get_table_name_by_pattern(f.hit.table_fmt, ch) + + evt_ids_ch = np.searchsorted( + tcm.cumulative_length, np.where(tcm.id == ch)[0], "right" ) + if table_name not in channels_skip: + res = utils.get_data_at_channel( + datainfo=datainfo, + ch=table_name, + tcm=tcm, + expr=expr, + field_list=field_list, + pars_dict=pars_dict, + ) + else: + idx_ch = tcm.idx[tcm.id == ch] + res = np.full(len(idx_ch), default_value) # see in which events the current channel is present mask = ak.to_numpy(ak.any(ch_comp == ch, axis=-1), allow_missing=False) @@ -448,231 +370,181 @@ def evaluate_at_channel_vov( out = ak.concatenate((out, cv), axis=-1) - if ch == chns[0]: + if ch == channels[0]: type_name = res.dtype - return VectorOfVectors(ak.values_astype(out, type_name), dtype=type_name) + return types.VectorOfVectors(ak.values_astype(out, type_name)) def evaluate_to_aoesa( - cumulength: NDArray, - idx: NDArray, - ids: NDArray, - f_hit: str, - f_dsp: str, - chns: list, - chns_rm: list, - expr: str, - exprl: list, - qry: str | NDArray, - nrows: int, - var_ph: dict = None, - defv: bool | int | float = np.nan, - missv=np.nan, - tcm_id_table_pattern: str = "ch{}", - evt_group: str = "evt", - hit_group: str = "hit", - dsp_group: str = "dsp", -) -> ArrayOfEqualSizedArrays: + datainfo, + tcm, + channels, + channels_skip, + expr, + field_list, + query, + n_rows, + pars_dict=None, + default_value=np.nan, + missing_value=np.nan, +) -> types.ArrayOfEqualSizedArrays: """Aggregates by returning an :class:`.ArrayOfEqualSizedArrays` of evaluated expressions of channels that fulfill a query expression. Parameters ---------- - idx - `tcm` index array. - ids - `tcm` id array. - f_hit - path to `hit` tier file. - f_dsp - path to `dsp` tier file. - chns + datainfo + input and output LH5 datainfo with HDF5 groups where tables are found. + tcm + TCM data arrays in an object that can be accessed by attribute. + channels list of channels to be aggregated. - chns_rm + channels_skip list of channels to be skipped from evaluation and set to default value. expr expression string to be evaluated. - exprl + field_list list of `dsp/hit/evt` parameter tuples in expression ``(tier, field)``. - qry + query query expression to mask aggregation. - nrows + n_rows length of output :class:`.VectorOfVectors`. ch_comp array of "rawid"s at which the expression is evaluated. - var_ph + pars_dict dictionary of `evt` and additional parameters and their values. - defv + default_value default value. - missv + missing_value missing value. sorter sorts the entries in the vector according to sorter expression. - tcm_id_table_pattern - pattern to format `tcm` id values to table name in higher tiers. Must have one - placeholder which is the `tcm` id. - dsp_group - LH5 root group in `dsp` file. - hit_group - LH5 root group in `hit` file. - evt_group - LH5 root group in `evt` file. """ + f = utils.make_files_config(datainfo) + # define dimension of output array - out = np.full((nrows, len(chns)), missv) + dtype = None + out = None + + for i, ch in enumerate(channels): + table_id = utils.get_tcm_id_by_pattern(f.hit.table_fmt, ch) + idx_ch = tcm.idx[tcm.id == table_id] - i = 0 - for ch in chns: - idx_ch = idx[ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch)] evt_ids_ch = np.searchsorted( - cumulength, - np.where(ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch))[0], + tcm.cumulative_length, + np.where(tcm.id == table_id)[0], "right", ) - res = utils.get_data_at_channel( - ch=ch, - ids=ids, - idx=idx, - expr=expr, - exprl=exprl, - var_ph=var_ph, - is_evaluated=ch not in chns_rm, - f_hit=f_hit, - f_dsp=f_dsp, - defv=defv, - tcm_id_table_pattern=tcm_id_table_pattern, - evt_group=evt_group, - hit_group=hit_group, - dsp_group=dsp_group, - ) + + if ch not in channels_skip: + res = utils.get_data_at_channel( + datainfo=datainfo, + ch=ch, + tcm=tcm, + expr=expr, + field_list=field_list, + pars_dict=pars_dict, + ) + + if dtype is None: + dtype = res.dtype + + if out is None: + out = utils.make_numpy_full( + (n_rows, len(channels)), missing_value, res.dtype + ) + else: + res = np.full(len(idx_ch), default_value) # get mask from query limarr = utils.get_mask_from_query( - qry=qry, + datainfo=datainfo, + query=query, length=len(res), ch=ch, idx_ch=idx_ch, - f_hit=f_hit, - f_dsp=f_dsp, - hit_group=hit_group, - dsp_group=dsp_group, ) out[evt_ids_ch, i] = np.where(limarr, res, out[evt_ids_ch, i]) - i += 1 - - return ArrayOfEqualSizedArrays(nda=out) + return out, dtype def evaluate_to_vector( - cumulength: NDArray, - idx: NDArray, - ids: NDArray, - f_hit: str, - f_dsp: str, - chns: list, - chns_rm: list, - expr: str, - exprl: list, - qry: str | NDArray, - nrows: int, - var_ph: dict = None, - defv: bool | int | float = np.nan, - sorter: str = None, - tcm_id_table_pattern: str = "ch{}", - evt_group: str = "evt", - hit_group: str = "hit", - dsp_group: str = "dsp", -) -> VectorOfVectors: + datainfo, + tcm, + channels, + channels_skip, + expr, + field_list, + query, + n_rows, + pars_dict=None, + default_value=np.nan, + sorter=None, +) -> types.VectorOfVectors: """Aggregates by returning a :class:`.VectorOfVector` of evaluated expressions of channels that fulfill a query expression. Parameters ---------- - idx - `tcm` index array. - ids - `tcm` id array. - f_hit - path to `hit` tier file. - f_dsp - path to `dsp` tier file. - chns + datainfo + input and output LH5 datainfo with HDF5 groups where tables are found. + tcm + TCM data arrays in an object that can be accessed by attribute. + channels list of channels to be aggregated. - chns_rm + channels_skip list of channels to be skipped from evaluation and set to default value. expr expression string to be evaluated. - exprl + field_list list of `dsp/hit/evt` parameter tuples in expression ``(tier, field)``. - qry + query query expression to mask aggregation. - nrows + n_rows length of output :class:`.VectorOfVectors`. ch_comp array of "rawids" at which the expression is evaluated. - var_ph + pars_dict dictionary of `evt` and additional parameters and their values. - defv + default_value default value. sorter sorts the entries in the vector according to sorter expression. ``ascend_by:`` results in an vector ordered ascending, ``decend_by:`` sorts descending. - tcm_id_table_pattern - pattern to format `tcm` id values to table name in higher tiers. Must have one - placeholder which is the `tcm` id. - dsp_group - LH5 root group in `dsp` file. - hit_group - LH5 root group in `hit` file. - evt_group - LH5 root group in `evt` file. """ - out = evaluate_to_aoesa( - cumulength=cumulength, - idx=idx, - ids=ids, - f_hit=f_hit, - f_dsp=f_dsp, - chns=chns, - chns_rm=chns_rm, + out, dtype = evaluate_to_aoesa( + datainfo=datainfo, + tcm=tcm, + channels=channels, + channels_skip=channels_skip, expr=expr, - exprl=exprl, - qry=qry, - nrows=nrows, - var_ph=var_ph, - defv=defv, - missv=np.nan, - tcm_id_table_pattern=tcm_id_table_pattern, - evt_group=evt_group, - hit_group=hit_group, - dsp_group=dsp_group, - ).view_as("np") + field_list=field_list, + query=query, + n_rows=n_rows, + pars_dict=pars_dict, + default_value=default_value, + missing_value=np.nan, + ) # if a sorter is given sort accordingly if sorter is not None: md, fld = sorter.split(":") - s_val = evaluate_to_aoesa( - cumulength=cumulength, - idx=idx, - ids=ids, - f_hit=f_hit, - f_dsp=f_dsp, - chns=chns, - chns_rm=chns_rm, + s_val, _ = evaluate_to_aoesa( + datainfo=datainfo, + tcm=tcm, + channels=channels, + channels_skip=channels_skip, expr=fld, - exprl=[tuple(fld.split("."))], - qry=None, - nrows=nrows, - missv=np.nan, - tcm_id_table_pattern=tcm_id_table_pattern, - evt_group=evt_group, - hit_group=hit_group, - dsp_group=dsp_group, - ).view_as("np") + field_list=[tuple(fld.split("."))], + query=None, + n_rows=n_rows, + missing_value=np.nan, + ) + if "ascend_by" == md: out = out[np.arange(len(out))[:, None], np.argsort(s_val)] @@ -683,7 +555,6 @@ def evaluate_to_vector( "sorter values can only have 'ascend_by' or 'descend_by' prefixes" ) - return VectorOfVectors( - ak.values_astype(ak.drop_none(ak.nan_to_none(ak.Array(out))), type(defv)), - dtype=type(defv), + return types.VectorOfVectors( + ak.values_astype(ak.drop_none(ak.nan_to_none(ak.Array(out))), dtype) ) diff --git a/src/pygama/evt/build_evt.py b/src/pygama/evt/build_evt.py index 5f7949bdb..3620dd373 100644 --- a/src/pygama/evt/build_evt.py +++ b/src/pygama/evt/build_evt.py @@ -4,59 +4,67 @@ from __future__ import annotations +import importlib import itertools -import json import logging import re -from importlib import import_module +from collections.abc import Mapping, Sequence +from typing import Any import awkward as ak import numpy as np from lgdo import Array, ArrayOfEqualSizedArrays, Table, VectorOfVectors, lh5 -from lgdo.lh5 import LH5Store +from ..utils import load_dict from . import aggregators, utils log = logging.getLogger(__name__) def build_evt( - f_tcm: str, - f_dsp: str, - f_hit: str, - evt_config: str | dict, - f_evt: str | None = None, + datainfo: utils.DataInfo | Mapping[str, Sequence[str, ...]], + config: str | Mapping[str, ...], wo_mode: str = "write_safe", - evt_group: str = "evt", - tcm_group: str = "hardware_tcm_1", - dsp_group: str = "dsp", - hit_group: str = "hit", - tcm_id_table_pattern: str = "ch{}", ) -> None | Table: - """Transform data from the `hit` and `dsp` levels which a channel sorted to a - event sorted data format. + r"""Transform data from hit-structured tiers to event-structured data. Parameters ---------- - f_tcm - input LH5 file of the `tcm` level. - f_dsp - input LH5 file of the `dsp` level. - f_hit - input LH5 file of the `hit` level. - evt_config + datainfo + input and output LH5 datainfo with HDF5 groups where tables are found, + (see :obj:`.utils.DataInfo`). Example: :: + + # syntax: {"tier-name": ("file-name", "hdf5-group"[, "table-format"])} + { + "tcm": ("data-tier_tcm.lh5", "hardware_tcm_1"), + "dsp": ("data-tier_dsp.lh5", "dsp", "ch{}"), + "hit": ("data-tier_hit.lh5", "hit", "ch{}"), + "evt": ("data-tier_evt.lh5", "evt") + } + + config name of configuration file or dictionary defining event fields. Channel lists can be defined by importing a metadata module. - - ``operations`` defines the fields ``name=key``, where ``channels`` - specifies the channels used to for this field (either a string or a - list of strings), + - ``channels`` specifies the channels used to for this field (either a + string or a list of strings). + - ``operations`` defines the event fields (``name=key``). If the key + contains slahes it will be interpreted as the path to the output + field inside nested sub-tables. + - ``outputs`` defines the fields that are actually included in the + output table. + + Inside the ``operations`` block: + - ``aggregation_mode`` defines how the channels should be combined (see :func:`evaluate_expression`). - - ``expression`` defnies the mathematical/special function to apply + - ``expression`` defines the expression or function call to apply (see :func:`evaluate_expression`), - ``query`` defines an expression to mask the aggregation. - ``parameters`` defines any other parameter used in expression. + - ``dtype`` defines the NumPy data type of the resulting data. + - ``initial`` defines the initial/default value. Useful with some types + of aggregators. For example: @@ -68,6 +76,7 @@ def build_evt( "spms_on": ["ch1057600", "ch1059201", "ch1062405"], "muon": "ch1027202", }, + "outputs": ["energy_id", "multiplicity"], "operations": { "energy_id":{ "channels": "geds_on", @@ -83,90 +92,69 @@ def build_evt( "is_muon_rejected":{ "channels": "muon", "aggregation_mode": "any", - "expression": "dsp.wf_max>a", - "parameters": {"a":15100}, + "expression": "dsp.wf_max > a", + "parameters": {"a": 15100}, "initial": false }, "multiplicity":{ "channels": ["geds_on", "geds_no_psd", "geds_ac"], "aggregation_mode": "sum", "expression": "hit.cuspEmax_ctc_cal > a", - "parameters": {"a":25}, + "parameters": {"a": 25}, "initial": 0 }, "t0":{ "aggregation_mode": "keep_at_ch:evt.energy_id", - "expression": "dsp.tp_0_est" + "expression": "dsp.tp_0_est", + "initial": "np.nan" }, "lar_energy":{ "channels": "spms_on", "aggregation_mode": "function", - "expression": ".modules.spm.get_energy(0.5, evt.t0, 48000, 1000, 5000)" + "expression": "pygama.evt.modules.spms.gather_pulse_data(<...>, observable='hit.energy_in_pe')" }, } } - f_evt - name of the output file. If ``None``, return the output :class:`.Table` - instead of writing to disk. wo_mode - writing mode. - evt group - LH5 root group name of `evt` tier. - tcm_group - LH5 root group in `tcm` file. - dsp_group - LH5 root group in `dsp` file. - hit_group - LH5 root group in `hit` file. - tcm_id_table_pattern - pattern to format `tcm` id values to table name in higher tiers. Must - have one placeholder which is the `tcm` id. + writing mode, see :func:`lgdo.lh5.core.write`. """ + if not isinstance(config, dict): + config = load_dict(config) - store = LH5Store() - tbl_cfg = evt_config - if not isinstance(tbl_cfg, (str, dict)): - raise TypeError() - if isinstance(tbl_cfg, str): - with open(tbl_cfg) as f: - tbl_cfg = json.load(f) - - if "channels" not in tbl_cfg.keys(): + if "channels" not in config.keys(): raise ValueError("channel field needs to be specified in the config") - if "operations" not in tbl_cfg.keys(): + if "operations" not in config.keys(): raise ValueError("operations field needs to be specified in the config") - # check tcm_id_table_pattern validity - pattern_check = re.findall(r"{([^}]*?)}", tcm_id_table_pattern) + # convert into a nice named tuple + f = utils.make_files_config(datainfo) + + # check chname_fmt validity + chname_fmt = f.hit.table_fmt + pattern_check = re.findall(r"{([^}]*?)}", chname_fmt) if len(pattern_check) != 1: - raise ValueError( - f"tcm_id_table_pattern must have exactly one placeholder. {tcm_id_table_pattern} is invalid." - ) + raise ValueError("chname_fmt must have exactly one placeholder {}") elif "{" in pattern_check[0] or "}" in pattern_check[0]: - raise ValueError( - f"tcm_id_table_pattern {tcm_id_table_pattern} has an invalid placeholder." - ) + raise ValueError(f"{chname_fmt=} has an invalid placeholder.") if ( utils.get_table_name_by_pattern( - tcm_id_table_pattern, - utils.get_tcm_id_by_pattern(tcm_id_table_pattern, lh5.ls(f_hit)[0]), + chname_fmt, + utils.get_tcm_id_by_pattern(chname_fmt, lh5.ls(f.hit.file)[0]), ) - != lh5.ls(f_hit)[0] + != lh5.ls(f.hit.file)[0] ): - raise ValueError( - f"tcm_id_table_pattern {tcm_id_table_pattern} does not match keys in data!" - ) + raise ValueError(f"chname_fmt {chname_fmt} does not match keys in data!") # create channel list according to config # This can be either read from the meta data # or a list of channel names - log.debug("Creating channel dictionary") + log.debug("creating channel dictionary") - chns = {} + channels = {} - for k, v in tbl_cfg["channels"].items(): + for key, v in config["channels"].items(): if isinstance(v, dict): # it is a meta module. module_name must exist if "module" not in v.keys(): @@ -175,10 +163,9 @@ def build_evt( ) attr = {} - # the time_key argument is set to the time key of the DSP file - # in case it is not provided by the config + # the time_key argument is mandatory if "time_key" not in v.keys(): - attr["time_key"] = re.search(r"\d{8}T\d{6}Z", f_dsp).group(0) + raise RuntimeError("the 'time_key' configuration field is mandatory") # if "None" do None elif "None" == v["time_key"]: @@ -186,160 +173,179 @@ def build_evt( # load module p, m = v["module"].rsplit(".", 1) - met = getattr(import_module(p, package=__package__), m) - chns[k] = met(v | attr) + met = getattr(importlib.import_module(p, package=__package__), m) + channels[key] = met(v | attr) elif isinstance(v, str): - chns[k] = [v] + channels[key] = [v] elif isinstance(v, list): - chns[k] = [e for e in v] - - nrows = store.read_n_rows(f"/{tcm_group}/cumulative_length", f_tcm) + channels[key] = [e for e in v] + + # load tcm data from disk + tcm = utils.TCMData( + id=lh5.read_as(f"/{f.tcm.group}/array_id", f.tcm.file, library="np"), + idx=lh5.read_as(f"/{f.tcm.group}/array_idx", f.tcm.file, library="np"), + cumulative_length=lh5.read_as( + f"/{f.tcm.group}/cumulative_length", f.tcm.file, library="np" + ), + ) - table = Table(size=nrows) + # get number of events in file (ask the TCM) + n_rows = len(tcm.cumulative_length) + table = Table(size=n_rows) - for k, v in tbl_cfg["operations"].items(): - log.debug("Processing field " + k) + # now loop over operations (columns in evt table) + for field, v in config["operations"].items(): + log.debug(f"processing field: '{field}'") - # if mode not defined in operation, it can only be an operation on the evt level. + # if mode not defined in operation, it can only be an operation on the + # evt level if "aggregation_mode" not in v.keys(): - var = {} - if "parameters" in v.keys(): - var = var | v["parameters"] - res = table.eval(v["expression"].replace(f"{evt_group}.", ""), var) + # compute and eventually get rid of evt. suffix + obj = table.eval( + v["expression"].replace("evt.", ""), v.get("parameters", {}) + ) - # add attribute if present + # add attributes if present if "lgdo_attrs" in v.keys(): - res.attrs |= v["lgdo_attrs"] - - table.add_field(k, res) + obj.attrs |= v["lgdo_attrs"] - # Else we build the event entry + # else we build the event entry else: if "channels" not in v.keys(): - chns_e = [] + channels_e = [] elif isinstance(v["channels"], str): - chns_e = chns[v["channels"]] + channels_e = channels[v["channels"]] elif isinstance(v["channels"], list): - chns_e = list( - itertools.chain.from_iterable([chns[e] for e in v["channels"]]) + channels_e = list( + itertools.chain.from_iterable([channels[e] for e in v["channels"]]) ) - chns_rm = [] + channels_skip = [] if "exclude_channels" in v.keys(): if isinstance(v["exclude_channels"], str): - chns_rm = chns[v["exclude_channels"]] + channels_skip = channels[v["exclude_channels"]] elif isinstance(v["exclude_channels"], list): - chns_rm = list( + channels_skip = list( itertools.chain.from_iterable( - [chns[e] for e in v["exclude_channels"]] + [channels[e] for e in v["exclude_channels"]] ) ) - pars, qry, defaultv, srter = None, None, np.nan, None - if "parameters" in v.keys(): - pars = v["parameters"] - if "query" in v.keys(): - qry = v["query"] - if "initial" in v.keys(): - defaultv = v["initial"] - if isinstance(defaultv, str) and ( - defaultv in ["np.nan", "np.inf", "-np.inf"] - ): - defaultv = eval(defaultv) - if "sort" in v.keys(): - srter = v["sort"] + defaultv = v.get("initial", np.nan) + if isinstance(defaultv, str) and ( + defaultv in ["np.nan", "np.inf", "-np.inf"] + ): + defaultv = eval(defaultv) obj = evaluate_expression( - f_tcm=f_tcm, - f_hit=f_hit, - f_dsp=f_dsp, - chns=chns_e, - chns_rm=chns_rm, + datainfo, + tcm, + channels=channels_e, + channels_skip=channels_skip, mode=v["aggregation_mode"], expr=v["expression"], - nrows=nrows, + n_rows=n_rows, table=table, - para=pars, - qry=qry, - defv=defaultv, - sorter=srter, - tcm_id_table_pattern=tcm_id_table_pattern, - evt_group=evt_group, - hit_group=hit_group, - dsp_group=dsp_group, - tcm_group=tcm_group, + parameters=v.get("parameters", None), + query=v.get("query", None), + default_value=defaultv, + sorter=v.get("sort", None), ) # add attribute if present if "lgdo_attrs" in v.keys(): obj.attrs |= v["lgdo_attrs"] - table.add_field(k, obj) - - # write output fields into f_evt - if "outputs" in tbl_cfg.keys(): - if len(tbl_cfg["outputs"]) < 1: - log.warning("No output fields specified, no file will be written.") - return table - else: - clms_to_remove = [e for e in table.keys() if e not in tbl_cfg["outputs"]] - for fld in clms_to_remove: - table.remove_field(fld, True) - - if f_evt: - store.write( - obj=table, name=f"/{evt_group}/", lh5_file=f_evt, wo_mode=wo_mode - ) - else: - return table + # cast to type, if required + # hijack the poor LGDO + if "dtype" in v: + type_ = v["dtype"] + + if isinstance(obj, Array): + obj.nda = obj.nda.astype(type_) + if isinstance(obj, VectorOfVectors): + fldata_ptr = obj.flattened_data + while isinstance(fldata_ptr, VectorOfVectors): + fldata_ptr = fldata_ptr.flattened_data + + fldata_ptr.nda = fldata_ptr.nda.astype(type_) + + log.debug(f"new column {field!s} = {obj!r}") + table.add_field(field, obj) + + # might need to re-organize fields in subtables, create a new object for that + nested_tbl = Table(size=n_rows) + output_fields = config.get("outputs", table.keys()) + + for field, obj in table.items(): + # also only add fields requested by the user + if field not in output_fields: + continue + + # if names contain slahes, put in sub-tables + lvl_ptr = nested_tbl + subfields = field.strip("/").split("___") + for level in subfields: + # if we are at the end, just add the field + if level == subfields[-1]: + lvl_ptr.add_field(level, obj) + break + + if not level: + msg = f"invalid field name '{field}'" + raise RuntimeError(msg) + + # otherwise, increase nesting + if level not in lvl_ptr: + lvl_ptr.add_field(level, Table(size=n_rows)) + lvl_ptr = lvl_ptr[level] + + # write output fields into outfile + if output_fields: + if f.evt.file is None: + return nested_tbl + + lh5.write( + obj=nested_tbl, + name=f.evt.group, + lh5_file=f.evt.file, + wo_mode=wo_mode, + ) else: - log.warning("No output fields specified, no file will be written.") - - key = re.search(r"\d{8}T\d{6}Z", f_hit).group(0) - log.info( - f"Applied {len(tbl_cfg['operations'])} operations to key {key} and saved " - f"{len(tbl_cfg['outputs'])} evt fields across {len(chns)} channel groups" - ) + log.warning("no output fields specified, no file will be written.") + return nested_tbl def evaluate_expression( - f_tcm: str, - f_hit: str, - f_dsp: str, - chns: list, - chns_rm: list, + datainfo: utils.DataInfo | Mapping[str, Sequence[str, ...]], + tcm: utils.TCMData, + channels: Sequence[str], + channels_skip: Sequence[list], mode: str, expr: str, - nrows: int, + n_rows: int, table: Table = None, - para: dict = None, - qry: str = None, - defv: bool | int | float = np.nan, + parameters: Mapping[str, Any] = None, + query: str = None, + default_value: bool | int | float = np.nan, sorter: str = None, - tcm_id_table_pattern: str = "ch{}", - evt_group: str = "evt", - hit_group: str = "hit", - dsp_group: str = "dsp", - tcm_group: str = "tcm", ) -> Array | ArrayOfEqualSizedArrays | VectorOfVectors: """Evaluates the expression defined by the user across all channels according to the mode. Parameters ---------- - f_tcm - path to `tcm` tier file. - f_hit - path to `hit` tier file. - f_dsp - path to `dsp` tier file. - chns - list of channel names across which expression gets evaluated (form: - ``ch``). - chns_rm + datainfo + input and output LH5 files with HDF5 groups where tables are found. + (see :obj:`.utils.DataInfo`) + tcm + tcm data structure (see :obj:`.utils.TCMData`) + channels + list of channel names across which expression gets evaluated + channels_skip list of channels which get set to default value during evaluation. In - function mode they are removed entirely (form: ``ch``) + function mode they are removed entirely mode The mode determines how the event entry is calculated across channels. Options are: @@ -354,118 +360,131 @@ def evaluate_expression( - ``keep_at_ch:ch_field``: aggregates according to passed ch_field. - ``keep_at_idx:tcm_idx_field``: aggregates according to passed tcm index field. - - ``gather``: Channels are not combined, but result saved as + - ``gather``: channels are not combined, but result saved as :class:`.VectorOfVectors`. + - ``function``: the function call specified in `expr` is evaluated, and + the resulting column is inserted into the output table. - qry + query a query that can mask the aggregation. expr the expression. That can be any mathematical equation/comparison. If `mode` is ``function``, the expression needs to be a special processing - function defined in modules (e.g. :func:`.modules.spm.get_energy`). In - the expression parameters from either hit, dsp, evt tier (from - operations performed before this one! Dictionary operations order - matters), or from the ``parameters`` field can be used. - nrows + function defined in :mod:`.modules`. In the expression, parameters from + either `evt` or lower tiers (from operations performed before this one! + Dictionary operations order matters), or from the ``parameters`` field + can be used. Fields can be prefixed with the tier id (e.g. + ``evt.energy`` or `hit.quality_flag``). + n_rows number of rows to be processed. table table of `evt` tier data. - para + parameters dictionary of parameters defined in the ``parameters`` field in the configuration dictionary. - defv + default_value default value of evaluation. sorter can be used to sort vector outputs according to sorter expression (see - :func:`evaluate_to_vector`). - tcm_id_table_pattern - pattern to format tcm id values to table name in higher tiers. Must have one - placeholder which is the `tcm` id. - evt group - LH5 root group name of `evt` tier. - tcm_group - LH5 root group in `tcm` file. - dsp_group - LH5 root group in `dsp` file. - hit_group - LH5 root group in `hit` file. + :func:`.evaluate_to_vector`). + + Note + ---- + The specification of custom functions that can be used as expression is + documented in :mod:`.modules`. """ + f = utils.make_files_config(datainfo) - store = LH5Store() + # build dictionary of parameter names and their values + # a parameter can be a column in the existing table... + pars_dict = {} - # find parameters in evt file or in parameters - exprl = re.findall( - rf"({evt_group}|{hit_group}|{dsp_group}).([a-zA-Z_$][\w$]*)", expr - ) - var_ph = {} - if table: - var_ph = var_ph | { - e: table[e].view_as("ak") - for e in table.keys() - if isinstance(table[e], (Array, ArrayOfEqualSizedArrays, VectorOfVectors)) + if table is not None: + pars_dict = { + k: v for k, v in table.items() if isinstance(v, (Array, VectorOfVectors)) } - if para: - var_ph = var_ph | para + + # ...or defined through the configuration + if parameters: + pars_dict = pars_dict | parameters if mode == "function": - # evaluate expression - func, params = expr.split("(") - params = ( - params.replace(f"{dsp_group}.", f"{dsp_group}_") - .replace(f"{hit_group}.", f"{hit_group}_") - .replace(f"{evt_group}.", "") + # syntax: + # + # pygama.evt.modules.spms.my_func([...], arg1=val, arg2=val) + + # get arguments list passed to the function (outermost parentheses) + args_str = re.search(r"\((.*)\)$", expr.strip()).group(1) + + # handle tier scoping: evt.<> + args_str = args_str.replace("evt.", "") + + good_chns = [x for x in channels if x not in channels_skip] + + # replace stuff before first comma with list of mandatory args + full_args_str = "datainfo, tcm, table_names," + ",".join( + args_str.split(",")[1:] ) - params = [ - f_hit, - f_dsp, - f_tcm, - hit_group, - dsp_group, - tcm_group, - tcm_id_table_pattern, - [x for x in chns if x not in chns_rm], - ] + [utils.num_and_pars(e, var_ph) for e in params[:-1].split(",")] - - # load function dynamically - p, m = func.rsplit(".", 1) - met = getattr(import_module(p, package=__package__), m) - return met(*params) + + # get module and function names + func_call = expr.strip().split("(")[0] + subpackage, func = func_call.rsplit(".", 1) + package = subpackage.split(".")[0] + + # import function into current namespace + log.debug(f"importing module {subpackage}") + importlib.import_module(subpackage, package=__package__) + + # declare imported package as globals (see eval() call later) + globs = { + package: importlib.import_module(package), + } + + # lookup dictionary for variables used in function arguments (see eval() call later) + locs = {"datainfo": f, "tcm": tcm, "table_names": good_chns} | pars_dict + + # evil eval() to avoid annoying args casting logic + call_str = f"{func_call}({full_args_str})" + log.debug(f"evaluating {call_str}") + log.debug(f"...globals={globs} and locals={locs}") + log.debug(f"...locals={locs}") + + return eval(call_str, globs, locs) else: + # find parameters in evt file or in parameters + field_list = re.findall( + rf"({'|'.join(f._asdict().keys())}).([a-zA-Z_$][\w$]*)", expr + ) + # check if query is either on channel basis or evt basis (and not a mix) - qry_mask = qry - if qry is not None: - if f"{evt_group}." in qry and ( - f"{hit_group}." in qry or f"{dsp_group}." in qry - ): + query_mask = query + if query is not None: + hit_tiers = [k for k in f._asdict() if k != "evt"] + if "evt." in query and (any([t in query for t in hit_tiers])): raise ValueError( - f"Query can't be a mix of {evt_group} tier and lower tiers." + f"Query can't be a mix of {f.evt.group} tier and lower tiers." ) # if it is an evt query we can evaluate it directly here - if table and f"{evt_group}." in qry: - qry_mask = eval(qry.replace(f"{evt_group}.", ""), table) - - # load TCM data to define an event - ids = store.read(f"/{tcm_group}/array_id", f_tcm)[0].view_as("np") - idx = store.read(f"/{tcm_group}/array_idx", f_tcm)[0].view_as("np") - cumulength = store.read(f"/{tcm_group}/cumulative_length", f_tcm)[0].view_as( - "np" - ) + if table and "evt." in query: + query_mask = eval(query.replace("evt.", ""), table) # switch through modes - if table and (("keep_at_ch:" == mode[:11]) or ("keep_at_idx:" == mode[:12])): - if "keep_at_ch:" == mode[:11]: - ch_comp = table[mode[11:].replace(f"{evt_group}.", "")] + if table and ( + mode.startswith("keep_at_ch:") or mode.startswith("keep_at_idx:") + ): + if mode.startswith("keep_at_ch:"): + ch_comp = table[mode[11:].replace("evt.", "")] else: - ch_comp = table[mode[12:].replace(f"{evt_group}.", "")] + ch_comp = table[mode[12:].replace("evt.", "")] if isinstance(ch_comp, Array): - ch_comp = Array(nda=ids[ch_comp.view_as("np")]) + ch_comp = Array(tcm.id[ch_comp.view_as("np")]) elif isinstance(ch_comp, VectorOfVectors): ch_comp = ch_comp.view_as("ak") ch_comp = VectorOfVectors( - array=ak.unflatten( - ids[ak.flatten(ch_comp)], ak.count(ch_comp, axis=-1) + ak.unflatten( + tcm.id[ak.flatten(ch_comp)], ak.count(ch_comp, axis=-1) ) ) else: @@ -476,114 +495,82 @@ def evaluate_expression( if isinstance(ch_comp, Array): return aggregators.evaluate_at_channel( - cumulength=cumulength, - idx=idx, - ids=ids, - f_hit=f_hit, - f_dsp=f_dsp, - chns_rm=chns_rm, + datainfo=datainfo, + tcm=tcm, + channels_skip=channels_skip, expr=expr, - exprl=exprl, + field_list=field_list, ch_comp=ch_comp, - var_ph=var_ph, - defv=defv, - tcm_id_table_pattern=tcm_id_table_pattern, - evt_group=evt_group, - hit_group=hit_group, - dsp_group=dsp_group, + pars_dict=pars_dict, + default_value=default_value, ) - elif isinstance(ch_comp, VectorOfVectors): + + if isinstance(ch_comp, VectorOfVectors): return aggregators.evaluate_at_channel_vov( - cumulength=cumulength, - idx=idx, - ids=ids, - f_hit=f_hit, - f_dsp=f_dsp, + datainfo=datainfo, + tcm=tcm, expr=expr, - exprl=exprl, + field_list=field_list, ch_comp=ch_comp, - chns_rm=chns_rm, - var_ph=var_ph, - defv=defv, - tcm_id_table_pattern=tcm_id_table_pattern, - evt_group=evt_group, - hit_group=hit_group, - dsp_group=dsp_group, - ) - else: - raise NotImplementedError( - type(ch_comp) - + " not supported (only Array and VectorOfVectors are supported)" + channels_skip=channels_skip, + pars_dict=pars_dict, + default_value=default_value, ) - elif "first_at:" in mode or "last_at:" in mode: + + raise NotImplementedError( + "{type(ch_comp).__name__} not supported " + "(only Array and VectorOfVectors are supported)" + ) + + if "first_at:" in mode or "last_at:" in mode: sorter = tuple( re.findall( - rf"({evt_group}|{hit_group}|{dsp_group}).([a-zA-Z_$][\w$]*)", + rf"({'|'.join(f._asdict().keys())}).([a-zA-Z_$][\w$]*)", mode.split("first_at:")[-1], )[0] ) return aggregators.evaluate_to_first_or_last( - cumulength=cumulength, - idx=idx, - ids=ids, - f_hit=f_hit, - f_dsp=f_dsp, - chns=chns, - chns_rm=chns_rm, + datainfo=datainfo, + tcm=tcm, + channels=channels, + channels_skip=channels_skip, expr=expr, - exprl=exprl, - qry=qry_mask, - nrows=nrows, + field_list=field_list, + query=query_mask, + n_rows=n_rows, sorter=sorter, - var_ph=var_ph, - defv=defv, + pars_dict=pars_dict, + default_value=default_value, is_first=True if "first_at:" in mode else False, - tcm_id_table_pattern=tcm_id_table_pattern, - evt_group=evt_group, - hit_group=hit_group, - dsp_group=dsp_group, ) - elif mode in ["sum", "any", "all"]: + + if mode in ["sum", "any", "all"]: return aggregators.evaluate_to_scalar( + datainfo=datainfo, + tcm=tcm, mode=mode, - cumulength=cumulength, - idx=idx, - ids=ids, - f_hit=f_hit, - f_dsp=f_dsp, - chns=chns, - chns_rm=chns_rm, + channels=channels, + channels_skip=channels_skip, expr=expr, - exprl=exprl, - qry=qry_mask, - nrows=nrows, - var_ph=var_ph, - defv=defv, - tcm_id_table_pattern=tcm_id_table_pattern, - evt_group=evt_group, - hit_group=hit_group, - dsp_group=dsp_group, + field_list=field_list, + query=query_mask, + n_rows=n_rows, + pars_dict=pars_dict, + default_value=default_value, ) - elif "gather" == mode: + if mode == "gather": return aggregators.evaluate_to_vector( - cumulength=cumulength, - idx=idx, - ids=ids, - f_hit=f_hit, - f_dsp=f_dsp, - chns=chns, - chns_rm=chns_rm, + datainfo=datainfo, + tcm=tcm, + channels=channels, + channels_skip=channels_skip, expr=expr, - exprl=exprl, - qry=qry_mask, - nrows=nrows, - var_ph=var_ph, - defv=defv, + field_list=field_list, + query=query_mask, + n_rows=n_rows, + pars_dict=pars_dict, + default_value=default_value, sorter=sorter, - tcm_id_table_pattern=tcm_id_table_pattern, - evt_group=evt_group, - hit_group=hit_group, - dsp_group=dsp_group, ) - else: - raise ValueError(mode + " not a valid mode") + + raise ValueError(f"'{mode}' is not a valid mode") diff --git a/src/pygama/evt/modules/__init__.py b/src/pygama/evt/modules/__init__.py index bd80462f8..a17d33d7a 100644 --- a/src/pygama/evt/modules/__init__.py +++ b/src/pygama/evt/modules/__init__.py @@ -1,21 +1,34 @@ -""" -Contains submodules for evt processing -""" +"""This subpackage provides some custom processors to process hit-structured +data into event-structured data. + +Custom processors must adhere to the following signature: :: + + def my_evt_processor( + datainfo, + tcm, + table_names, + *, # all following arguments are keyword-only + arg1, + arg2, + ... + ) -> LGDO: + # ... -from .spm import ( - get_energy, - get_energy_dplms, - get_etc, - get_majority, - get_majority_dplms, - get_time_shift, -) +The first three arguments are automatically supplied by :func:`.build_evt`, +when the function is called from the :func:`.build_evt` configuration. -__all__ = [ - "get_energy", - "get_majority", - "get_energy_dplms", - "get_majority_dplms", - "get_etc", - "get_time_shift", -] +- `datainfo`: a :obj:`.DataInfo` object that specifies tier names, file names, + HDF5 groups in which data is found and pattern used by hit table names to + encode the channel identifier (e.g. ``ch{}``). +- `tcm`: :obj:`.TCMData` object that holds the TCM data, to be used for event + reconstruction. +- `table_names`: a list of hit table names to read the data from. + +The remaining arguments are characteristic to the processor and can be supplied +in the function call from the :func:`.build_evt` configuration. + +The function must return an :class:`~lgdo.types.lgdo.LGDO` object suitable for +insertion in the final table with event data. + +For examples, have a look at the existing processors provided by this subpackage. +""" diff --git a/src/pygama/evt/modules/geds.py b/src/pygama/evt/modules/geds.py new file mode 100644 index 000000000..a655eb56f --- /dev/null +++ b/src/pygama/evt/modules/geds.py @@ -0,0 +1,46 @@ +"""Event processors for HPGe data.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from lgdo import lh5, types + +from .. import utils + + +def apply_xtalk_correction( + datainfo: utils.DataInfo, + tcm: utils.TCMData, + table_names: Sequence[str], + *, + energy_observable: types.VectorOfVectors, + rawids: types.VectorOfVectors, + xtalk_matrix_filename: str, +) -> types.VectorOfVectors: + """Applies the cross-talk correction to the energy observable. + + The format of `xtalk_matrix_filename` should be... + + Parameters + ---------- + datainfo, tcm, table_names + positional arguments automatically supplied by :func:`.build_evt`. + energy_observable + array of energy values to correct, one event per row. The detector + identifier is stored in `rawids`, which has the same layout. + rawids + array of detector identifiers for each energy in `energy_observable`. + xtalk_matrix_filename + name of the file containing the cross-talk matrices. + """ + # read in xtalk matrices + lh5.read_as("", xtalk_matrix_filename, "ak") + + # do the correction + energies_corr = ... + + # return the result as LGDO + return types.VectorOfVectors( + energies_corr, attrs=utils.copy_lgdo_attrs(energy_observable) + ) diff --git a/src/pygama/evt/modules/larveto.py b/src/pygama/evt/modules/larveto.py new file mode 100644 index 000000000..429076a84 --- /dev/null +++ b/src/pygama/evt/modules/larveto.py @@ -0,0 +1,160 @@ +"""Routines to evaluate the correlation between HPGe and SiPM signals.""" + +from __future__ import annotations + +import awkward as ak +import numpy as np +import scipy +from numpy.typing import ArrayLike + + +def l200_combined_test_stat( + t0: ak.Array, + amp: ak.Array, + geds_t0: ak.Array, +) -> ak.Array: + """Combined L200 LAr veto classifier. + + Where combined means taking channel-specific parameters into account. + + `t0` and `amp` must be in the format of a 3-dimensional Awkward array, + where the innermost dimension labels the SiPM pulse, the second one labels + the SiPM channel and the outermost one labels the event. + + Parameters + ---------- + t0 + arrival times of pulses in ns, split by channel. + amp + amplitude of pulses in p.e., split by channel. + geds_t0 + t0 (ns) of the HPGe signal. + """ + # flatten the data in the last axis (i.e. merge all channels together) + # TODO: implement channel distinction + t0 = ak.flatten(t0, axis=-1) + amp = ak.flatten(amp, axis=-1) + + # subtract the HPGe t0 from the SiPM pulse t0s + # HACK: remove 16 when units will be fixed + rel_t0 = 16 * t0 - geds_t0 + + return l200_test_stat(rel_t0, amp) + + +def l200_test_stat(relative_t0, amp): + """Compute the test statistics. + + Parameters + ---------- + relative_t0 + t0 (ns) of the SiPM pulses relative to the HPGe t0. + amp + amplitude in p.e. of the SiPM pulses. + """ + return -ak.sum(ak.transform(_ak_l200_test_stat_terms, relative_t0, amp), axis=-1) + + +# need to define this function and use it with ak.transform() because scipy +# routines are not NumPy universal functions +def _ak_l200_test_stat_terms(layouts, **kwargs): + """Awkward transform to compute the per-pulse terms of the test statistics. + + The two arguments are the pulse times `t0` and their amplitude `amp`. The + function has to be invoked as ``ak.transform(_ak_l200_test_stat_terms, t0, amp, + ...)``. + """ + # sanity check + assert len(layouts) == 2 + + if not all([lay.is_numpy for lay in layouts]): + return + + # these are the two supported arguments + t0 = layouts[0].data + amp = layouts[1].data + + # sanity check + assert len(t0) == len(amp) + + # if there are no pulses return NaN + if len(t0) == 0 or any(np.isnan(t0)): + return ak.contents.NumpyArray([np.nan]) + + # convert to integer number of pes + n_pes = pulse_amp_round(amp) + n_pe_tot = np.sum(n_pes) + + t_stat = n_pes * np.log(l200_tc_time_pdf(t0)) / n_pe_tot + np.log( + l200_rc_amp_pdf(n_pe_tot) + ) + + return ak.contents.NumpyArray(t_stat) + + +def pulse_amp_round(amp: float | ArrayLike): + """Get the most likely (integer) number of photo-electrons.""" + # promote all amps < 1 to 1. standard rounding to nearest for + # amps > 1 + return ak.where(amp < 1, np.ceil(amp), np.floor(amp + 0.5)) + + +def l200_tc_time_pdf( + t0: float | ArrayLike, + *, + domain_ns: tuple[float] = (-1_000, 5_000), + tau_singlet_ns: float = 6, + tau_triplet_ns: float = 1100, + sing2trip_ratio: float = 1 / 3, + t0_res_ns: float = 35, + t0_bias_ns: float = -80, + bkg_prob: float = 0.42, +) -> float | ArrayLike: + """The L200 experimental LAr scintillation pdf + + The theoretical scintillation pdf convoluted with a Normal distribution + (experimental effects) and summed to a uniform distribution (uncorrelated + pulses). + + Parameters + ---------- + t0 + arrival times of the SiPM pulses in ns. + tau_singlet_ns + The lifetime of the LAr singlet state in ns. + tau_triplet_ns + The lifetime of the LAr triplet state in ns. + sing2trip_ratio + The singlet-to-triplet excitation probability ratio. + t0_res_ns + sigma (ns) of the normal distribution. + t0_bias_ns + mean (ns) of the normal distribution. + bkg_prob + probability for a pulse coming from some uncorrelated physics (uniform + distribution). + """ + if not np.all(t0 <= domain_ns[1] and t0 >= domain_ns[0]): + msg = f"{t0=} out of bounds for {domain_ns=}" + raise ValueError(msg) + + # TODO: make this a true pdf, i.e. normalize to integral 1 + return ( + # the triplet + (1 - sing2trip_ratio) + * scipy.stats.exponnorm.pdf( + t0, tau_triplet_ns / t0_res_ns, loc=t0_bias_ns, scale=t0_res_ns + ) + # the singlet + + sing2trip_ratio + * scipy.stats.exponnorm.pdf( + t0, tau_singlet_ns / t0_res_ns, loc=t0_bias_ns, scale=t0_res_ns + ) + # the random coincidences (uniform pdf) + + bkg_prob + * scipy.stats.uniform.pdf(t0, domain_ns[0], domain_ns[1] - domain_ns[0]) + ) + + +def l200_rc_amp_pdf(n): + return np.exp(-n) diff --git a/src/pygama/evt/modules/spm.py b/src/pygama/evt/modules/spm.py deleted file mode 100644 index 843c7935e..000000000 --- a/src/pygama/evt/modules/spm.py +++ /dev/null @@ -1,526 +0,0 @@ -""" -Module for special event level routines for SiPMs - -functions must take as the first 8 args in order: -- path to the hit file -- path to the dsp ak.Array: - if isinstance(trgr, Array): - return ak.fill_none(ak.nan_to_none(trgr.view_as("ak")), tdefault) - - elif isinstance(trgr, (VectorOfVectors)): - return ak.fill_none( - ak.min(ak.fill_none(trgr.view_as("ak"), tdefault), axis=-1), tdefault - ) - - elif isinstance(trgr, (ak.Array, ak.highlevel.Array)): - if trgr.ndim == 1: - return ak.fill_none(ak.nan_to_none(trgr), tdefault) - elif trgr.ndim == 2: - return ak.fill_none( - ak.min(ak.fill_none(ak.nan_to_none(trgr), tdefault), axis=-1), tdefault - ) - else: - raise ValueError(f"Too many dimensions: {trgr.ndim}") - elif isinstance(trgr, (float, int)) and isinstance(length, int): - return ak.Array([trgr] * length) - else: - raise ValueError(f"Can't deal with t0 of type {type(trgr)}") - - -# get SiPM coincidence window mask -def get_spm_mask( - lim: float, trgr: ak.Array, tmin: float, tmax: float, pe: ak.Array, times: ak.Array -) -> ak.Array: - if trgr.ndim != 1: - raise ValueError("trigger array muse be 1 dimensional!") - if (len(trgr) != len(pe)) or (len(trgr) != len(times)): - raise ValueError( - f"All arrays must have same dimension across first axis len(pe)={len(pe)}, len(times)={len(times)}, len(trgr)={len(trgr)}" - ) - - tmi = trgr - tmin - tma = trgr + tmax - - mask = ( - ((times * 16.0) < tma[:, None]) & ((times * 16.0) > tmi[:, None]) & (pe > lim) - ) - return mask - - -# get LAr indices according to mask per event over all channels -# mode 0 -> return pulse indices -# mode 1 -> return tcm indices -# mode 2 -> return rawids -# mode 3 -> return tcm_idx -def get_masked_tcm_idx( - f_hit, - f_dsp, - f_tcm, - hit_group, - dsp_group, - tcm_group, - tcm_id_table_pattern, - chs, - lim, - trgr, - tdefault, - tmin, - tmax, - mode=0, -) -> VectorOfVectors: - # load TCM data to define an event - store = LH5Store() - ids = store.read(f"/{tcm_group}/array_id", f_tcm)[0].view_as("np") - idx = store.read(f"/{tcm_group}/array_idx", f_tcm)[0].view_as("np") - - arr_lst = [] - - if isinstance(trgr, (float, int)): - tge = cast_trigger(trgr, tdefault, length=np.max(idx) + 1) - else: - tge = cast_trigger(trgr, tdefault, length=None) - - for ch in chs: - idx_ch = idx[ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch)] - - pe = store.read(f"{ch}/{hit_group}/energy_in_pe", f_hit, idx=idx_ch)[0].view_as( - "np" - ) - tmp = np.full((np.max(idx) + 1, len(pe[0])), np.nan) - tmp[idx_ch] = pe - pe = ak.drop_none(ak.nan_to_none(ak.Array(tmp))) - - # times are in sample units - times = store.read(f"{ch}/{hit_group}/trigger_pos", f_hit, idx=idx_ch)[ - 0 - ].view_as("np") - tmp = np.full((np.max(idx) + 1, len(times[0])), np.nan) - tmp[idx_ch] = times - times = ak.drop_none(ak.nan_to_none(ak.Array(tmp))) - - mask = get_spm_mask(lim, tge, tmin, tmax, pe, times) - - if mode == 0: - out_idx = ak.local_index(mask)[mask] - - elif mode == 1: - out_idx = np.full((np.max(idx) + 1), np.nan) - out_idx[idx_ch] = np.where( - ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch) - )[0] - out_idx = ak.drop_none(ak.nan_to_none(ak.Array(out_idx)[:, None])) - out_idx = out_idx[mask[mask] - 1] - - elif mode == 2: - out_idx = ak.Array( - [utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch)] * len(mask) - ) - out_idx = out_idx[:, None][mask[mask] - 1] - - elif mode == 3: - out_idx = np.full((np.max(idx) + 1), np.nan) - out_idx[idx_ch] = idx_ch - out_idx = ak.drop_none(ak.nan_to_none(ak.Array(out_idx)[:, None])) - out_idx = out_idx[mask[mask] - 1] - - else: - raise ValueError("Unknown mode") - - arr_lst.append(out_idx) - - return VectorOfVectors(array=ak.concatenate(arr_lst, axis=-1)) - - -def get_spm_ene_or_maj( - f_hit, - f_tcm, - hit_group, - tcm_group, - tcm_id_table_pattern, - chs, - lim, - trgr, - tdefault, - tmin, - tmax, - mode, -): - if mode not in ["energy_hc", "energy_dplms", "majority_hc", "majority_dplms"]: - raise ValueError("Unknown mode") - - # load TCM data to define an event - store = LH5Store() - ids = store.read(f"/{tcm_group}/array_id", f_tcm)[0].view_as("np") - idx = store.read(f"/{tcm_group}/array_idx", f_tcm)[0].view_as("np") - out = np.zeros(np.max(idx) + 1) - - if isinstance(trgr, (float, int)): - tge = cast_trigger(trgr, tdefault, length=np.max(idx) + 1) - else: - tge = cast_trigger(trgr, tdefault, length=None) - - for ch in chs: - idx_ch = idx[ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch)] - - if mode in ["energy_dplms", "majority_dplms"]: - pe = ak.drop_none( - ak.nan_to_none( - store.read( - f"{ch}/{hit_group}/energy_in_pe_dplms", f_hit, idx=idx_ch - )[0].view_as("ak") - ) - ) - - # times are in sample units - times = ak.drop_none( - ak.nan_to_none( - store.read( - f"{ch}/{hit_group}/trigger_pos_dplms", f_hit, idx=idx_ch - )[0].view_as("ak") - ) - ) - - else: - pe = ak.drop_none( - ak.nan_to_none( - store.read(f"{ch}/{hit_group}/energy_in_pe", f_hit, idx=idx_ch)[ - 0 - ].view_as("ak") - ) - ) - - # times are in sample units - times = ak.drop_none( - ak.nan_to_none( - store.read(f"{ch}/{hit_group}/trigger_pos", f_hit, idx=idx_ch)[ - 0 - ].view_as("ak") - ) - ) - - mask = get_spm_mask(lim, tge[idx_ch], tmin, tmax, pe, times) - pe = pe[mask] - - if mode in ["energy_hc", "energy_dplms"]: - out[idx_ch] = out[idx_ch] + ak.to_numpy(ak.nansum(pe, axis=-1)) - - else: - out[idx_ch] = out[idx_ch] + ak.to_numpy( - ak.where(ak.nansum(pe, axis=-1) > lim, 1, 0) - ) - - return Array(nda=out) - - -# get LAr energy per event over all channels -def get_energy( - f_hit, - f_dsp, - f_tcm, - hit_group, - dsp_group, - tcm_group, - tcm_id_table_pattern, - chs, - lim, - trgr, - tdefault, - tmin, - tmax, -) -> Array: - return get_spm_ene_or_maj( - f_hit, - f_tcm, - hit_group, - tcm_group, - tcm_id_table_pattern, - chs, - lim, - trgr, - tdefault, - tmin, - tmax, - "energy_hc", - ) - - -# get LAr majority per event over all channels -def get_majority( - f_hit, - f_dsp, - f_tcm, - hit_group, - dsp_group, - tcm_group, - tcm_id_table_pattern, - chs, - lim, - trgr, - tdefault, - tmin, - tmax, -) -> Array: - return get_spm_ene_or_maj( - f_hit, - f_tcm, - hit_group, - tcm_group, - tcm_id_table_pattern, - chs, - lim, - trgr, - tdefault, - tmin, - tmax, - "majority_hc", - ) - - -# get LAr energy per event over all channels -def get_energy_dplms( - f_hit, - f_dsp, - f_tcm, - hit_group, - dsp_group, - tcm_group, - tcm_id_table_pattern, - chs, - lim, - trgr, - tdefault, - tmin, - tmax, -) -> Array: - return get_spm_ene_or_maj( - f_hit, - f_tcm, - hit_group, - tcm_group, - tcm_id_table_pattern, - chs, - lim, - trgr, - tdefault, - tmin, - tmax, - "energy_dplms", - ) - - -# get LAr majority per event over all channels -def get_majority_dplms( - f_hit, - f_dsp, - f_tcm, - hit_group, - dsp_group, - tcm_group, - tcm_id_table_pattern, - chs, - lim, - trgr, - tdefault, - tmin, - tmax, -) -> Array: - return get_spm_ene_or_maj( - f_hit, - f_tcm, - hit_group, - tcm_group, - tcm_id_table_pattern, - chs, - lim, - trgr, - tdefault, - tmin, - tmax, - "majority_dplms", - ) - - -# Calculate the ETC in different trailing modes: -# trail = 0: Singlet window = [tge,tge+swin] -# trail = 1: Singlet window = [t_first_lar_pulse, t_first_lar_pulse+ swin] -# trail = 2: Like trail = 1, but t_first_lar_pulse <= tge is ensured -# min_first_pls_ene sets the minimum energy of the first pulse (only used in trail > 0) -# max_per_channel, maximum number of pes a channel is allowed to have, if above it gets excluded -def get_etc( - f_hit, - f_dsp, - f_tcm, - hit_group, - dsp_group, - tcm_group, - tcm_id_table_pattern, - chs, - lim, - trgr, - tdefault, - tmin, - tmax, - swin, - trail, - min_first_pls_ene, - max_per_channel, -) -> Array: - # load TCM data to define an event - store = LH5Store() - ids = store.read(f"/{tcm_group}/array_id", f_tcm)[0].view_as("np") - idx = store.read(f"/{tcm_group}/array_idx", f_tcm)[0].view_as("np") - pe_lst = [] - time_lst = [] - - if isinstance(trgr, (float, int)): - tge = cast_trigger(trgr, tdefault, length=np.max(idx) + 1) - else: - tge = cast_trigger(trgr, tdefault, length=None) - - for ch in chs: - idx_ch = idx[ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch)] - - pe = store.read(f"{ch}/{hit_group}/energy_in_pe", f_hit, idx=idx_ch)[0].view_as( - "np" - ) - tmp = np.full((np.max(idx) + 1, len(pe[0])), np.nan) - tmp[idx_ch] = pe - pe = ak.drop_none(ak.nan_to_none(ak.Array(tmp))) - - # times are in sample units - times = store.read(f"{ch}/{hit_group}/trigger_pos", f_hit, idx=idx_ch)[ - 0 - ].view_as("np") - tmp = np.full((np.max(idx) + 1, len(times[0])), np.nan) - tmp[idx_ch] = times - times = ak.drop_none(ak.nan_to_none(ak.Array(tmp))) - - mask = get_spm_mask(lim, tge, tmin, tmax, pe, times) - - pe = pe[mask] - - # max pe mask - max_pe_mask = ak.nansum(pe, axis=-1) < max_per_channel - pe = ak.drop_none( - ak.nan_to_none(ak.where(max_pe_mask, pe, ak.Array([[np.nan]]))) - ) - pe_lst.append(pe) - - times = times[mask] * 16 - times = ak.drop_none( - ak.nan_to_none(ak.where(max_pe_mask, times, ak.Array([[np.nan]]))) - ) - time_lst.append(times) - - pe_all = ak.concatenate(pe_lst, axis=-1) - time_all = ak.concatenate(time_lst, axis=-1) - - if trail > 0: - t1d = ak.min(time_all[pe_all > min_first_pls_ene], axis=-1) - - if trail == 2: - t1d = ak.where(t1d > tge, tge, t1d) - - mask_total = time_all > t1d - mask_singlet = (time_all > t1d) & (time_all < t1d + swin) - - else: - mask_total = time_all > tge - mask_singlet = (time_all > tge) & (time_all < tge + swin) - - pe_singlet = ak.to_numpy( - ak.fill_none(ak.nansum(pe_all[mask_singlet], axis=-1), 0), allow_missing=False - ) - pe_total = ak.to_numpy( - ak.fill_none(ak.nansum(pe_all[mask_total], axis=-1), 0), allow_missing=False - ) - etc = np.divide( - pe_singlet, pe_total, out=np.full_like(pe_total, np.nan), where=pe_total != 0 - ) - - return Array(nda=etc) - - -# returns relative time shift of the first LAr pulse relative to the Ge trigger -def get_time_shift( - f_hit, - f_dsp, - f_tcm, - hit_group, - dsp_group, - tcm_group, - tcm_id_table_pattern, - chs, - lim, - trgr, - tdefault, - tmin, - tmax, -) -> Array: - store = LH5Store() - # load TCM data to define an event - ids = store.read(f"/{tcm_group}/array_id", f_tcm)[0].view_as("np") - idx = store.read(f"/{tcm_group}/array_idx", f_tcm)[0].view_as("np") - time_all = ak.Array([[] for x in range(np.max(idx) + 1)]) - - if isinstance(trgr, (float, int)): - tge = cast_trigger(trgr, tdefault, length=np.max(idx) + 1) - else: - tge = cast_trigger(trgr, tdefault, length=None) - - for ch in chs: - idx_ch = idx[ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch)] - - pe = store.read(f"{ch}/{hit_group}/energy_in_pe", f_hit, idx=idx_ch)[0].view_as( - "np" - ) - tmp = np.full((np.max(idx) + 1, len(pe[0])), np.nan) - tmp[idx_ch] = pe - pe = ak.drop_none(ak.nan_to_none(ak.Array(tmp))) - - # times are in sample units - times = store.read(f"{ch}/{hit_group}/trigger_pos", f_hit, idx=idx_ch)[ - 0 - ].view_as("np") - tmp = np.full((np.max(idx) + 1, len(times[0])), np.nan) - tmp[idx_ch] = times - times = ak.drop_none(ak.nan_to_none(ak.Array(tmp))) - - mask = get_spm_mask(lim, tge, tmin, tmax, pe, times) - - # apply mask and convert sample units to ns - times = times[mask] * 16 - - time_all = ak.concatenate((time_all, times), axis=-1) - - out = ak.min(time_all, axis=-1) - - # Convert to 1D numpy array - out = ak.to_numpy(ak.fill_none(out, np.inf), allow_missing=False) - tge = ak.to_numpy(tge, allow_missing=False) - - return Array(out - tge) diff --git a/src/pygama/evt/modules/spms.py b/src/pygama/evt/modules/spms.py new file mode 100644 index 000000000..2f40c472a --- /dev/null +++ b/src/pygama/evt/modules/spms.py @@ -0,0 +1,381 @@ +"""Event processors for SiPM data.""" + +from __future__ import annotations + +from collections.abc import Sequence + +import awkward as ak +import numpy as np +from lgdo import lh5, types + +from .. import utils +from . import larveto + + +def gather_pulse_data( + datainfo: utils.DataInfo, + tcm: utils.TCMData, + table_names: Sequence[str], + *, + observable: str, + pulse_mask: types.VectorOfVectors = None, + a_thr_pe: float = None, + t_loc_ns: float = None, + dt_range_ns: Sequence[float] = None, + t_loc_default_ns: float = None, + drop_empty: bool = True, +) -> types.VectorOfVectors: + """Gathers SiPM pulse data into a 3D :class:`~lgdo.types.vectorofvectors.VectorOfVectors`. + + The returned data structure specifies the event in the first axis, the SiPM + channel in the second and the pulse index in the last. + + Pulse data can be optionally masked with `pulse_mask` or a mask can be + built on the fly from the `a_thr_pe`, `t_loc_ns`, `dt_range_ns`, + `t_loc_default_ns` arguments (see :func:`make_pulse_data_mask`). + + If `pulse_mask`, `a_thr_pe`, `t_loc_ns`, `dt_range_ns`, `t_loc_default_ns` + are all ``None``, no masking is applied and the full data set is returned. + + Parameters + ---------- + datainfo, tcm, table_names + positional arguments automatically supplied by :func:`.build_evt`. + observable + name of the pulse parameter to be gathered, optionally prefixed by tier + name (e.g. ``hit.energy_in_pe``). If no tier is specified, it defaults + to ``hit``. + pulse_mask + 3D mask object used to filter out pulse data. See + :func:`make_pulse_data_mask`. + a_thr_pe + amplitude threshold (in photoelectrons) used to build a pulse mask with + :func:`make_pulse_data_mask`, if `pulse_mask` is ``None``. The output + pulse data will be such that the pulse amplitude is above this value. + t_loc_ns + location of the time window in which pulses must sit. If a 1D array is + provided, it is interpreted as a list of locations for each event (can + be employed to e.g. provide the actual HPGe pulse position) + dt_range_ns + tuple with dimension of the time window in which pulses must sit + relative to `t_loc_ns`. If, for example, `t_loc_ns` is 48000 ns and + `dt_range_ns` is (-1000, 5000) ns, the resulting window will be (47000, + 53000) ns. + t_loc_default_ns + default value for `t_loc_ns`, in case the supplied value is + :any:`numpy.nan`. + drop_empty + if ``True``, drop empty arrays at the last axis (the pulse axis), i.e. + drop channels with no pulse data. The filtering is applied after the + application of the mask. + """ + # parse observables string. default to hit tier + p = observable.split(".") + tier = p[0] if len(p) > 1 else "hit" + column = p[1] if len(p) > 1 else p[0] + + tierinfo = datainfo._asdict()[tier] + + # loop over selected table_names and load hit data + concatme = [] + for channel in table_names: + table_id = utils.get_tcm_id_by_pattern(tierinfo.table_fmt, channel) + + # determine list of indices found in the TCM that we want to load for channel + idx = tcm.idx[tcm.id == table_id] + + # read the data in + lgdo_obj = lh5.read( + f"/{channel}/{tierinfo.group}/{column}", tierinfo.file, idx=idx + ) + data = lgdo_obj.view_as(library="ak") + + # remove nans (this happens when SiPM data is stored as ArrayOfEqualSizedArrays) + data = ak.drop_none(ak.nan_to_none(data)) + + # increase the dimensionality by one (events) + data = ak.unflatten(data, np.full(data.layout.length, 1, dtype="uint8")) + + concatme.append(data) + + # concatenate along the event axes (i.e. gather table_names together) + data = ak.concatenate(concatme, axis=1) + + # check if user wants to apply a mask + if pulse_mask is None and any( + [kwarg is not None for kwarg in (a_thr_pe, t_loc_ns, dt_range_ns)] + ): + # generate the time/amplitude mask from parameters + pulse_mask = make_pulse_data_mask( + datainfo, + tcm, + table_names, + a_thr_pe=a_thr_pe, + t_loc_ns=t_loc_ns, + dt_range_ns=dt_range_ns, + t_loc_default_ns=t_loc_default_ns, + ) + + if pulse_mask is not None: + if not isinstance(pulse_mask, ak.Array): + pulse_mask = pulse_mask.view_as("ak") + + # apply the mask + data = data[pulse_mask] + + # remove empty arrays = table_names with no pulses + if drop_empty: + data = data[ak.count(data, axis=-1) > 0] + + return types.VectorOfVectors(data, attrs=utils.copy_lgdo_attrs(lgdo_obj)) + + +def gather_tcm_data( + datainfo: utils.DataInfo, + tcm: utils.TCMData, + table_names: Sequence[str], + *, + tcm_field="id", + pulse_mask=None, + a_thr_pe=None, + t_loc_ns=None, + dt_range_ns=None, + t_loc_default_ns=None, + drop_empty=True, +) -> types.VectorOfVectors: + """Gather TCM data into a 2D :class:`~lgdo.types.vectorofvectors.VectorOfVectors`. + + The returned data structure specifies the event on the first axis and the + TCM data (`id` or `idx`) on the second. Can be used to filter out data from + :func:`gather_pulse_data` based on SiPM channel provenance (`id`) or to + load hit data from lower tiers (with `idx`). + + If `drop_empty` is ``True``, channel ids with no pulse data associated are + removed. + + See :func:`gather_pulse_data` for documentation about the other function + arguments. + """ + # unflatten the tcm data with cumulative_length, i.e. make a VoV + tcm_vov = {} + for field in ("id", "idx"): + tcm_vov[field] = types.VectorOfVectors( + flattened_data=tcm._asdict()[field], cumulative_length=tcm.cumulative_length + ).view_as("ak") + + # list user wanted table names + table_ids = [ + utils.get_tcm_id_by_pattern(datainfo.hit.table_fmt, id) for id in table_names + ] + # find them in tcm.id (we'll filter the rest out) + locs = np.isin(tcm_vov["id"], table_ids) + + # select tcm field requested by the user + data = tcm_vov[tcm_field] + + # apply mask + # NOTE: need to cast to irregular axes, otherwise the masking result is + # non-nested + data = data[ak.from_regular(locs)] + + # check if user wants to apply a custom mask + if drop_empty: + if pulse_mask is None: + # generate the time/amplitude mask from parameters + # if all parameters are None, a dummy mask (the identity) will be made + pulse_mask = make_pulse_data_mask( + datainfo, + tcm, + table_names, + a_thr_pe=a_thr_pe, + t_loc_ns=t_loc_ns, + dt_range_ns=dt_range_ns, + t_loc_default_ns=t_loc_default_ns, + ) + + if not isinstance(pulse_mask, ak.Array): + pulse_mask = pulse_mask.view_as("ak") + + if pulse_mask.ndim != 3: + msg = "pulse_mask must be 3D" + raise ValueError(msg) + + # convert the 3D mask to a 2D mask (can be used to filter table_ids) + ch_mask = ak.sum(pulse_mask, axis=-1) > 0 + + # apply the mask + data = data[ch_mask] + + return types.VectorOfVectors(data) + + +# NOTE: the mask never gets the empty arrays removed +def make_pulse_data_mask( + datainfo: utils.DataInfo, + tcm: utils.TCMData, + table_names: Sequence[str], + *, + a_thr_pe=None, + t_loc_ns=None, + dt_range_ns=None, + t_loc_default_ns=None, +) -> types.VectorOfVectors: + """Calculate a 3D :class:`~lgdo.types.vectorofvectors.VectorOfVectors` pulse data mask. + + Useful to filter any pulse data based on pulse amplitude and start time. + + Parameters + ---------- + datainfo, tcm, table_names + positional arguments automatically supplied by :func:`.build_evt`. + a_thr_pe + amplitude threshold (in photoelectrons) used to build a pulse mask with + :func:`make_pulse_data_mask`, if `pulse_mask` is ``None``. The output + pulse data will be such that the pulse amplitude is above this value. + t_loc_ns + location of the time window in which pulses must sit. If a 1D array is + provided, it is interpreted as a list of locations for each event (can + be employed to e.g. provide the actual HPGe pulse position) + dt_range_ns + tuple with dimension of the time window in which pulses must sit + relative to `t_loc_ns`. If, for example, `t_loc_ns` is 48000 ns and + `dt_range_ns` is (-1000, 5000) ns, the resulting window will be (47000, + 53000) ns. + t_loc_default_ns + default value for `t_loc_ns`, in case the supplied value is + :any:`numpy.nan`. + """ + # get the t0 of each single pulse + pulse_t0 = gather_pulse_data( + datainfo, + tcm, + table_names, + observable="hit.trigger_pos", + drop_empty=False, + ) + + # HACK: handle units + # HACK: remove me once units are fixed in the dsp tier + if "units" in pulse_t0.attrs and pulse_t0.attrs["units"] == "ns": + pulse_t0_ns = pulse_t0.view_as("ak") + else: + pulse_t0_ns = pulse_t0.view_as("ak") * 16 + + pulse_amp = gather_pulse_data( + datainfo, + tcm, + table_names, + observable="hit.energy_in_pe", + drop_empty=False, + ).view_as("ak") + + # (HPGe) trigger position can vary among events! + if isinstance(t_loc_ns, types.Array): + t_loc_ns = t_loc_ns.view_as("ak") + + if isinstance(t_loc_ns, ak.Array): + if t_loc_ns.ndim != 1: + msg = "t_loc_ns must be 0- or 1-dimensional" + raise ValueError(msg) + + # NOTE: the assumption is that t0 is np.nan when missing -> replace + # with default value + t_loc_ns = ak.fill_none(ak.nan_to_none(t_loc_ns), t_loc_default_ns) + + # start with all-true mask + mask = pulse_t0_ns == pulse_t0_ns + + # apply p.e. threshold + if a_thr_pe is not None: + mask = mask & (pulse_amp > a_thr_pe) + + # apply time windowing + if t_loc_ns is not None and dt_range_ns is not None: + if not isinstance(dt_range_ns, (tuple, list)): + msg = "dt_range_ns must be a tuple" + raise ValueError(msg) + + mask = mask & ( + (pulse_t0_ns < (t_loc_ns + dt_range_ns[1])) + & (pulse_t0_ns > (t_loc_ns + dt_range_ns[0])) + ) + + return types.VectorOfVectors(mask) + + +def geds_coincidence_classifier( + datainfo: utils.DataInfo, + tcm: utils.TCMData, + table_names: Sequence[str], + *, + geds_t0_ns: types.Array, +) -> types.Array: + """Calculate the HPGe / SiPMs coincidence classifier. + + The value represents the likelihood of a physical correlation between HPGe + and SiPM signals. + + Parameters + ---------- + datainfo, tcm, table_names + positional arguments automatically supplied by :func:`.build_evt`. + """ + # mask for windowing data around the HPGe t0 + pulse_mask = make_pulse_data_mask( + datainfo, + tcm, + table_names, + a_thr_pe=None, + t_loc_ns=geds_t0_ns, + dt_range_ns=(-1_000, 5_000), + t_loc_default_ns=48_000, + ) + + # load the data + data = {} + for k, obs in {"amp": "hit.energy_in_pe", "t0": "hit.trigger_pos"}.items(): + data[k] = gather_pulse_data( + datainfo, + tcm, + table_names, + observable=obs, + pulse_mask=pulse_mask, + drop_empty=True, + ).view_as("ak") + + # load the channel info + # rawids = spms.gather_tcm_id_data( + # datainfo, + # tcm, + # table_names, + # pulse_mask=pulse_mask, + # drop_empty=True, + # ) + + # (HPGe) trigger position can vary among events! + if isinstance(geds_t0_ns, types.Array): + geds_t0_ns = geds_t0_ns.view_as("ak") + + ts_data = larveto.l200_combined_test_stat(data["t0"], data["amp"], geds_t0_ns) + + return types.Array(ts_data) + + +# REMOVE ME: not needed anymore with VectorOfVectors DSP outputs +def gather_is_valid_hit(datainfo, tcm, table_names): + data = {} + for field in ("is_valid_hit", "trigger_pos"): + data[field] = gather_pulse_data( + datainfo, + tcm, + table_names, + observable=f"hit.{field}", + pulse_mask=None, + drop_empty=False, + ).view_as("ak") + + return types.VectorOfVectors( + data["is_valid_hit"][ + ak.local_index(data["is_valid_hit"]) < ak.num(data["trigger_pos"], axis=-1) + ] + ) diff --git a/src/pygama/evt/utils.py b/src/pygama/evt/utils.py index 175cd868a..30d14639a 100644 --- a/src/pygama/evt/utils.py +++ b/src/pygama/evt/utils.py @@ -4,192 +4,189 @@ from __future__ import annotations +import copy import re +from collections import namedtuple import awkward as ak import numpy as np -from lgdo.lh5 import LH5Store +from lgdo import lh5 from numpy.typing import NDArray +H5DataLoc = namedtuple( + "H5DataLoc", ("file", "group", "table_fmt"), defaults=3 * (None,) +) -def get_tcm_id_by_pattern(tcm_id_table_pattern: str, ch: str) -> int: - pre = tcm_id_table_pattern.split("{")[0] - post = tcm_id_table_pattern.split("}")[1] +DataInfo = namedtuple( + "DataInfo", ("raw", "tcm", "dsp", "hit", "evt"), defaults=5 * (None,) +) + +TCMData = namedtuple("TCMData", ("id", "idx", "cumulative_length")) + + +def make_files_config(data: dict): + if not isinstance(data, DataInfo): + return DataInfo( + *[ + H5DataLoc(*data[tier]) if tier in data else H5DataLoc() + for tier in DataInfo._fields + ] + ) + + return data + + +def make_numpy_full(size, fill_value, try_dtype): + if np.can_cast(fill_value, try_dtype): + return np.full(size, fill_value, dtype=try_dtype) + else: + return np.full(size, fill_value) + + +def copy_lgdo_attrs(obj): + attrs = copy.copy(obj.attrs) + attrs.pop("datatype") + return attrs + + +def get_tcm_id_by_pattern(table_id_fmt: str, ch: str) -> int: + pre = table_id_fmt.split("{")[0] + post = table_id_fmt.split("}")[1] return int(ch.strip(pre).strip(post)) -def get_table_name_by_pattern(tcm_id_table_pattern: str, ch_id: int) -> str: - # check tcm_id_table_pattern validity - pattern_check = re.findall(r"{([^}]*?)}", tcm_id_table_pattern)[0] +def get_table_name_by_pattern(table_id_fmt: str, ch_id: int) -> str: + # check table_id_fmt validity + pattern_check = re.findall(r"{([^}]*?)}", table_id_fmt)[0] if pattern_check == "" or ":" == pattern_check[0]: - return tcm_id_table_pattern.format(ch_id) + return table_id_fmt.format(ch_id) else: raise NotImplementedError( - "Only empty placeholders with format specifications are currently implemented" + "only empty placeholders {} in format specifications are currently supported" ) -def num_and_pars(value: str, par_dic: dict): - # function tries to convert a string to a int, float, bool - # or returns the value if value is a key in par_dic - if value in par_dic.keys(): - return par_dic[value] - try: - value = int(value) - except ValueError: - try: - value = float(value) - except ValueError: - try: - value = bool(value) - except ValueError: - pass - return value - - def find_parameters( - f_hit: str, - f_dsp: str, - ch: str, - idx_ch: NDArray, - exprl: list, - hit_group: str = "hit", - dsp_group: str = "dsp", + datainfo, + ch, + idx_ch, + field_list, ) -> dict: - """Wraps :func:`load_vars_to_nda` to return parameters from `hit` and `dsp` - tiers. + """Finds and returns parameters from `hit` and `dsp` tiers. Parameters ---------- - f_hit - path to `hit` tier file. - f_dsp - path to `dsp` tier file. + datainfo + input and output LH5 datainfo with HDF5 groups where tables are found. ch "rawid" in the tiers. idx_ch - index array of entries to be read from files. - exprl + index array of entries to be read from datainfo. + field_list list of tuples ``(tier, field)`` to be found in the `hit/dsp` tiers. - dsp_group - LH5 root group in dsp file. - hit_group - LH5 root group in hit file. """ + f = make_files_config(datainfo) # find fields in either dsp, hit - dsp_flds = [e[1] for e in exprl if e[0] == dsp_group] - hit_flds = [e[1] for e in exprl if e[0] == hit_group] + dsp_flds = [e[1] for e in field_list if e[0] == f.dsp.group] + hit_flds = [e[1] for e in field_list if e[0] == f.hit.group] - store = LH5Store() hit_dict, dsp_dict = {}, {} + if len(hit_flds) > 0: - hit_ak = store.read( - f"{ch.replace('/','')}/{hit_group}/", f_hit, field_mask=hit_flds, idx=idx_ch - )[0].view_as("ak") + hit_ak = lh5.read_as( + f"{ch.replace('/','')}/{f.hit.group}/", + f.hit.file, + field_mask=hit_flds, + idx=idx_ch, + library="ak", + ) + hit_dict = dict( - zip([f"{hit_group}_" + e for e in ak.fields(hit_ak)], ak.unzip(hit_ak)) + zip([f"{f.hit.group}_" + e for e in ak.fields(hit_ak)], ak.unzip(hit_ak)) ) + if len(dsp_flds) > 0: - dsp_ak = store.read( - f"{ch.replace('/','')}/{dsp_group}/", f_dsp, field_mask=dsp_flds, idx=idx_ch - )[0].view_as("ak") + dsp_ak = lh5.read_as( + f"{ch.replace('/','')}/{f.dsp.group}/", + f.dsp.file, + field_mask=dsp_flds, + idx=idx_ch, + library="ak", + ) + dsp_dict = dict( - zip([f"{dsp_group}_" + e for e in ak.fields(dsp_ak)], ak.unzip(dsp_ak)) + zip([f"{f.dsp.group}_" + e for e in ak.fields(dsp_ak)], ak.unzip(dsp_ak)) ) return hit_dict | dsp_dict def get_data_at_channel( - ch: str, - ids: NDArray, - idx: NDArray, - expr: str, - exprl: list, - var_ph: dict, - is_evaluated: bool, - f_hit: str, - f_dsp: str, - defv, - tcm_id_table_pattern: str = "ch{}", - evt_group: str = "evt", - hit_group: str = "hit", - dsp_group: str = "dsp", -) -> np.ndarray: + datainfo, + ch, + tcm, + expr, + field_list, + pars_dict, +) -> NDArray: """Evaluates an expression and returns the result. Parameters ---------- + datainfo + input and output LH5 datainfo with HDF5 groups where tables are found. ch "rawid" of channel to be evaluated. - idx - `tcm` index array. - ids - `tcm` id array. + tcm + TCM data arrays in an object that can be accessed by attribute. expr expression to be evaluated. - exprl + field_list list of parameter-tuples ``(root_group, field)`` found in the expression. - var_ph + pars_dict dict of additional parameters that are not channel dependent. is_evaluated if false, the expression does not get evaluated but an array of default values is returned. - f_hit - path to `hit` tier file. - f_dsp - path to `dsp` tier file. - defv + default_value default value. - tcm_id_table_pattern - Pattern to format tcm id values to table name in higher tiers. Must have one - placeholder which is the tcm id. - dsp_group - LH5 root group in dsp file. - hit_group - LH5 root group in hit file. - evt_group - LH5 root group in evt file. """ + f = make_files_config(datainfo) + table_id = get_tcm_id_by_pattern(f.hit.table_fmt, ch) # get index list for this channel to be loaded - idx_ch = idx[ids == get_tcm_id_by_pattern(tcm_id_table_pattern, ch)] + idx_ch = tcm.idx[tcm.id == table_id] outsize = len(idx_ch) - if not is_evaluated: - res = np.full(outsize, defv, dtype=type(defv)) - elif "tcm.array_id" == expr: - res = np.full( - outsize, get_tcm_id_by_pattern(tcm_id_table_pattern, ch), dtype=int - ) - elif "tcm.index" == expr: - res = np.where(ids == get_tcm_id_by_pattern(tcm_id_table_pattern, ch))[0] + if expr == "tcm.array_id": + res = np.full(outsize, table_id, dtype=int) + elif expr == "tcm.array_idx": + res = idx_ch + elif expr == "tcm.index": + res = np.where(tcm.id == table_id)[0] else: var = find_parameters( - f_hit=f_hit, - f_dsp=f_dsp, + datainfo=datainfo, ch=ch, idx_ch=idx_ch, - exprl=exprl, - hit_group=hit_group, - dsp_group=dsp_group, + field_list=field_list, ) - if var_ph is not None: - var = var | var_ph + if pars_dict is not None: + var = var | pars_dict # evaluate expression # move tier+dots in expression to underscores (e.g. evt.foo -> evt_foo) res = eval( - expr.replace(f"{dsp_group}.", f"{dsp_group}_") - .replace(f"{hit_group}.", f"{hit_group}_") - .replace(f"{evt_group}.", ""), + expr.replace(f"{f.dsp.group}.", f"{f.dsp.group}_") + .replace(f"{f.hit.group}.", f"{f.hit.group}_") + .replace(f"{f.evt.group}.", ""), var, ) # in case the expression evaluates to a single value blow it up - if (not hasattr(res, "__len__")) or (isinstance(res, str)): + if not hasattr(res, "__len__") or isinstance(res, str): return np.full(outsize, res) # the resulting arrays need to be 1D from the operation, @@ -200,27 +197,28 @@ def get_data_at_channel( # in this method only 1D values are allowed if res.ndim > 1: raise ValueError( - f"expression '{expr}' must return 1D array. If you are using VectorOfVectors or ArrayOfEqualSizedArrays, use awkward reduction functions to reduce the dimension" + f"expression '{expr}' must return 1D array. If you are using " + "VectorOfVectors or ArrayOfEqualSizedArrays, use awkward " + "reduction functions to reduce the dimension" ) return res def get_mask_from_query( - qry: str | NDArray, - length: int, - ch: str, - idx_ch: NDArray, - f_hit: str, - f_dsp: str, - hit_group: str = "hit", - dsp_group: str = "dsp", -) -> np.ndarray: + datainfo, + query, + length, + ch, + idx_ch, +) -> NDArray: """Evaluates a query expression and returns a mask accordingly. Parameters ---------- - qry + datainfo + input and output LH5 datainfo with HDF5 groups where tables are found. + query query expression. length length of the return mask. @@ -228,33 +226,23 @@ def get_mask_from_query( "rawid" of channel to be evaluated. idx_ch channel indices to be read. - f_hit - path to `hit` tier file. - f_dsp - path to `dsp` tier file. - hit_group - LH5 root group in hit file. - dsp_group - LH5 root group in dsp file. """ + f = make_files_config(datainfo) # get sub evt based query condition if needed - if isinstance(qry, str): - qry_lst = re.findall(r"(hit|dsp).([a-zA-Z_$][\w$]*)", qry) - qry_var = find_parameters( - f_hit=f_hit, - f_dsp=f_dsp, + if isinstance(query, str): + query_lst = re.findall(r"(hit|dsp).([a-zA-Z_$][\w$]*)", query) + query_var = find_parameters( + datainfo=datainfo, ch=ch, idx_ch=idx_ch, - exprl=qry_lst, - hit_group=hit_group, - dsp_group=dsp_group, + field_list=query_lst, ) limarr = eval( - qry.replace(f"{dsp_group}.", f"{dsp_group}_").replace( - f"{hit_group}.", f"{hit_group}_" + query.replace(f"{f.dsp.group}.", f"{f.dsp.group}_").replace( + f"{f.hit.group}.", f"{f.hit.group}_" ), - qry_var, + query_var, ) # in case the expression evaluates to a single value blow it up @@ -264,12 +252,14 @@ def get_mask_from_query( limarr = ak.to_numpy(limarr, allow_missing=False) if limarr.ndim > 1: raise ValueError( - f"query '{qry}' must return 1D array. If you are using VectorOfVectors or ArrayOfEqualSizedArrays, use awkward reduction functions to reduce the dimension" + f"query '{query}' must return 1D array. If you are using " + "VectorOfVectors or ArrayOfEqualSizedArrays, use awkward " + "reduction functions to reduce the dimension" ) # or forward the array - elif isinstance(qry, np.ndarray): - limarr = qry + elif isinstance(query, np.ndarray): + limarr = query # if no condition, it must be true else: diff --git a/src/pygama/flow/data_loader.py b/src/pygama/flow/data_loader.py index 7e5c38616..aa65d2a3c 100644 --- a/src/pygama/flow/data_loader.py +++ b/src/pygama/flow/data_loader.py @@ -17,7 +17,7 @@ from lgdo.lh5 import LH5Iterator, LH5Store from lgdo.lh5.utils import expand_vars from lgdo.types import Array, Struct, Table -from lgdo.types.vectorofvectors import build_cl, explode_arrays, explode_cl +from lgdo.types.vovutils import build_cl, explode_arrays, explode_cl from tqdm.auto import tqdm from . import utils diff --git a/src/pygama/hit/build_hit.py b/src/pygama/hit/build_hit.py index 2fad9c981..7a8c6a241 100644 --- a/src/pygama/hit/build_hit.py +++ b/src/pygama/hit/build_hit.py @@ -4,7 +4,6 @@ from __future__ import annotations -import json import logging import os from collections import OrderedDict @@ -14,6 +13,8 @@ import numpy as np from lgdo.lh5 import LH5Iterator, LH5Store, ls +from .. import utils + log = logging.getLogger(__name__) @@ -96,20 +97,17 @@ def build_hit( tbl_cfg = lh5_tables_config # sanitize config if isinstance(tbl_cfg, str): - with open(tbl_cfg) as f: - tbl_cfg = json.load(f) + tbl_cfg = utils.load_dict(tbl_cfg) for k, v in tbl_cfg.items(): if isinstance(v, str): - with open(v) as f: - tbl_cfg[k] = json.load(f) + tbl_cfg[k] = utils.load_dict(v) lh5_tables_config = tbl_cfg else: if isinstance(hit_config, str): # sanitize config - with open(hit_config) as f: - hit_config = json.load(f) + hit_config = utils.load_dict(hit_config) if lh5_tables is None: lh5_tables_config = {} diff --git a/src/pygama/utils.py b/src/pygama/utils.py new file mode 100644 index 000000000..b35f109fd --- /dev/null +++ b/src/pygama/utils.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import json +import logging +from pathlib import Path + +import yaml + +log = logging.getLogger(__name__) + +__file_extensions__ = {"json": [".json"], "yaml": [".yaml", ".yml"]} + + +def load_dict(fname: str, ftype: str | None = None) -> dict: + """Load a text file as a Python dict.""" + fname = Path(fname) + + # determine file type from extension + if ftype is None: + for _ftype, exts in __file_extensions__.items(): + if fname.suffix in exts: + ftype = _ftype + + msg = f"loading {ftype} dict from: {fname}" + log.debug(msg) + + with fname.open() as f: + if ftype == "json": + return json.load(f) + if ftype == "yaml": + return yaml.safe_load(f) + + msg = f"unsupported file format {ftype}" + raise NotImplementedError(msg) diff --git a/tests/evt/configs/basic-evt-config.json b/tests/evt/configs/basic-evt-config.json deleted file mode 100644 index 3a8c62753..000000000 --- a/tests/evt/configs/basic-evt-config.json +++ /dev/null @@ -1,90 +0,0 @@ -{ - "channels": { - "geds_on": ["ch1084803", "ch1084804", "ch1121600"] - }, - "outputs": [ - "multiplicity", - "energy", - "energy_id", - "energy_idx", - "energy_any_above1MeV", - "energy_all_above1MeV", - "energy_aux", - "energy_sum", - "is_usable_aoe", - "aoe", - "is_aoe_rejected" - ], - "operations": { - "multiplicity": { - "channels": "geds_on", - "aggregation_mode": "sum", - "expression": "hit.cuspEmax_ctc_cal > a", - "parameters": { "a": 25 }, - "initial": 0, - "lgdo_attrs": { "statement": "0bb decay is real" } - }, - "energy": { - "channels": "geds_on", - "aggregation_mode": "first_at:dsp.tp_0_est", - "query": "hit.cuspEmax_ctc_cal>25", - "expression": "hit.cuspEmax_ctc_cal", - "initial": "np.nan" - }, - "energy_id": { - "channels": "geds_on", - "aggregation_mode": "first_at:dsp.tp_0_est", - "query": "hit.cuspEmax_ctc_cal>25", - "expression": "tcm.array_id", - "initial": 0 - }, - "energy_idx": { - "channels": "geds_on", - "aggregation_mode": "first_at:dsp.tp_0_est", - "query": "hit.cuspEmax_ctc_cal>25", - "expression": "tcm.index", - "initial": 999999999999 - }, - "energy_any_above1MeV": { - "channels": "geds_on", - "aggregation_mode": "any", - "expression": "hit.cuspEmax_ctc_cal>1000", - "initial": false - }, - "energy_all_above1MeV": { - "channels": "geds_on", - "aggregation_mode": "all", - "expression": "hit.cuspEmax_ctc_cal>1000", - "initial": false - }, - "energy_aux": { - "channels": "geds_on", - "aggregation_mode": "last_at:dsp.tp_0_est", - "query": "hit.cuspEmax_ctc_cal>25", - "expression": "hit.cuspEmax_ctc_cal", - "initial": "np.nan" - }, - "energy_sum": { - "channels": "geds_on", - "aggregation_mode": "sum", - "query": "hit.cuspEmax_ctc_cal>25", - "expression": "hit.cuspEmax_ctc_cal", - "initial": 0.0 - }, - "is_usable_aoe": { - "aggregation_mode": "keep_at_ch:evt.energy_id", - "expression": "True", - "initial": false - }, - "aoe": { - "aggregation_mode": "keep_at_ch:evt.energy_id", - "expression": "hit.AoE_Classifier", - "initial": "np.nan" - }, - "is_aoe_rejected": { - "aggregation_mode": "keep_at_ch:evt.energy_id", - "expression": "~(hit.AoE_Double_Sided_Cut)", - "initial": false - } - } -} diff --git a/tests/evt/configs/basic-evt-config.yaml b/tests/evt/configs/basic-evt-config.yaml new file mode 100644 index 000000000..bf229504e --- /dev/null +++ b/tests/evt/configs/basic-evt-config.yaml @@ -0,0 +1,85 @@ +channels: + geds_on: + - ch1084803 + - ch1084804 + - ch1121600 +outputs: + - timestamp + - multiplicity + - energy + - energy_id + - energy_idx + - energy_hit_idx + - energy_any_above1MeV + - energy_all_above1MeV + - energy_sum + - is_usable_aoe + - aoe + - is_aoe_rejected +operations: + timestamp: + channels: geds_on + aggregation_mode: first_at:dsp.tp_0_est + expression: dsp.timestamp + lgdo_attrs: + units: s + multiplicity: + channels: geds_on + aggregation_mode: sum + expression: hit.cuspEmax_ctc_cal > a + parameters: + a: 25 + initial: 0 + lgdo_attrs: + statement: 0bb decay is real + energy: + channels: geds_on + aggregation_mode: gather + query: hit.cuspEmax_ctc_cal>25 + expression: hit.cuspEmax_ctc_cal + energy_id: + channels: geds_on + aggregation_mode: first_at:dsp.tp_0_est + query: hit.cuspEmax_ctc_cal>25 + expression: tcm.array_id + initial: 0 + energy_idx: + channels: geds_on + aggregation_mode: first_at:dsp.tp_0_est + query: hit.cuspEmax_ctc_cal>25 + expression: tcm.index + initial: 999999999999 + energy_hit_idx: + channels: geds_on + aggregation_mode: first_at:dsp.tp_0_est + query: hit.cuspEmax_ctc_cal>25 + expression: tcm.array_idx + initial: 999999999999 + energy_any_above1MeV: + channels: geds_on + aggregation_mode: any + expression: hit.cuspEmax_ctc_cal>1000 + initial: false + energy_all_above1MeV: + channels: geds_on + aggregation_mode: all + expression: hit.cuspEmax_ctc_cal>1000 + initial: false + energy_sum: + channels: geds_on + aggregation_mode: sum + query: hit.cuspEmax_ctc_cal>25 + expression: hit.cuspEmax_ctc_cal + initial: 0 + is_usable_aoe: + aggregation_mode: keep_at_ch:evt.energy_id + expression: "True" + initial: false + aoe: + aggregation_mode: keep_at_ch:evt.energy_id + expression: hit.AoE_Classifier + initial: np.nan + is_aoe_rejected: + aggregation_mode: keep_at_ch:evt.energy_id + expression: ~(hit.AoE_Double_Sided_Cut) + initial: false diff --git a/tests/evt/configs/module-test-evt-config.json b/tests/evt/configs/module-test-evt-config.json deleted file mode 100644 index 0daa94658..000000000 --- a/tests/evt/configs/module-test-evt-config.json +++ /dev/null @@ -1,72 +0,0 @@ -{ - "channels": { - "spms_on": ["ch1057600", "ch1059201", "ch1062405"], - "geds_on": ["ch1084803", "ch1084804", "ch1121600"] - }, - "outputs": [ - "energy_first", - "energy_first_id", - "t0", - "lar_energy", - "lar_multiplicity", - "is_lar_rejected", - "lar_classifier", - "lar_energy_dplms", - "lar_multiplicity_dplms", - "lar_time_shift" - ], - "operations": { - "energy_first": { - "channels": "geds_on", - "aggregation_mode": "first_at:dsp.tp_0_est", - "query": "hit.cuspEmax_ctc_cal>25", - "expression": "hit.cuspEmax_ctc_cal", - "initial": "np.nan" - }, - "energy_first_id": { - "channels": "geds_on", - "aggregation_mode": "first_at:dsp.tp_0_est", - "query": "hit.cuspEmax_ctc_cal>25", - "expression": "tcm.array_id", - "initial": 0 - }, - "t0": { - "aggregation_mode": "keep_at_ch:evt.energy_first_id", - "expression": "dsp.tp_0_est", - "initial": 0.0 - }, - "lar_energy": { - "channels": "spms_on", - "aggregation_mode": "function", - "expression": "pygama.evt.modules.spm.get_energy(0.5,evt.t0,48000,1000,5000)" - }, - "lar_multiplicity": { - "channels": "spms_on", - "aggregation_mode": "function", - "expression": ".modules.spm.get_majority(0.5,evt.t0,48000,1000,5000)" - }, - "is_lar_rejected": { - "expression": "(evt.lar_energy >4) | (evt.lar_multiplicity > 4) " - }, - "lar_classifier": { - "channels": "spms_on", - "aggregation_mode": "function", - "expression": ".modules.spm.get_etc(0.5,evt.t0,48000,100,6000,80,1,0,50)" - }, - "lar_energy_dplms": { - "channels": "spms_on", - "aggregation_mode": "function", - "expression": ".modules.spm.get_energy_dplms(0.5,evt.t0,48000,1000,5000)" - }, - "lar_multiplicity_dplms": { - "channels": "spms_on", - "aggregation_mode": "function", - "expression": ".modules.spm.get_majority_dplms(0.5,evt.t0,48000,1000,5000)" - }, - "lar_time_shift": { - "channels": "spms_on", - "aggregation_mode": "function", - "expression": ".modules.spm.get_time_shift(0.5,evt.t0,48000,1000,5000)" - } - } -} diff --git a/tests/evt/configs/module-test-t0-vov-evt-config.json b/tests/evt/configs/module-test-t0-vov-evt-config.json deleted file mode 100644 index cda042337..000000000 --- a/tests/evt/configs/module-test-t0-vov-evt-config.json +++ /dev/null @@ -1,82 +0,0 @@ -{ - "channels": { - "spms_on": ["ch1057600", "ch1059201", "ch1062405"], - "geds_on": ["ch1084803", "ch1084804", "ch1121600"] - }, - "outputs": [ - "energy", - "energy_id", - "t0", - "lar_energy", - "lar_multiplicity", - "is_lar_rejected", - "lar_classifier", - "lar_energy_dplms", - "lar_multiplicity_dplms", - "lar_time_shift", - "lar_tcm_index", - "lar_pulse_index" - ], - "operations": { - "energy": { - "channels": "geds_on", - "aggregation_mode": "gather", - "query": "hit.cuspEmax_ctc_cal>25", - "expression": "hit.cuspEmax_ctc_cal" - }, - "energy_id": { - "channels": "geds_on", - "aggregation_mode": "gather", - "query": "hit.cuspEmax_ctc_cal>25", - "expression": "tcm.array_id" - }, - "t0": { - "aggregation_mode": "keep_at_ch:evt.energy_id", - "expression": "dsp.tp_0_est", - "initial": 0.0 - }, - "lar_energy": { - "channels": "spms_on", - "aggregation_mode": "function", - "expression": ".modules.spm.get_energy(0.5,evt.t0,48000,1000,5000)" - }, - "lar_multiplicity": { - "channels": "spms_on", - "aggregation_mode": "function", - "expression": ".modules.spm.get_majority(0.5,evt.t0,48000,1000,5000)" - }, - "is_lar_rejected": { - "expression": "(evt.lar_energy >4) | (evt.lar_multiplicity > 4) " - }, - "lar_classifier": { - "channels": "spms_on", - "aggregation_mode": "function", - "expression": ".modules.spm.get_etc(0.5,evt.t0,48000,100,6000,80,1,0,50)" - }, - "lar_energy_dplms": { - "channels": "spms_on", - "aggregation_mode": "function", - "expression": ".modules.spm.get_energy_dplms(0.5,evt.t0,48000,1000,5000)" - }, - "lar_multiplicity_dplms": { - "channels": "spms_on", - "aggregation_mode": "function", - "expression": ".modules.spm.get_majority_dplms(0.5,evt.t0,48000,1000,5000)" - }, - "lar_time_shift": { - "channels": "spms_on", - "aggregation_mode": "function", - "expression": ".modules.spm.get_time_shift(0.5,evt.t0,48000,1000,5000)" - }, - "lar_tcm_index": { - "channels": "spms_on", - "aggregation_mode": "function", - "expression": ".modules.spm.get_masked_tcm_idx(0.5,evt.t0,48000,1000,5000,1)" - }, - "lar_pulse_index": { - "channels": "spms_on", - "aggregation_mode": "function", - "expression": ".modules.spm.get_masked_tcm_idx(0.5,evt.t0,48000,1000,5000,0)" - } - } -} diff --git a/tests/evt/configs/spms-module-config.yaml b/tests/evt/configs/spms-module-config.yaml new file mode 100644 index 000000000..2e9b3119a --- /dev/null +++ b/tests/evt/configs/spms-module-config.yaml @@ -0,0 +1,99 @@ +channels: + spms_on: + - ch1057600 + - ch1059201 + - ch1062405 + geds_on: + - ch1084803 + - ch1084804 + - ch1121600 +outputs: + - t0 + - _pulse_mask + - spms_amp + - rawid + - hit_idx + - rawid_wo_empty + - spms_amp_full + - spms_amp_wo_empty + - trigger_pos + - is_valid_hit + - lar_coinc_class +operations: + t0: + channels: geds_on + aggregation_mode: first_at:dsp.tp_0_est + expression: dsp.tp_0_est + query: hit.cuspEmax_ctc_cal > 25 + initial: np.nan + _pulse_mask: + channels: spms_on + aggregation_mode: function + expression: pygama.evt.modules.spms.make_pulse_data_mask( + <...>, + a_thr_pe=0.1, + t_loc_ns=evt.t0, + dt_range_ns=(-30_000, 30_000), + t_loc_default_ns=48_000) + trigger_pos_full: + channels: spms_on + aggregation_mode: function + expression: pygama.evt.modules.spms.gather_pulse_data( + <...>, + observable='hit.trigger_pos', + drop_empty=False) + trigger_pos: + channels: spms_on + aggregation_mode: function + expression: pygama.evt.modules.spms.gather_pulse_data( + <...>, + observable='hit.trigger_pos', + pulse_mask=evt._pulse_mask, + drop_empty=False) + is_valid_hit: + channels: spms_on + aggregation_mode: function + expression: pygama.evt.modules.spms.gather_is_valid_hit(<...>) + spms_amp: + channels: spms_on + aggregation_mode: function + expression: pygama.evt.modules.spms.gather_pulse_data( + <...>, + observable='hit.energy_in_pe', + pulse_mask=evt._pulse_mask, + drop_empty=False) + rawid: + channels: spms_on + aggregation_mode: function + expression: pygama.evt.modules.spms.gather_tcm_data(<...>, drop_empty=False) + hit_idx: + channels: spms_on + aggregation_mode: function + expression: pygama.evt.modules.spms.gather_tcm_data(<...>, + tcm_field='idx', + drop_empty=False) + spms_amp_full: + channels: spms_on + aggregation_mode: function + expression: pygama.evt.modules.spms.gather_pulse_data( + <...>, + observable='hit.energy_in_pe', + drop_empty=False) + spms_amp_wo_empty: + channels: spms_on + aggregation_mode: function + expression: pygama.evt.modules.spms.gather_pulse_data( + <...>, + observable='hit.energy_in_pe', + pulse_mask=evt._pulse_mask) + rawid_wo_empty: + channels: spms_on + aggregation_mode: function + expression: pygama.evt.modules.spms.gather_tcm_data( + <...>, + pulse_mask=evt._pulse_mask, + drop_empty=True) + lar_coin_class: + channels: spms_on + aggregation_mode: function + expression: pygama.evt.modules.spms.geds_coincidence_classifier(<...>, geds_t0_ns=evt.t0) diff --git a/tests/evt/configs/vov-test-evt-config.json b/tests/evt/configs/vov-test-evt-config.json index 31334101e..6de44075b 100644 --- a/tests/evt/configs/vov-test-evt-config.json +++ b/tests/evt/configs/vov-test-evt-config.json @@ -28,14 +28,15 @@ "channels": "geds_on", "aggregation_mode": "gather", "query": "hit.cuspEmax_ctc_cal>25", - "expression": "hit.cuspEmax_ctc_cal" + "expression": "hit.cuspEmax_ctc_cal", + "dtype": "float32" }, "energy_sum": { "channels": "geds_on", "aggregation_mode": "sum", "query": "hit.cuspEmax_ctc_cal>25", "expression": "hit.cuspEmax_ctc_cal", - "initial": 0.0 + "initial": 0 }, "energy_idx": { "channels": "geds_on", @@ -66,7 +67,8 @@ "aggregation_mode": "sum", "expression": "hit.cuspEmax_ctc_cal > a", "parameters": { "a": 25 }, - "initial": 0 + "initial": 0, + "dtype": "int16" }, "is_saturated": { "aggregation_mode": "keep_at_ch:evt.energy_id", diff --git a/tests/evt/modules/larveto.py b/tests/evt/modules/larveto.py new file mode 100644 index 000000000..79f580234 --- /dev/null +++ b/tests/evt/modules/larveto.py @@ -0,0 +1,14 @@ +import numpy as np +import pytest + +from pygama.evt.modules import larveto + + +def test_tc_time_pdf(): + assert isinstance(larveto.l200_tc_time_pdf(0), float) + assert isinstance( + larveto.l200_tc_time_pdf(np.array([0, -0.5, 3]) * 1e3), np.ndarray + ) + + with pytest.raises(ValueError): + assert isinstance(larveto.l200_tc_time_pdf(-10000), float) diff --git a/tests/evt/test_build_evt.py b/tests/evt/test_build_evt.py index 80a40d9a8..99bf66d6f 100644 --- a/tests/evt/test_build_evt.py +++ b/tests/evt/test_build_evt.py @@ -4,7 +4,7 @@ import awkward as ak import numpy as np import pytest -from lgdo import Array, VectorOfVectors, lh5 +from lgdo import Array, Table, VectorOfVectors, lh5 from lgdo.lh5 import LH5Store from pygama.evt import build_evt @@ -13,197 +13,224 @@ store = LH5Store() -def test_basics(lgnd_test_data, tmptestdir): - outfile = f"{tmptestdir}/l200-p03-r001-phy-20230322T160139Z-tier_evt.lh5" +@pytest.fixture(scope="module") +def files_config(lgnd_test_data, tmptestdir): tcm_path = "lh5/prod-ref-l200/generated/tier/tcm/phy/p03/r001/l200-p03-r001-phy-20230322T160139Z-tier_tcm.lh5" - if os.path.exists(outfile): - os.remove(outfile) + outfile = f"{tmptestdir}/l200-p03-r001-phy-20230322T160139Z-tier_evt.lh5" + + return { + "tcm": (lgnd_test_data.get_path(tcm_path), "hardware_tcm_1"), + "dsp": (lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")), "dsp", "ch{}"), + "hit": (lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")), "hit", "ch{}"), + "evt": (outfile, "evt"), + } + +def test_basics(lgnd_test_data, files_config): build_evt( - f_tcm=lgnd_test_data.get_path(tcm_path), - f_dsp=lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")), - f_hit=lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")), - evt_config=f"{config_dir}/basic-evt-config.json", - f_evt=outfile, - wo_mode="o", - evt_group="evt", - hit_group="hit", - dsp_group="dsp", - tcm_group="hardware_tcm_1", + files_config, + config=f"{config_dir}/basic-evt-config.yaml", + wo_mode="of", ) - assert "statement" in store.read("/evt/multiplicity", outfile)[0].getattrs().keys() - assert ( - store.read("/evt/multiplicity", outfile)[0].getattrs()["statement"] - == "0bb decay is real" - ) + outfile = files_config["evt"][0] + f_tcm = files_config["tcm"][0] + + evt = lh5.read("evt", outfile) + + assert "statement" in evt.multiplicity.attrs + assert evt.multiplicity.attrs["statement"] == "0bb decay is real" + assert os.path.exists(outfile) - assert len(lh5.ls(outfile, "/evt/")) == 11 - nda = { - e: store.read(f"/evt/{e}", outfile)[0].view_as("np") - for e in ["energy", "energy_aux", "energy_sum", "multiplicity"] - } - assert ( - nda["energy"][nda["multiplicity"] == 1] - == nda["energy_aux"][nda["multiplicity"] == 1] - ).all() - assert ( - nda["energy"][nda["multiplicity"] == 1] - == nda["energy_sum"][nda["multiplicity"] == 1] - ).all() - assert ( - nda["energy_aux"][nda["multiplicity"] == 1] - == nda["energy_sum"][nda["multiplicity"] == 1] - ).all() + assert sorted(evt.keys()) == [ + "aoe", + "energy", + "energy_all_above1MeV", + "energy_any_above1MeV", + "energy_hit_idx", + "energy_id", + "energy_idx", + "energy_sum", + "is_aoe_rejected", + "is_usable_aoe", + "multiplicity", + "timestamp", + ] + + ak_evt = evt.view_as("ak") + + assert ak.all(ak_evt.energy_sum == ak.sum(ak_evt.energy, axis=-1)) eid = store.read("/evt/energy_id", outfile)[0].view_as("np") eidx = store.read("/evt/energy_idx", outfile)[0].view_as("np") eidx = eidx[eidx != 999999999999] - ids = store.read("hardware_tcm_1/array_id", lgnd_test_data.get_path(tcm_path))[ - 0 - ].view_as("np") + ids = store.read("hardware_tcm_1/array_id", f_tcm)[0].view_as("np") ids = ids[eidx] assert ak.all(ids == eid[eid != 0]) + ehidx = store.read("/evt/energy_hit_idx", outfile)[0].view_as("np") + ids = store.read("hardware_tcm_1/array_idx", f_tcm)[0].view_as("np") + ids = ids[eidx] + assert ak.all(ids == ehidx[ehidx != 999999999999]) + + +def test_field_nesting(lgnd_test_data, files_config): + config = { + "channels": {"geds_on": ["ch1084803", "ch1084804", "ch1121600"]}, + "outputs": [ + "sub1___timestamp", + "sub2___multiplicity", + "sub2___dummy", + ], + "operations": { + "sub1___timestamp": { + "channels": "geds_on", + "aggregation_mode": "first_at:dsp.tp_0_est", + "expression": "dsp.timestamp", + }, + "sub2___multiplicity": { + "channels": "geds_on", + "aggregation_mode": "sum", + "expression": "hit.cuspEmax_ctc_cal > 25", + "initial": 0, + }, + "sub2___dummy": { + "channels": "geds_on", + "aggregation_mode": "sum", + "expression": "hit.cuspEmax_ctc_cal > evt.sub1___timestamp", + "initial": 0, + }, + }, + } -def test_lar_module(lgnd_test_data, tmptestdir): - outfile = f"{tmptestdir}/l200-p03-r001-phy-20230322T160139Z-tier_evt.lh5" - tcm_path = "lh5/prod-ref-l200/generated/tier/tcm/phy/p03/r001/l200-p03-r001-phy-20230322T160139Z-tier_tcm.lh5" - if os.path.exists(outfile): - os.remove(outfile) build_evt( - f_tcm=lgnd_test_data.get_path(tcm_path), - f_dsp=lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")), - f_hit=lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")), - evt_config=f"{config_dir}/module-test-evt-config.json", - f_evt=outfile, - wo_mode="o", - evt_group="evt", - hit_group="hit", - dsp_group="dsp", - tcm_group="hardware_tcm_1", + files_config, + config=config, + wo_mode="of", ) - assert os.path.exists(outfile) - assert len(lh5.ls(outfile, "/evt/")) == 10 - nda = { - e: store.read(f"/evt/{e}", outfile)[0].view_as("np") - for e in ["lar_multiplicity", "lar_multiplicity_dplms", "t0", "lar_time_shift"] - } - assert np.max(nda["lar_multiplicity"]) <= 3 - assert np.max(nda["lar_multiplicity_dplms"]) <= 3 - assert ((nda["lar_time_shift"] + nda["t0"]) >= 0).all() + outfile = files_config["evt"][0] + evt = lh5.read("/evt", outfile) + assert isinstance(evt, Table) + assert isinstance(evt.sub1, Table) + assert isinstance(evt.sub2, Table) + assert isinstance(evt.sub1.timestamp, Array) -def test_lar_t0_vov_module(lgnd_test_data, tmptestdir): - outfile = f"{tmptestdir}/l200-p03-r001-phy-20230322T160139Z-tier_evt.lh5" - tcm_path = "lh5/prod-ref-l200/generated/tier/tcm/phy/p03/r001/l200-p03-r001-phy-20230322T160139Z-tier_tcm.lh5" - if os.path.exists(outfile): - os.remove(outfile) + assert sorted(evt.keys()) == ["sub1", "sub2"] + assert sorted(evt.sub1.keys()) == ["timestamp"] + assert sorted(evt.sub2.keys()) == ["dummy", "multiplicity"] + + +def test_spms_module(lgnd_test_data, files_config): build_evt( - f_tcm=lgnd_test_data.get_path(tcm_path), - f_dsp=lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")), - f_hit=lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")), - evt_config=f"{config_dir}/module-test-t0-vov-evt-config.json", - f_evt=outfile, - wo_mode="o", - evt_group="evt", - hit_group="hit", - dsp_group="dsp", - tcm_group="hardware_tcm_1", + files_config, + config=f"{config_dir}/spms-module-config.yaml", + wo_mode="of", ) - assert os.path.exists(outfile) - assert len(lh5.ls(outfile, "/evt/")) == 12 - nda = { - e: store.read(f"/evt/{e}", outfile)[0].view_as("np") - for e in ["lar_multiplicity", "lar_multiplicity_dplms", "lar_time_shift"] - } - assert np.max(nda["lar_multiplicity"]) <= 3 - assert np.max(nda["lar_multiplicity_dplms"]) <= 3 + outfile = files_config["evt"][0] - ch_idx = store.read("/evt/lar_tcm_index", outfile)[0].view_as("ak") - pls_idx = store.read("/evt/lar_pulse_index", outfile)[0].view_as("ak") - assert ak.count(ch_idx) == ak.count(pls_idx) - assert ak.all(ak.count(ch_idx, axis=-1) == ak.count(pls_idx, axis=-1)) + evt = lh5.read("/evt", outfile) + t0 = ak.fill_none(ak.nan_to_none(evt.t0.view_as("ak")), 48_000) + tr_pos = evt.trigger_pos.view_as("ak") * 16 + assert ak.all(tr_pos > t0 - 30_000) + assert ak.all(tr_pos < t0 + 30_000) -def test_vov(lgnd_test_data, tmptestdir): - outfile = f"{tmptestdir}/l200-p03-r001-phy-20230322T160139Z-tier_evt.lh5" - tcm_path = "lh5/prod-ref-l200/generated/tier/tcm/phy/p03/r001/l200-p03-r001-phy-20230322T160139Z-tier_tcm.lh5" - if os.path.exists(outfile): - os.remove(outfile) + mask = evt._pulse_mask + assert isinstance(mask, VectorOfVectors) + assert len(mask) == 10 + assert mask.ndim == 3 + + full = evt.spms_amp_full.view_as("ak") + amp = evt.spms_amp.view_as("ak") + assert ak.all(amp > 0.1) + + assert ak.all(full[mask.view_as("ak")] == amp) + + wo_empty = evt.spms_amp_wo_empty.view_as("ak") + assert ak.all(wo_empty == amp[ak.count(amp, axis=-1) > 0]) + + rawids = evt.rawid.view_as("ak") + assert rawids.ndim == 2 + assert ak.count(rawids) == 30 + + idx = evt.hit_idx.view_as("ak") + assert idx.ndim == 2 + assert ak.count(idx) == 30 + + rawids_wo_empty = evt.rawid_wo_empty.view_as("ak") + assert ak.count(rawids_wo_empty) == 7 + + vhit = evt.is_valid_hit.view_as("ak") + vhit.show() + assert ak.all(ak.num(vhit, axis=-1) == ak.num(full, axis=-1)) + + +def test_vov(lgnd_test_data, files_config): build_evt( - f_tcm=lgnd_test_data.get_path(tcm_path), - f_dsp=lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")), - f_hit=lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")), - evt_config=f"{config_dir}/vov-test-evt-config.json", - f_evt=outfile, - wo_mode="o", - evt_group="evt", - hit_group="hit", - dsp_group="dsp", - tcm_group="hardware_tcm_1", + files_config, + config=f"{config_dir}/vov-test-evt-config.json", + wo_mode="of", ) + outfile = files_config["evt"][0] + f_tcm = files_config["tcm"][0] + assert os.path.exists(outfile) assert len(lh5.ls(outfile, "/evt/")) == 12 + + timestamp, _ = store.read("/evt/timestamp", outfile) + assert np.all(~np.isnan(timestamp.nda)) + vov_ene, _ = store.read("/evt/energy", outfile) vov_aoe, _ = store.read("/evt/aoe", outfile) arr_ac, _ = store.read("/evt/multiplicity", outfile) vov_aoeene, _ = store.read("/evt/energy_times_aoe", outfile) vov_eneac, _ = store.read("/evt/energy_times_multiplicity", outfile) arr_ac2, _ = store.read("/evt/multiplicity_squared", outfile) + assert isinstance(vov_ene, VectorOfVectors) assert isinstance(vov_aoe, VectorOfVectors) assert isinstance(arr_ac, Array) assert isinstance(vov_aoeene, VectorOfVectors) assert isinstance(vov_eneac, VectorOfVectors) assert isinstance(arr_ac2, Array) + + assert vov_ene.dtype == "float32" + assert vov_aoe.dtype == "float64" + assert arr_ac.dtype == "int16" + assert (np.diff(vov_ene.cumulative_length.nda, prepend=[0]) == arr_ac.nda).all() vov_eid = store.read("/evt/energy_id", outfile)[0].view_as("ak") vov_eidx = store.read("/evt/energy_idx", outfile)[0].view_as("ak") vov_aoe_idx = store.read("/evt/aoe_idx", outfile)[0].view_as("ak") - ids = store.read("hardware_tcm_1/array_id", lgnd_test_data.get_path(tcm_path))[ - 0 - ].view_as("ak") + ids = store.read("hardware_tcm_1/array_id", f_tcm)[0].view_as("ak") ids = ak.unflatten(ids[ak.flatten(vov_eidx)], ak.count(vov_eidx, axis=-1)) assert ak.all(ids == vov_eid) arr_ene = store.read("/evt/energy_sum", outfile)[0].view_as("ak") - assert ak.all(arr_ene == ak.nansum(vov_ene.view_as("ak"), axis=-1)) + assert ak.all( + ak.isclose(arr_ene, ak.nansum(vov_ene.view_as("ak"), axis=-1), rtol=1e-3) + ) assert ak.all(vov_aoe.view_as("ak") == vov_aoe_idx) -def test_graceful_crashing(lgnd_test_data, tmptestdir): - outfile = f"{tmptestdir}/l200-p03-r001-phy-20230322T160139Z-tier_evt.lh5" - tcm_path = "lh5/prod-ref-l200/generated/tier/tcm/phy/p03/r001/l200-p03-r001-phy-20230322T160139Z-tier_tcm.lh5" - if os.path.exists(outfile): - os.remove(outfile) - f_tcm = lgnd_test_data.get_path(tcm_path) - f_dsp = lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")) - f_hit = lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")) - f_config = f"{config_dir}/basic-evt-config.json" - - with pytest.raises(KeyError): - build_evt(f_dsp, f_tcm, f_hit, f_config, outfile) - - with pytest.raises(KeyError): - build_evt(f_tcm, f_hit, f_dsp, f_config, outfile) - +def test_graceful_crashing(lgnd_test_data, files_config): with pytest.raises(TypeError): - build_evt(f_tcm, f_dsp, f_hit, None, outfile) + build_evt(files_config, None, wo_mode="of") conf = {"operations": {}} with pytest.raises(ValueError): - build_evt(f_tcm, f_dsp, f_hit, conf, outfile) + build_evt(files_config, conf, wo_mode="of") conf = {"channels": {"geds_on": ["ch1084803", "ch1084804", "ch1121600"]}} with pytest.raises(ValueError): - build_evt(f_tcm, f_dsp, f_hit, conf, outfile) + build_evt(files_config, conf, wo_mode="of") conf = { "channels": {"geds_on": ["ch1084803", "ch1084804", "ch1121600"]}, @@ -219,38 +246,25 @@ def test_graceful_crashing(lgnd_test_data, tmptestdir): }, } with pytest.raises(ValueError): - build_evt(f_tcm, f_dsp, f_hit, conf, outfile) + build_evt( + files_config, + conf, + wo_mode="of", + ) -def test_query(lgnd_test_data, tmptestdir): - outfile = f"{tmptestdir}/l200-p03-r001-phy-20230322T160139Z-tier_evt.lh5" - tcm_path = "lh5/prod-ref-l200/generated/tier/tcm/phy/p03/r001/l200-p03-r001-phy-20230322T160139Z-tier_tcm.lh5" - if os.path.exists(outfile): - os.remove(outfile) +def test_query(lgnd_test_data, files_config): build_evt( - f_tcm=lgnd_test_data.get_path(tcm_path), - f_dsp=lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")), - f_hit=lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")), - evt_config=f"{config_dir}/query-test-evt-config.json", - f_evt=outfile, - wo_mode="o", - evt_group="evt", - hit_group="hit", - dsp_group="dsp", - tcm_group="hardware_tcm_1", + files_config, + config=f"{config_dir}/query-test-evt-config.json", + wo_mode="of", ) + outfile = files_config["evt"][0] + assert len(lh5.ls(outfile, "/evt/")) == 12 -def test_vector_sort(lgnd_test_data, tmptestdir): - outfile = f"{tmptestdir}/l200-p03-r001-phy-20230322T160139Z-tier_evt.lh5" - tcm_path = "lh5/prod-ref-l200/generated/tier/tcm/phy/p03/r001/l200-p03-r001-phy-20230322T160139Z-tier_tcm.lh5" - if os.path.exists(outfile): - os.remove(outfile) - f_tcm = lgnd_test_data.get_path(tcm_path) - f_dsp = lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")) - f_hit = lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")) - +def test_vector_sort(lgnd_test_data, files_config): conf = { "channels": {"geds_on": ["ch1084803", "ch1084804", "ch1121600"]}, "outputs": ["acend_id", "t0_acend", "decend_id", "t0_decend"], @@ -279,7 +293,14 @@ def test_vector_sort(lgnd_test_data, tmptestdir): }, }, } - build_evt(f_tcm, f_dsp, f_hit, conf, outfile) + + build_evt( + files_config, + conf, + wo_mode="of", + ) + + outfile = files_config["evt"][0] assert os.path.exists(outfile) assert len(lh5.ls(outfile, "/evt/")) == 4 @@ -289,27 +310,3 @@ def test_vector_sort(lgnd_test_data, tmptestdir): vov_t0, _ = store.read("/evt/t0_decend", outfile) nda_t0 = vov_t0.to_aoesa().view_as("np") assert ((np.diff(nda_t0) <= 0) | (np.isnan(np.diff(nda_t0)))).all() - - -def test_tcm_id_table_pattern(lgnd_test_data, tmptestdir): - outfile = f"{tmptestdir}/l200-p03-r001-phy-20230322T160139Z-tier_evt.lh5" - tcm_path = "lh5/prod-ref-l200/generated/tier/tcm/phy/p03/r001/l200-p03-r001-phy-20230322T160139Z-tier_tcm.lh5" - if os.path.exists(outfile): - os.remove(outfile) - f_tcm = lgnd_test_data.get_path(tcm_path) - f_dsp = lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")) - f_hit = lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")) - f_config = f"{config_dir}/basic-evt-config.json" - - with pytest.raises(ValueError): - build_evt(f_tcm, f_dsp, f_hit, f_config, outfile, tcm_id_table_pattern="ch{{}}") - with pytest.raises(ValueError): - build_evt(f_tcm, f_dsp, f_hit, f_config, outfile, tcm_id_table_pattern="ch{}{}") - with pytest.raises(NotImplementedError): - build_evt( - f_tcm, f_dsp, f_hit, f_config, outfile, tcm_id_table_pattern="ch{tcm_id}" - ) - with pytest.raises(ValueError): - build_evt( - f_tcm, f_dsp, f_hit, f_config, outfile, tcm_id_table_pattern="apple{}banana" - ) diff --git a/tests/evt/test_utils.py b/tests/evt/test_utils.py new file mode 100644 index 000000000..c3548289e --- /dev/null +++ b/tests/evt/test_utils.py @@ -0,0 +1,22 @@ +from pygama.evt import utils + + +def test_tier_data_tuple(): + files = utils.make_files_config( + { + "tcm": ("f1", "g1"), + "dsp": ("f2", "g2"), + "hit": ("f3", "g3"), + "evt": ("f4", "g4"), + } + ) + + assert files.raw == utils.H5DataLoc() + assert files.tcm.file == "f1" + assert files.tcm.group == "g1" + assert files.dsp.file == "f2" + assert files.dsp.group == "g2" + assert files.hit.file == "f3" + assert files.hit.group == "g3" + assert files.evt.file == "f4" + assert files.evt.group == "g4" diff --git a/tests/flow/test_filedb.py b/tests/flow/test_filedb.py index fe8fa72cb..8a57160d3 100644 --- a/tests/flow/test_filedb.py +++ b/tests/flow/test_filedb.py @@ -2,6 +2,7 @@ from pathlib import Path import pytest +from lgdo.lh5.exceptions import LH5EncodeError from pandas.testing import assert_frame_equal from pygama.flow import FileDB @@ -346,7 +347,7 @@ def test_serialization(test_filedb_full, tmptestdir): db = test_filedb_full db.to_disk(f"{tmptestdir}/filedb.lh5", wo_mode="of") - with pytest.raises(RuntimeError): + with pytest.raises(LH5EncodeError): db.to_disk(f"{tmptestdir}/filedb.lh5") db2 = FileDB(f"{tmptestdir}/filedb.lh5") diff --git a/tests/skm/test_build_skm.py b/tests/skm/test_build_skm.py index c60c460f0..00fda9f08 100644 --- a/tests/skm/test_build_skm.py +++ b/tests/skm/test_build_skm.py @@ -3,6 +3,7 @@ import awkward as ak import lgdo +import pytest from lgdo.lh5 import LH5Store from pygama.evt import build_evt @@ -13,33 +14,35 @@ store = LH5Store() -def test_basics(lgnd_test_data, tmptestdir): - outfile = f"{tmptestdir}/l200-p03-r001-phy-20230322T160139Z-tier_evt.lh5" +@pytest.fixture(scope="module") +def files_config(lgnd_test_data, tmptestdir): tcm_path = "lh5/prod-ref-l200/generated/tier/tcm/phy/p03/r001/l200-p03-r001-phy-20230322T160139Z-tier_tcm.lh5" - if os.path.exists(outfile): - os.remove(outfile) + outfile = f"{tmptestdir}/l200-p03-r001-phy-20230322T160139Z-tier_evt.lh5" + + return { + "tcm": (lgnd_test_data.get_path(tcm_path), "hardware_tcm_1"), + "dsp": (lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")), "dsp", "ch{}"), + "hit": (lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")), "hit", "ch{}"), + "evt": (outfile, "evt"), + } + +def test_basics(tmptestdir, files_config): build_evt( - f_tcm=lgnd_test_data.get_path(tcm_path), - f_dsp=lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")), - f_hit=lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")), - evt_config=f"{evt_config_dir}/vov-test-evt-config.json", - f_evt=outfile, - wo_mode="o", - evt_group="evt", - hit_group="hit", - dsp_group="dsp", - tcm_group="hardware_tcm_1", + files_config, + config=f"{evt_config_dir}/vov-test-evt-config.json", + wo_mode="of", ) + outfile = files_config["evt"][0] skm_conf = f"{config_dir}/basic-skm-config.json" skm_out = f"{tmptestdir}/l200-p03-r001-phy-20230322T160139Z-tier_skm.lh5" result = build_skm( outfile, - lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")), - lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")), - lgnd_test_data.get_path(tcm_path), + files_config["hit"][0], + files_config["dsp"][0], + files_config["tcm"][0], skm_conf, ) @@ -47,9 +50,9 @@ def test_basics(lgnd_test_data, tmptestdir): build_skm( outfile, - lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")), - lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")), - lgnd_test_data.get_path(tcm_path), + files_config["hit"][0], + files_config["dsp"][0], + files_config["tcm"][0], skm_conf, skm_out, wo_mode="o", @@ -71,10 +74,6 @@ def test_basics(lgnd_test_data, tmptestdir): assert "multiplicity" in df.keys() assert "energy_sum" in df.keys() assert (df.multiplicity.to_numpy() <= 3).all() - assert ( - df.energy_0.to_numpy() + df.energy_1.to_numpy() + df.energy_2.to_numpy() - == df.energy_sum.to_numpy() - ).all() vov_eid = ak.to_numpy( ak.fill_none( @@ -90,23 +89,15 @@ def test_basics(lgnd_test_data, tmptestdir): assert (vov_eid[:, 2] == df.energy_id_2.to_numpy()).all() -def test_attribute_passing(lgnd_test_data, tmptestdir): +def test_attribute_passing(tmptestdir, files_config): outfile = f"{tmptestdir}/l200-p03-r001-phy-20230322T160139Z-tier_evt.lh5" - tcm_path = "lh5/prod-ref-l200/generated/tier/tcm/phy/p03/r001/l200-p03-r001-phy-20230322T160139Z-tier_tcm.lh5" if os.path.exists(outfile): os.remove(outfile) build_evt( - f_tcm=lgnd_test_data.get_path(tcm_path), - f_dsp=lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")), - f_hit=lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")), - evt_config=f"{evt_config_dir}/vov-test-evt-config.json", - f_evt=outfile, - wo_mode="o", - evt_group="evt", - hit_group="hit", - dsp_group="dsp", - tcm_group="hardware_tcm_1", + files_config, + config=f"{evt_config_dir}/vov-test-evt-config.json", + wo_mode="of", ) skm_conf = f"{config_dir}/basic-skm-config.json" @@ -115,9 +106,9 @@ def test_attribute_passing(lgnd_test_data, tmptestdir): build_skm( outfile, - lgnd_test_data.get_path(tcm_path.replace("tcm", "hit")), - lgnd_test_data.get_path(tcm_path.replace("tcm", "dsp")), - lgnd_test_data.get_path(tcm_path), + files_config["hit"][0], + files_config["dsp"][0], + files_config["tcm"][0], skm_conf, f_skm=skm_out, wo_mode="o",