Skip to content

Commit

Permalink
Updated pargen.utils.load_data to use LH5Iterator and field_mask to b…
Browse files Browse the repository at this point in the history
…e more memory efficient
  • Loading branch information
iguinn committed Aug 25, 2024
1 parent 981877e commit 023addd
Showing 1 changed file with 74 additions and 69 deletions.
143 changes: 74 additions & 69 deletions src/pygama/pargen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from lgdo import lh5

log = logging.getLogger(__name__)
sto = lh5.LH5Store()


def convert_to_minuit(pars, func):
Expand All @@ -35,101 +34,107 @@ def return_nans(input):
return m.values, m.errors, np.full((len(m.values), len(m.values)), np.nan)


def get_params(file_params, param_list):
out_params = []
if isinstance(file_params, dict):
possible_keys = file_params.keys()
elif isinstance(file_params, list):
possible_keys = file_params
for param in param_list:
for key in possible_keys:
if key in param:
out_params.append(key)
return np.unique(out_params).tolist()


def load_data(
files: list,
files: str | list | dict,
lh5_path: str,
cal_dict: dict,
params: list,
params: set,
cal_energy_param: str = "cuspEmax_ctc_cal",
threshold=None,
return_selection_mask=False,
) -> tuple(np.array, np.array, np.array, np.array):
) -> pd.DataFrame | tuple(pd.DataFrame, np.array):
"""
Loads in the A/E parameters needed and applies calibration constants to energy
Loads parameters from data files. Applies calibration to cal_energy_param
and uses this to apply a lower energy threshold.
files
file or list of files or dict pointing from timestamps to lists of files
lh5_path
path to table in files
cal_dict
dictionary with operations used to apply calibration constants
params
list of parameters to load from file
cal_energy_param
name of uncalibrated energy parameter
threshold
lower energy threshold for events to load
return_selection_map
if True, return selection mask for threshold along with data
"""

params = set(params)
if isinstance(files, str):
files = [files]

if isinstance(files, dict):
keys = lh5.ls(
files[list(files)[0]][0],
lh5_path if lh5_path[-1] == "/" else lh5_path + "/",
)
keys = [key.split("/")[-1] for key in keys]
if list(files)[0] in cal_dict:
params = get_params(keys + list(cal_dict[list(files)[0]].keys()), params)
else:
params = get_params(keys + list(cal_dict.keys()), params)

# Go through each tstamp and recursively load_data on file lists
df = []
all_files = []
masks = np.array([], dtype=bool)
masks = []
for tstamp, tfiles in files.items():
table = sto.read(lh5_path, tfiles)[0]

file_df = pd.DataFrame(columns=params)
if tstamp in cal_dict:
cal_dict_ts = cal_dict[tstamp]
else:
cal_dict_ts = cal_dict

for outname, info in cal_dict_ts.items():
outcol = table.eval(info["expression"], info.get("parameters", None))
table.add_column(outname, outcol)

for param in params:
file_df[param] = table[param]

file_df = load_data(
tfiles,
lh5_path,
cal_dict.get(tstamp, cal_dict),
params,
cal_energy_param,
threshold,
return_selection_mask,
)
file_df["run_timestamp"] = np.full(len(file_df), tstamp, dtype=object)

if threshold is not None:
mask = file_df[cal_energy_param] > threshold
file_df.drop(np.where(~mask)[0], inplace=True)
if return_selection_mask:
df.append(file_df[0])
masks.append(file_df[1])
else:
mask = np.ones(len(file_df), dtype=bool)
masks = np.append(masks, mask)
df.append(file_df)
all_files += tfiles
df.append(file_df)

params.append("run_timestamp")
df = pd.concat(df)
if return_selection_mask:
masks = np.concat(masks)

elif isinstance(files, list):
keys = lh5.ls(files[0], lh5_path if lh5_path[-1] == "/" else lh5_path + "/")
keys = [key.split("/")[-1] for key in keys]
params = get_params(keys + list(cal_dict.keys()), params)

table = sto.read(lh5_path, files)[0]
df = pd.DataFrame(columns=params)
for outname, info in cal_dict.items():
outcol = table.eval(info["expression"], info.get("parameters", None))
table.add_column(outname, outcol)
for param in params:
df[param] = table[param]
# Get set of available fields between input table and cal_dict
file_keys = lh5.ls(
files[0], lh5_path if lh5_path[-1] == "/" else lh5_path + "/"
)
file_keys = {key.split("/")[-1] for key in file_keys}

# Get set of keys in calibration expressions that show up in file
cal_keys = {
name
for expr in cal_dict.values()
for name in compile(expr, "0vbb is real!", "eval").co_names
} & file_keys

# Get set of fields to read from files
fields = cal_keys | (file_keys & params)

lh5_it = lh5.iterator.LH5Iterator(
files, lh5_path, field_mask=fields, buffer_len=100000
)
df = pd.DataFrame(columns=list(params))
for table, entry, n_rows in lh5_it:
# Evaluate all provided expressions and add to table
for outname, info in cal_dict.items():
table[outname] = table.eval(
info["expression"], local_dict=info.get("parameters", None)
)

# Copy params in table into dataframe
for par in params:
# First set of entries: allocate enough memory for all entries
if entry == 0:
df[par] = np.resize(table[par], len(lh5_it))
else:
df.loc[entry : entry + n_rows - 1, par] = table[par][:n_rows]

# Evaluate threshold mask and drop events below threshold
if threshold is not None:
masks = df[cal_energy_param] > threshold
df.drop(np.where(~masks)[0], inplace=True)
else:
masks = np.ones(len(df), dtype=bool)
all_files = files

for col in list(df.keys()):
if col not in params:
df.drop(col, inplace=True, axis=1)

log.debug("data loaded")
if return_selection_mask:
Expand Down

0 comments on commit 023addd

Please sign in to comment.