Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated pargen.utils.load_data to use LH5Iterator and field_mask to be more memory efficient #589

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 84 additions & 70 deletions src/pygama/pargen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from lgdo import lh5

log = logging.getLogger(__name__)
sto = lh5.LH5Store()


def convert_to_minuit(pars, func):
Expand All @@ -35,101 +34,116 @@ def return_nans(input):
return m.values, m.errors, np.full((len(m.values), len(m.values)), np.nan)


def get_params(file_params, param_list):
out_params = []
if isinstance(file_params, dict):
possible_keys = file_params.keys()
elif isinstance(file_params, list):
possible_keys = file_params
for param in param_list:
for key in possible_keys:
if key in param:
out_params.append(key)
return np.unique(out_params).tolist()


def load_data(
files: list,
files: str | list | dict,
lh5_path: str,
cal_dict: dict,
params: list,
params: set,
cal_energy_param: str = "cuspEmax_ctc_cal",
threshold=None,
return_selection_mask=False,
) -> tuple(np.array, np.array, np.array, np.array):
) -> pd.DataFrame | tuple(pd.DataFrame, np.array):
"""
Loads in the A/E parameters needed and applies calibration constants to energy
Loads parameters from data files. Applies calibration to cal_energy_param
and uses this to apply a lower energy threshold.

files
file or list of files or dict pointing from timestamps to lists of files
lh5_path
path to table in files
cal_dict
dictionary with operations used to apply calibration constants
params
list of parameters to load from file
cal_energy_param
name of uncalibrated energy parameter
threshold
lower energy threshold for events to load
return_selection_map
if True, return selection mask for threshold along with data
"""

params = set(params)
if isinstance(files, str):
files = [files]

if isinstance(files, dict):
keys = lh5.ls(
files[list(files)[0]][0],
lh5_path if lh5_path[-1] == "/" else lh5_path + "/",
)
keys = [key.split("/")[-1] for key in keys]
if list(files)[0] in cal_dict:
params = get_params(keys + list(cal_dict[list(files)[0]].keys()), params)
else:
params = get_params(keys + list(cal_dict.keys()), params)

# Go through each tstamp and recursively load_data on file lists
df = []
all_files = []
masks = np.array([], dtype=bool)
masks = []
for tstamp, tfiles in files.items():
table = sto.read(lh5_path, tfiles)[0]

file_df = pd.DataFrame(columns=params)
if tstamp in cal_dict:
cal_dict_ts = cal_dict[tstamp]
file_df = load_data(
tfiles,
lh5_path,
cal_dict.get(tstamp, cal_dict),
params,
cal_energy_param,
threshold,
return_selection_mask,
)

if return_selection_mask:
file_df[0]["run_timestamp"] = np.full(
len(file_df[0]), tstamp, dtype=object
)
df.append(file_df[0])
masks.append(file_df[1])
else:
cal_dict_ts = cal_dict
file_df["run_timestamp"] = np.full(len(file_df), tstamp, dtype=object)
df.append(file_df)

for outname, info in cal_dict_ts.items():
outcol = table.eval(info["expression"], info.get("parameters", None))
table.add_column(outname, outcol)

for param in params:
file_df[param] = table[param]
df = pd.concat(df)
if return_selection_mask:
masks = np.concatenate(masks)

file_df["run_timestamp"] = np.full(len(file_df), tstamp, dtype=object)
elif isinstance(files, list):
# Get set of available fields between input table and cal_dict
file_keys = lh5.ls(
files[0], lh5_path if lh5_path[-1] == "/" else lh5_path + "/"
)
file_keys = {key.split("/")[-1] for key in file_keys}

if threshold is not None:
mask = file_df[cal_energy_param] > threshold
file_df.drop(np.where(~mask)[0], inplace=True)
else:
mask = np.ones(len(file_df), dtype=bool)
masks = np.append(masks, mask)
df.append(file_df)
all_files += tfiles
# Get set of keys in calibration expressions that show up in file
cal_keys = {
name
for info in cal_dict.values()
for name in compile(info["expression"], "0vbb is real!", "eval").co_names
} & file_keys

params.append("run_timestamp")
df = pd.concat(df)
# Get set of fields to read from files
fields = cal_keys | (file_keys & params)

elif isinstance(files, list):
keys = lh5.ls(files[0], lh5_path if lh5_path[-1] == "/" else lh5_path + "/")
keys = [key.split("/")[-1] for key in keys]
params = get_params(keys + list(cal_dict.keys()), params)

table = sto.read(lh5_path, files)[0]
df = pd.DataFrame(columns=params)
for outname, info in cal_dict.items():
outcol = table.eval(info["expression"], info.get("parameters", None))
table.add_column(outname, outcol)
for param in params:
df[param] = table[param]
lh5_it = lh5.iterator.LH5Iterator(
files, lh5_path, field_mask=fields, buffer_len=100000
)
df_fields = params & (fields | set(cal_dict))
if df_fields != params:
log.debug(
f"load_data(): params not found in data files or cal_dict: {params-df_fields}"
)
df = pd.DataFrame(columns=list(df_fields))

for table, entry, n_rows in lh5_it:
# Evaluate all provided expressions and add to table
for outname, info in cal_dict.items():
table[outname] = table.eval(
info["expression"], info.get("parameters", None)
)

# Copy params in table into dataframe
for par in df:
# First set of entries: allocate enough memory for all entries
if entry == 0:
df[par] = np.resize(table[par], len(lh5_it))
else:
df.loc[entry : entry + n_rows - 1, par] = table[par][:n_rows]

# Evaluate threshold mask and drop events below threshold
if threshold is not None:
masks = df[cal_energy_param] > threshold
df.drop(np.where(~masks)[0], inplace=True)
else:
masks = np.ones(len(df), dtype=bool)
all_files = files

for col in list(df.keys()):
if col not in params:
df.drop(col, inplace=True, axis=1)

log.debug("data loaded")
if return_selection_mask:
Expand Down
2 changes: 1 addition & 1 deletion src/pygama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class NumbaPygamaDefaults(MutableMapping):
"""

def __init__(self) -> None:
self.parallel: bool = getenv_bool("PYGAMA_PARALLEL", default=True)
self.parallel: bool = getenv_bool("PYGAMA_PARALLEL", default=False)
self.fastmath: bool = getenv_bool("PYGAMA_FASTMATH", default=True)

def __getitem__(self, item: str) -> Any:
Expand Down
Loading