From 15c3dc094988298eac563e1bb0861322e076a5e5 Mon Sep 17 00:00:00 2001 From: haddadanas Date: Tue, 27 Aug 2024 14:05:31 +0200 Subject: [PATCH] HIstScatterPlot --- hbt/tasks/plotting.py | 128 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 124 insertions(+), 4 deletions(-) diff --git a/hbt/tasks/plotting.py b/hbt/tasks/plotting.py index f810af01..394fe95e 100644 --- a/hbt/tasks/plotting.py +++ b/hbt/tasks/plotting.py @@ -261,6 +261,7 @@ def get_input_as_df(self) -> pd.DataFrame: # create a column for each category events_dict = {} + dataset_map = {"tt_sl_powheg": "tt", "tt_dl_powheg": "tt", "tt_fh_powheg": "tt"} # Merge the tt datasets # read the data for cat in category_insts: events_dict[cat.name] = pd.DataFrame() @@ -271,7 +272,7 @@ def get_input_as_df(self) -> pd.DataFrame: temp = pd.DataFrame() for r in route: temp[r.string_column] = r.apply(events[mask]).to_numpy() - temp["dataset"] = dataset + temp["dataset"] = dataset_map.get(dataset, dataset) events_dict[cat.name] = pd.concat([events_dict[cat.name], temp], ignore_index=True) return events_dict @@ -279,6 +280,9 @@ def get_plot_parameters(self: PlotScatterPlots, variable_tuple: tuple) -> dict: x_inst, y_inst = self.get_variable_insts(variable_tuple) params = super().get_plot_parameters(variable_tuple) params["hue"] = "dataset" + # params["weights"] = "weights" + params["common_norm"] = False + params["levels"] = [0.1705, 0.341, 0.682, 0.954, 0.997] params["log_scale"] = tuple(var.log_x for var in (x_inst, y_inst)) params.update(self.general_settings) return params @@ -294,6 +298,8 @@ def make_pretty( fig.suptitle(plt_title, size=35, va="top", ha="center") ax.set_xlabel(x_inst.x_title, fontsize=ax.xaxis.label.get_size() + 4) ax.set_ylabel(y_inst.x_title, fontsize=ax.yaxis.label.get_size() + 4) + ax.set_xlim(x_inst.x_min, x_inst.x_max) + ax.set_ylim(y_inst.x_min, y_inst.x_max) return fig, ax def output(self) -> dict[str, list]: @@ -322,11 +328,14 @@ def run(self): print("✅") continue variable_tuple = self.variable_tuples[variable] - column_names = {self.get_variable_insts(variable).expression for variable in variable_tuple} + column_names = tuple(self.get_variable_insts(variable).expression for variable in variable_tuple) + sel_events = events[category] + for c in column_names: + sel_events = sel_events.loc[sel_events[c] > 0] # call the plot function fig, ax = self.call_plot_func( self.plot_function, - **self.get_data_args(events[category], *column_names), + **self.get_data_args(sel_events, *column_names), **self.get_plot_parameters(variable_tuple), ) # make the plot prettier @@ -341,6 +350,104 @@ def run(self): print("└── Plotting completed.") +class PlotHistScatter(PlotKDEPlots): + + plot_function = PlotBase.plot_function.copy( + default="seaborn.scatterplot", + add_default_to_description=True, + description="the full path given using the dot notation of the desired plot function.", + ) + + base_dataset = luigi.Parameter( + default="hh_ggf_hbb_htt_kl1_kt1_powheg", + description="The dataset to be used as the base for the KDE plot.", + ) + + base_plot = luigi.Parameter( + default="seaborn.histplot", + description="The plot function to be used for the base dataset.", + ) + + def update_plot_kwargs(self: PlotHistScatter, kwargs: dict) -> dict: + kwargs_new = super().update_plot_kwargs(kwargs) + kwargs_new["color"] = kwargs["color"] + if "marker" in kwargs: + kwargs_new["marker"] = kwargs["marker"] + if "cbar" in kwargs: + kwargs_new["cbar"] = kwargs["cbar"] + kwargs_new.pop("hue") + return kwargs_new + + def call_plot_func(self: PlotScatterPlots, func_name: str, ax, **kwargs) -> tuple[plt.Figure, plt.Axes]: + return PlotBaseHBT.call_plot_func(self, func_name, ax=ax, **kwargs) + + def make_pretty( + self: PlotScatterPlots, + fig: plt.Figure, + axs: tuple, + variable_tuple: tuple, + plt_title: str = "", + ) -> tuple[plt.Figure, plt.Axes]: + for ax in axs: + fig, ax = super().make_pretty(fig, ax, variable_tuple, plt_title) + return fig, axs + + @law.decorator.log + @view_output_plots + def run(self): + + events = self.get_input_as_df() + for category in self.categories: + print(f"├── plotting in {category}") + + for variable in self.variables: + print(f"│ ├── Plotting variable: {variable}", end=" ", flush=True) + if all([f.complete() for f in self.output()[category][variable]]): + print("✅") + continue + variable_tuple = self.variable_tuples[variable] + column_names = tuple(self.get_variable_insts(variable).expression for variable in variable_tuple) + sel_events = events[category] + for c in column_names: + sel_events = sel_events.loc[sel_events[c] > 0] + sel_events = self.get_weights(sel_events) + base_events = sel_events.loc[sel_events["dataset"] == self.base_dataset] + sel_events = sel_events.loc[sel_events["dataset"] != self.base_dataset] + # call the plot function + plt.style.use(mplhep.style.CMS) + plt.rcParams.update({"legend.facecolor": "white"}) + fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharey="row", figsize=(35, 15)) + for ax, color in zip((ax1, ax2), ("b", "k")): + self.call_plot_func( + self.base_plot, + ax, + # cbar=True, + # thresh=5, + # color="k", + **self.get_data_args(base_events, *column_names), + **self.get_plot_parameters(variable_tuple), + ) + self.call_plot_func( + self.plot_function, + ax2, + # marker="+", + # color="r", + # legend="full", + **self.get_data_args(sel_events, *column_names), + **self.get_plot_parameters(variable_tuple), + ) + # make the plot prettier + plt_title = f"({category})" + fig, ax = self.make_pretty(fig, (ax1, ax2), variable_tuple, plt_title) + + # save the outputs + for outp in self.output()[category][variable]: + outp.dump(fig, formatter="mpl", dpi=150, bbox_inches="tight") + print("✅") + + print("└── Plotting completed.") + + class PlotFancyPlots(PlotBaseHBT): """ Task to plot scatter plots of the selection results. @@ -549,4 +656,17 @@ def call_plot_func(self, func_name: str, data, **kwargs) -> Any: bin_centers = [(bin_edges[i] + bin_edges[i + 1]) / 2. for i in range(len(bin_edges) - 1)] ax1.step(bin_centers, eff, where="mid", color="black") ax1.set_ylabel("Efficiency") - return fig, ax0 + return fig, (ax0, ax1) + + def make_pretty( + self, + fig: plt.Figure, + axs: tuple, + variable_tuple: tuple, + plt_title: str = "", + effeciency: str = "", + ) -> tuple[plt.Figure, plt.Axes]: + fig, ax0 = super().make_pretty(fig, axs[0], variable_tuple, plt_title, effeciency) + ax1 = axs[1] + ax1.set_xlim(ax0.get_xlim()) + return fig, (ax0, ax1)