Skip to content

Commit

Permalink
Merge pull request #1154 from facebookexperimental/sort_one_pager_gra…
Browse files Browse the repository at this point in the history
…ph_labelling

Fix One pager plots for accuracy and correctness
  • Loading branch information
sumalreddy17 authored Nov 22, 2024
2 parents c6f879b + b295b14 commit 0e601a0
Show file tree
Hide file tree
Showing 5 changed files with 584 additions and 425 deletions.
201 changes: 145 additions & 56 deletions python/src/robyn/reporting/onepager_reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ def _setup_plotting_style(self):
sns.set_theme(style="whitegrid", context="paper")
plt.rcParams.update(
{
"figure.figsize": (22, 17), # Increased figure size
"figure.figsize": (30, 34), # Increased from (22, 17)
"figure.dpi": 100,
"savefig.dpi": 300,
"font.size": 10,
"axes.titlesize": 12,
"axes.labelsize": 10,
"xtick.labelsize": 9,
"ytick.labelsize": 9,
"legend.fontsize": 9,
"figure.titlesize": 14,
"font.size": 16, # Increased from 10
"axes.titlesize": 22, # Increased from 12
"axes.labelsize": 12, # Increased from 10
"xtick.labelsize": 11, # Increased from 9
"ytick.labelsize": 11, # Increased from 9
"legend.fontsize": 11, # Increased from 9
"figure.titlesize": 16, # Increased from 14
"axes.grid": True,
"grid.alpha": 0.3,
"axes.spines.top": False,
Expand Down Expand Up @@ -135,38 +135,129 @@ def _setup_grid(
def _get_model_info(self, solution_id: str) -> Dict[str, str]:
"""Get model performance metrics for specific solution."""
try:
model_data = self.pareto_result.plot_data_collect[solution_id]

# Extract RSQ from plot5data safely
rsq = (
model_data["plot5data"].get("rsq")
if isinstance(model_data["plot5data"], dict)
else (
model_data["plot5data"].rsq.iloc[0]
if hasattr(model_data["plot5data"], "rsq")
else 0
# Get media share data similar to R code
x_decomp_agg = self.pareto_result.x_decomp_agg
plot_media_share = x_decomp_agg[
(x_decomp_agg["sol_id"] == solution_id)
& (
x_decomp_agg["rn"].isin(
self.mmm_data.mmmdata_spec.paid_media_spends
)
)
]
if plot_media_share.empty:
raise ValueError(
f"No media share data found for solution {solution_id}"
)

# Extract metrics following R code logic
metrics = {}

# Get training metrics
metrics["rsq_train"] = self._safe_format(
plot_media_share["rsq_train"].iloc[0]
)
metrics["nrmse_train"] = self._safe_format(
plot_media_share["nrmse_train"].iloc[0]
)

# Get NRMSE and DECOMP.RSSD values safely
nrmse = model_data.get("nrmse", 0)
if isinstance(nrmse, (pd.DataFrame, pd.Series)):
nrmse = (
nrmse.iloc[0] if isinstance(nrmse, pd.Series) else nrmse.iloc[0, 0]
# Get validation metrics if available
if "rsq_val" in plot_media_share.columns:
metrics["rsq_val"] = self._safe_format(
plot_media_share["rsq_val"].iloc[0]
)
metrics["nrmse_val"] = self._safe_format(
plot_media_share["nrmse_val"].iloc[0]
)

decomp_rssd = model_data.get("decomp.rssd", 0)
if isinstance(decomp_rssd, (pd.DataFrame, pd.Series)):
decomp_rssd = (
decomp_rssd.iloc[0]
if isinstance(decomp_rssd, pd.Series)
else decomp_rssd.iloc[0, 0]
# Get test metrics if available
if "rsq_test" in plot_media_share.columns:
metrics["rsq_test"] = self._safe_format(
plot_media_share["rsq_test"].iloc[0]
)
metrics["nrmse_test"] = self._safe_format(
plot_media_share["nrmse_test"].iloc[0]
)

metrics = {
"rsq_train": self._safe_format(rsq),
"nrmse": self._safe_format(nrmse),
"decomp_rssd": self._safe_format(decomp_rssd),
# Get decomp.rssd
metrics["decomp_rssd"] = self._safe_format(
plot_media_share["decomp.rssd"].iloc[0]
)

# Get MAPE if available
if "mape" in plot_media_share.columns:
metrics["mape"] = self._safe_format(plot_media_share["mape"].iloc[0])

# Get train size
metrics["train_size"] = self._safe_format(
plot_media_share["train_size"].iloc[0]
)

# Calculate performance (ROAS/CPA)
dep_var_type = self.mmm_data.mmmdata_spec.dep_var_type
type_metric = "CPA" if dep_var_type == "conversion" else "ROAS"

perf = (
x_decomp_agg[
(x_decomp_agg["sol_id"] == solution_id)
& (
x_decomp_agg["rn"].isin(
self.mmm_data.mmmdata_spec.paid_media_spends
)
)
]
.groupby("sol_id")
.agg({"xDecompAgg": "sum", "total_spend": "sum"})
)

if not perf.empty:
if type_metric == "ROAS":
performance = (
perf["xDecompAgg"].iloc[0] / perf["total_spend"].iloc[0]
)
else: # CPA
performance = (
perf["total_spend"].iloc[0] / perf["xDecompAgg"].iloc[0]
)

metrics["performance"] = f"{performance:.3g} {type_metric}"

# Format the metrics string based on validation availability
if "rsq_val" in metrics:
metrics_text = (
f"Adj.R2: train = {metrics['rsq_train']}, "
f"val = {metrics['rsq_val']}, "
f"test = {metrics['rsq_test']} | "
f"NRMSE: train = {metrics['nrmse_train']}, "
f"val = {metrics['nrmse_val']}, "
f"test = {metrics['nrmse_test']} | "
f"DECOMP.RSSD = {metrics['decomp_rssd']}"
)
else:
metrics_text = (
f"Adj.R2: train = {metrics['rsq_train']} | "
f"NRMSE: train = {metrics['nrmse_train']} | "
f"DECOMP.RSSD = {metrics['decomp_rssd']}"
)

if "mape" in metrics:
metrics_text += f" | MAPE = {metrics['mape']}"

if "performance" in metrics:
metrics_text += f" | {metrics['performance']}"

metrics["formatted_text"] = metrics_text
return metrics

except Exception as e:
logger.error(
f"Error getting model info for solution {solution_id}: {str(e)}"
)
return {
"rsq_train": "0.0000",
"nrmse_train": "0.0000",
"decomp_rssd": "0.0000",
"formatted_text": "Error calculating metrics",
}

if hasattr(self.pareto_result, "mape"):
Expand All @@ -184,8 +275,7 @@ def _get_model_info(self, solution_id: str) -> Dict[str, str]:
def _generate_solution_plots(
self, solution_id: str, plots: List[PlotType], gs: GridSpec
) -> None:
"""
Generate plots for a single solution with dynamic layout.
"""Generate plots for a single solution with dynamic layout.
Args:
solution_id: Solution ID to generate plots for
Expand Down Expand Up @@ -222,7 +312,7 @@ def _generate_solution_plots(
transfor_viz = TransformationVisualizer(self.pareto_result, self.mmm_data)

# Add space at top for title
gs.update(top=0.85)
gs.update(top=0.92)

# TODO: Move the config out of the method to its own data class.

Expand Down Expand Up @@ -319,7 +409,7 @@ def _generate_solution_plots(
)
raise e

# Add model info and titles
# Add model info and titles with adjusted positioning
try:
model_info = self._get_model_info(solution_id)
metrics_text = (
Expand All @@ -333,18 +423,18 @@ def _generate_solution_plots(

fig = gs.figure
fig.suptitle(
f"MMM Analysis One-Pager (Solution {solution_id})",
fontsize=14,
f"MMM Analysis One-Pager for Model: {solution_id})",
fontsize=18,
y=0.98,
)
fig.text(0.5, 0.94, metrics_text, fontsize=12, ha="center")
fig.text(0.5, 0.96, metrics_text, fontsize=18, ha="center")
except Exception as e:
logger.error(
f"Error adding title and metrics for solution {solution_id}: {str(e)}"
)
gs.figure.suptitle(
f"MMM Analysis One-Pager (Solution {solution_id})",
fontsize=14,
f"MMM Analysis One-Pager for Model: {solution_id})",
fontsize=18,
y=0.98,
)

Expand All @@ -359,12 +449,11 @@ def generate_one_pager(
self,
solution_ids: Union[str, List[str]] = "all",
plots: Optional[List[str]] = None,
figsize: tuple = (20, 15),
figsize: tuple = (30, 34), # Reduced height from 36 to 32
save_path: Optional[str] = None,
top_pareto: bool = False,
) -> List[plt.Figure]:
"""
Generate separate one-pager for each solution ID.
"""Generate separate one-pager for each solution ID.
Args:
solution_ids: Single solution ID or list of solution IDs or 'all'
Expand Down Expand Up @@ -478,20 +567,20 @@ def generate_one_pager(

# Adjust layout with improved spacing
fig.set_constrained_layout_pads(
w_pad=0.15, # Increased padding between plots horizontally
h_pad=0.2, # Increased padding between plots vertically
hspace=0.4, # Increased height space between subplots
wspace=0.3, # Increased width space between subplots
w_pad=0.15, # Padding between plots horizontally
h_pad=0.2, # Padding between plots vertically
hspace=0.4, # Height space between subplots
wspace=0.3, # Width space between subplots
)

# Update layout to leave more space for titles and labels
# Update layout with reduced top spacing
plt.subplots_adjust(
top=0.85, # Space for main title
bottom=0.1, # Space for bottom x-labels
left=0.1, # Space for left y-labels
right=0.9, # Space for right margin
hspace=0.4, # Space between plots vertically
wspace=0.3, # Space between plots horizontally
top=0.92, # Increased from 0.88 to reduce top space
bottom=0.08, # Keep the same
left=0.08, # Keep the same
right=0.92, # Keep the same
hspace=0.35, # Keep the same
wspace=0.25, # Keep the same
)

if save_path:
Expand Down
Loading

0 comments on commit 0e601a0

Please sign in to comment.