Skip to content

Commit

Permalink
HIstScatterPlot
Browse files Browse the repository at this point in the history
  • Loading branch information
haddadanas committed Aug 27, 2024
1 parent 750b4cb commit 15c3dc0
Showing 1 changed file with 124 additions and 4 deletions.
128 changes: 124 additions & 4 deletions hbt/tasks/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def get_input_as_df(self) -> pd.DataFrame:

# create a column for each category
events_dict = {}
dataset_map = {"tt_sl_powheg": "tt", "tt_dl_powheg": "tt", "tt_fh_powheg": "tt"} # Merge the tt datasets
# read the data
for cat in category_insts:
events_dict[cat.name] = pd.DataFrame()
Expand All @@ -271,14 +272,17 @@ def get_input_as_df(self) -> pd.DataFrame:
temp = pd.DataFrame()
for r in route:
temp[r.string_column] = r.apply(events[mask]).to_numpy()
temp["dataset"] = dataset
temp["dataset"] = dataset_map.get(dataset, dataset)
events_dict[cat.name] = pd.concat([events_dict[cat.name], temp], ignore_index=True)
return events_dict

def get_plot_parameters(self: PlotScatterPlots, variable_tuple: tuple) -> dict:
x_inst, y_inst = self.get_variable_insts(variable_tuple)
params = super().get_plot_parameters(variable_tuple)
params["hue"] = "dataset"
# params["weights"] = "weights"
params["common_norm"] = False
params["levels"] = [0.1705, 0.341, 0.682, 0.954, 0.997]
params["log_scale"] = tuple(var.log_x for var in (x_inst, y_inst))
params.update(self.general_settings)
return params
Expand All @@ -294,6 +298,8 @@ def make_pretty(
fig.suptitle(plt_title, size=35, va="top", ha="center")
ax.set_xlabel(x_inst.x_title, fontsize=ax.xaxis.label.get_size() + 4)
ax.set_ylabel(y_inst.x_title, fontsize=ax.yaxis.label.get_size() + 4)
ax.set_xlim(x_inst.x_min, x_inst.x_max)
ax.set_ylim(y_inst.x_min, y_inst.x_max)
return fig, ax

def output(self) -> dict[str, list]:
Expand Down Expand Up @@ -322,11 +328,14 @@ def run(self):
print("βœ…")
continue
variable_tuple = self.variable_tuples[variable]
column_names = {self.get_variable_insts(variable).expression for variable in variable_tuple}
column_names = tuple(self.get_variable_insts(variable).expression for variable in variable_tuple)
sel_events = events[category]
for c in column_names:
sel_events = sel_events.loc[sel_events[c] > 0]
# call the plot function
fig, ax = self.call_plot_func(
self.plot_function,
**self.get_data_args(events[category], *column_names),
**self.get_data_args(sel_events, *column_names),
**self.get_plot_parameters(variable_tuple),
)
# make the plot prettier
Expand All @@ -341,6 +350,104 @@ def run(self):
print("└── Plotting completed.")


class PlotHistScatter(PlotKDEPlots):

plot_function = PlotBase.plot_function.copy(
default="seaborn.scatterplot",
add_default_to_description=True,
description="the full path given using the dot notation of the desired plot function.",
)

base_dataset = luigi.Parameter(
default="hh_ggf_hbb_htt_kl1_kt1_powheg",
description="The dataset to be used as the base for the KDE plot.",
)

base_plot = luigi.Parameter(
default="seaborn.histplot",
description="The plot function to be used for the base dataset.",
)

def update_plot_kwargs(self: PlotHistScatter, kwargs: dict) -> dict:
kwargs_new = super().update_plot_kwargs(kwargs)
kwargs_new["color"] = kwargs["color"]
if "marker" in kwargs:
kwargs_new["marker"] = kwargs["marker"]
if "cbar" in kwargs:
kwargs_new["cbar"] = kwargs["cbar"]
kwargs_new.pop("hue")
return kwargs_new

def call_plot_func(self: PlotScatterPlots, func_name: str, ax, **kwargs) -> tuple[plt.Figure, plt.Axes]:
return PlotBaseHBT.call_plot_func(self, func_name, ax=ax, **kwargs)

def make_pretty(
self: PlotScatterPlots,
fig: plt.Figure,
axs: tuple,
variable_tuple: tuple,
plt_title: str = "",
) -> tuple[plt.Figure, plt.Axes]:
for ax in axs:
fig, ax = super().make_pretty(fig, ax, variable_tuple, plt_title)
return fig, axs

@law.decorator.log
@view_output_plots
def run(self):

events = self.get_input_as_df()
for category in self.categories:
print(f"β”œβ”€β”€ plotting in {category}")

for variable in self.variables:
print(f"β”‚ β”œβ”€β”€ Plotting variable: {variable}", end=" ", flush=True)
if all([f.complete() for f in self.output()[category][variable]]):
print("βœ…")
continue
variable_tuple = self.variable_tuples[variable]
column_names = tuple(self.get_variable_insts(variable).expression for variable in variable_tuple)
sel_events = events[category]
for c in column_names:
sel_events = sel_events.loc[sel_events[c] > 0]
sel_events = self.get_weights(sel_events)
base_events = sel_events.loc[sel_events["dataset"] == self.base_dataset]
sel_events = sel_events.loc[sel_events["dataset"] != self.base_dataset]
# call the plot function
plt.style.use(mplhep.style.CMS)
plt.rcParams.update({"legend.facecolor": "white"})
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharey="row", figsize=(35, 15))
for ax, color in zip((ax1, ax2), ("b", "k")):
self.call_plot_func(
self.base_plot,
ax,
# cbar=True,
# thresh=5,
# color="k",
**self.get_data_args(base_events, *column_names),
**self.get_plot_parameters(variable_tuple),
)
self.call_plot_func(
self.plot_function,
ax2,
# marker="+",
# color="r",
# legend="full",
**self.get_data_args(sel_events, *column_names),
**self.get_plot_parameters(variable_tuple),
)
# make the plot prettier
plt_title = f"({category})"
fig, ax = self.make_pretty(fig, (ax1, ax2), variable_tuple, plt_title)

# save the outputs
for outp in self.output()[category][variable]:
outp.dump(fig, formatter="mpl", dpi=150, bbox_inches="tight")
print("βœ…")

print("└── Plotting completed.")


class PlotFancyPlots(PlotBaseHBT):
"""
Task to plot scatter plots of the selection results.
Expand Down Expand Up @@ -549,4 +656,17 @@ def call_plot_func(self, func_name: str, data, **kwargs) -> Any:
bin_centers = [(bin_edges[i] + bin_edges[i + 1]) / 2. for i in range(len(bin_edges) - 1)]
ax1.step(bin_centers, eff, where="mid", color="black")
ax1.set_ylabel("Efficiency")
return fig, ax0
return fig, (ax0, ax1)

def make_pretty(
self,
fig: plt.Figure,
axs: tuple,
variable_tuple: tuple,
plt_title: str = "",
effeciency: str = "",
) -> tuple[plt.Figure, plt.Axes]:
fig, ax0 = super().make_pretty(fig, axs[0], variable_tuple, plt_title, effeciency)
ax1 = axs[1]
ax1.set_xlim(ax0.get_xlim())
return fig, (ax0, ax1)

0 comments on commit 15c3dc0

Please sign in to comment.