From c93f5422d91de408de601b9a14067187fd6acf45 Mon Sep 17 00:00:00 2001 From: Dmitri Gavrilov Date: Sat, 8 Jun 2024 18:15:35 -0400 Subject: [PATCH 1/3] Optimized memory use for lazy loading of dask arrays from tiled. --- pyxrf/model/load_data_from_db.py | 84 ++++++++++++++++++++++++++------ 1 file changed, 68 insertions(+), 16 deletions(-) diff --git a/pyxrf/model/load_data_from_db.py b/pyxrf/model/load_data_from_db.py index 8cd982df..5fd6dd55 100644 --- a/pyxrf/model/load_data_from_db.py +++ b/pyxrf/model/load_data_from_db.py @@ -2318,20 +2318,24 @@ def map_data2D_srx_new_tiled( d_xs, d_xs_sum, N_xs, d_xs2, d_xs2_sum, N_xs2 = None, None, 0, None, None, 0 if "xs_fluor" in data_stream0: - d_xs = data_stream0["xs_fluor"] + # The type of loaded data is tiled.client.array.DaskArrayClient + # If the data is not explicitly converted to da.array, then da.sum + # will fill the whole dataset, which is not desirable. This could be + # fixed in the future versions of Tiled. + d_xs = da.array(data_stream0["xs_fluor"]) d_xs_sum = da.sum(d_xs, 2) N_xs = d_xs.shape[2] elif "fluor" in data_stream0: # Old format - d_xs = data_stream0["fluor"] + d_xs = da.array(data_stream0["fluor"]) d_xs_sum = da.sum(d_xs, 2) N_xs = d_xs.shape[2] if "xs_fluor_xs2" in data_stream0: - d_xs2 = data_stream0["xs_fluor_xs2"] + d_xs2 = da.array(data_stream0["xs_fluor_xs2"]) d_xs2_sum = da.sum(d_xs2, 2) N_xs2 = d_xs2.shape[2] elif "fluor_xs2" in data_stream0: # Old format - d_xs2 = data_stream0["fluor_xs2"] + d_xs2 = da.array(data_stream0["fluor_xs2"]) d_xs2_sum = da.sum(d_xs2, 2) N_xs2 = d_xs2.shape[2] @@ -2536,10 +2540,12 @@ def swap_axes(): # Replace NaNs with 0s (in corrupt data rows). loaded_data = {} - loaded_data["det_sum"] = np.nan_to_num(tmp_data_sum.compute()) + # loaded_data["det_sum"] = np.nan_to_num(tmp_data_sum.compute()) + loaded_data["det_sum"] = tmp_data_sum if create_each_det: for i in range(num_det): - loaded_data["det" + str(i + 1)] = np.nan_to_num(da.squeeze(tmp_data[:, :, i, :]).compute()) + # loaded_data["det" + str(i + 1)] = np.nan_to_num(da.squeeze(tmp_data[:, :, i, :]).compute()) + loaded_data["det" + str(i + 1)] = da.squeeze(tmp_data[:, :, i, :]) if save_scaler: loaded_data["scaler_data"] = sclr.compute() @@ -3519,6 +3525,8 @@ def save_data_to_hdf5( Failed to write data to HDF5 file. """ + time_start = ttime.time() + fpath = os.path.expanduser(fpath) fpath = os.path.abspath(fpath) @@ -3539,7 +3547,7 @@ def incorrect_type_msg(channel, data_type): f"The data is converted from '{data_type}' to 'np.float32' before saving to file." ) - if "det_sum" in data and isinstance(data["det_sum"], np.ndarray): + if "det_sum" in data and isinstance(data["det_sum"], (np.ndarray, da.core.Array)): if data["det_sum"].dtype != np.float32: incorrect_type_msg("det_sum", data["det_sum"].dtype) data["det_sum"] = data["det_sum"].astype(np.float32, copy=False) @@ -3547,14 +3555,15 @@ def incorrect_type_msg(channel, data_type): sum_data_exists = True for detname in xrf_det_list: - if detname in data and isinstance(data[detname], np.ndarray): + if detname in data and isinstance(data[detname], (np.ndarray, da.core.Array)): if data[detname].dtype != np.float32: incorrect_type_msg(detname, data[detname].dtype) data[detname] = data[detname].astype(np.float32, copy=False) if not sum_data_exists: # Don't compute it if it already exists if sum_data is None: - sum_data = np.copy(data[detname]) + # sum_data = np.copy(data[detname]) + sum_data = data[detname].copy() else: sum_data += data[detname] @@ -3598,18 +3607,59 @@ def incorrect_type_msg(channel, data_type): for key, value in metadata_prepared.items(): metadata_grp.attrs[key] = value + # The following parameters control how data is loaded from Tiled + n_pixels_in_batch = 40000 + n_tiled_download_retries = 10 + + def compute_batch_params(shape): + n_rows, n_cols = shape[0], shape[1] + n_rows_batch = max(int(n_pixels_in_batch / n_cols), 1) # Save at least one row + n_batches = int(n_rows / n_rows_batch) + if n_rows % n_rows_batch: + n_batches += 1 + return n_rows, n_rows_batch, n_batches + + def download_dataset(dset, data): + n_rows, n_rows_batch, n_batches = compute_batch_params(sum_data.shape) + for n in range(n_batches): + ns, ne = n * n_rows_batch, min((n + 1) * n_rows_batch, n_rows) + for retry in range(n_tiled_download_retries): + try: + dset[ns:ne, ...] = np.array(data[ns:ne, ...]) + break + except Exception as ex: + logger.error(f"Failed to load the batch: {ex}") + if retry >= n_tiled_download_retries - 1: + raise TimeoutError("Failed to download data from Tiled server") + print(f" Number of saved rows: {ne}") + if create_each_det is True: for detname in xrf_det_list: - new_data = data[detname] - dataGrp = f.create_group(interpath + "/" + detname) - ds_data = dataGrp.create_dataset("counts", data=new_data, compression="gzip") - ds_data.attrs["comments"] = "Experimental data from {}".format(detname) + if not isinstance(sum_data, da.core.Array): + new_data = data[detname] + dataGrp = f.create_group(interpath + "/" + detname) + ds_data = dataGrp.create_dataset("counts", data=new_data, compression="gzip") + ds_data.attrs["comments"] = "Experimental data from {}".format(detname) + else: + new_data = data[detname] + dataGrp = f.create_group(interpath + "/" + detname) + ds_data = dataGrp.create_dataset("counts", new_data.shape, compression="gzip") + print(f"Downloading data: channel {detname!r} ...") + download_dataset(ds_data, new_data) + ds_data.attrs["comments"] = "Experimental data from {}".format(detname) # summed data if sum_data is not None: - dataGrp = f.create_group(interpath + "/detsum") - ds_data = dataGrp.create_dataset("counts", data=sum_data, compression="gzip") - ds_data.attrs["comments"] = "Experimental data from channel sum" + if not isinstance(sum_data, da.core.Array): + dataGrp = f.create_group(interpath + "/detsum") + ds_data = dataGrp.create_dataset("counts", data=sum_data, compression="gzip") + ds_data.attrs["comments"] = "Experimental data from channel sum" + else: + dataGrp = f.create_group(interpath + "/detsum") + ds_data = dataGrp.create_dataset("counts", sum_data.shape, compression="gzip") + print("Downloading data: the sum of all channels ...") + download_dataset(ds_data, sum_data) + ds_data.attrs["comments"] = "Experimental data from channel sum" # add positions if "pos_names" in data: @@ -3627,6 +3677,8 @@ def incorrect_type_msg(channel, data_type): dataGrp.create_dataset("name", data=helper_encode_list(scaler_names)) dataGrp.create_dataset("val", data=scaler_data) + logger.info(f"Total data saving time: {ttime.time() - time_start}") + return fpath From 7ed934757be44616cc0e39b0fa3adf56bb3cc2ae Mon Sep 17 00:00:00 2001 From: Dmitri Gavrilov Date: Sat, 8 Jun 2024 18:35:45 -0400 Subject: [PATCH 2/3] ENH: replace NaNs with 0s --- pyxrf/model/load_data_from_db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyxrf/model/load_data_from_db.py b/pyxrf/model/load_data_from_db.py index 5fd6dd55..278a3da9 100644 --- a/pyxrf/model/load_data_from_db.py +++ b/pyxrf/model/load_data_from_db.py @@ -3625,7 +3625,7 @@ def download_dataset(dset, data): ns, ne = n * n_rows_batch, min((n + 1) * n_rows_batch, n_rows) for retry in range(n_tiled_download_retries): try: - dset[ns:ne, ...] = np.array(data[ns:ne, ...]) + dset[ns:ne, ...] = np.nan_to_num(np.array(data[ns:ne, ...])) break except Exception as ex: logger.error(f"Failed to load the batch: {ex}") From dcc11c4f382a9e02d0eddbd9c563676ae0bdf65e Mon Sep 17 00:00:00 2001 From: Dmitri Gavrilov Date: Sat, 8 Jun 2024 20:14:29 -0400 Subject: [PATCH 3/3] FIX: replaced deprecated BrokenBarHCollection with PolyCollection --- pyxrf/model/lineplot.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyxrf/model/lineplot.py b/pyxrf/model/lineplot.py index f3c961d7..ac983343 100644 --- a/pyxrf/model/lineplot.py +++ b/pyxrf/model/lineplot.py @@ -10,7 +10,7 @@ import numpy as np from atom.api import Atom, Bool, Dict, Float, Int, List, Str, Typed, observe from matplotlib.axes import Axes -from matplotlib.collections import BrokenBarHCollection +from matplotlib.collections import PolyCollection from matplotlib.colors import LogNorm from matplotlib.figure import Figure from matplotlib.lines import Line2D @@ -122,7 +122,7 @@ class LinePlotModel(Atom): _fig_preview = Typed(Figure) _ax_preview = Typed(Axes) _lines_preview = List() - _bahr_preview = Typed(BrokenBarHCollection) + _bahr_preview = Typed(PolyCollection) plot_type_preview = Typed(PlotTypes) energy_range_preview = Typed(EnergyRangePresets) @@ -177,7 +177,7 @@ class LinePlotModel(Atom): show_exp_opt = Bool(False) # Flag: show spectrum preview # Reference to artist responsible for displaying the selected range of energies on the plot - plot_energy_barh = Typed(BrokenBarHCollection) + plot_energy_barh = Typed(PolyCollection) t_bar = Typed(object) plot_exp_list = List() @@ -727,7 +727,7 @@ def plot_selected_energy_range_original(self, *, e_low=None, e_high=None): self.plot_energy_barh.remove() # Create the new plot (based on new parameters if necessary - self.plot_energy_barh = BrokenBarHCollection.span_where( + self.plot_energy_barh = PolyCollection.span_where( x_v, ymin=y_min, ymax=y_max, where=ss, facecolor="white", edgecolor="yellow", alpha=1 ) self._ax.add_collection(self.plot_energy_barh) @@ -1471,7 +1471,7 @@ def plot_selected_energy_range(self, *, axes, barh_existing, e_low=None, e_high= barh_existing.remove() # Create the new plot (based on new parameters if necessary - barh_new = BrokenBarHCollection.span_where( + barh_new = PolyCollection.span_where( x_v, ymin=y_min, ymax=y_max, where=ss, facecolor="white", edgecolor="yellow", alpha=1 ) axes.add_collection(barh_new)