Skip to content

Commit

Permalink
Clean up and reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
sumalreddy17 committed Nov 21, 2024
1 parent 6c36bf3 commit b295b14
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 81 deletions.
11 changes: 0 additions & 11 deletions python/src/robyn/reporting/onepager_reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def _get_model_info(self, solution_id: str) -> Dict[str, str]:
try:
# Get media share data similar to R code
x_decomp_agg = self.pareto_result.x_decomp_agg
print(f"x_decomp_agg= {x_decomp_agg}")
plot_media_share = x_decomp_agg[
(x_decomp_agg["sol_id"] == solution_id)
& (
Expand All @@ -146,7 +145,6 @@ def _get_model_info(self, solution_id: str) -> Dict[str, str]:
)
)
]
print(f"plot_media_share= {plot_media_share}")
if plot_media_share.empty:
raise ValueError(
f"No media share data found for solution {solution_id}"
Expand All @@ -162,7 +160,6 @@ def _get_model_info(self, solution_id: str) -> Dict[str, str]:
metrics["nrmse_train"] = self._safe_format(
plot_media_share["nrmse_train"].iloc[0]
)
print(f"metrics= {metrics}")

# Get validation metrics if available
if "rsq_val" in plot_media_share.columns:
Expand All @@ -172,7 +169,6 @@ def _get_model_info(self, solution_id: str) -> Dict[str, str]:
metrics["nrmse_val"] = self._safe_format(
plot_media_share["nrmse_val"].iloc[0]
)
print(f"metrics with rsq_val= {metrics}")

# Get test metrics if available
if "rsq_test" in plot_media_share.columns:
Expand All @@ -182,24 +178,20 @@ def _get_model_info(self, solution_id: str) -> Dict[str, str]:
metrics["nrmse_test"] = self._safe_format(
plot_media_share["nrmse_test"].iloc[0]
)
print(f"metrics with rsq_test= {metrics}")

# Get decomp.rssd
metrics["decomp_rssd"] = self._safe_format(
plot_media_share["decomp.rssd"].iloc[0]
)
print(f"metrics with decomp_rssd= {metrics}")

# Get MAPE if available
if "mape" in plot_media_share.columns:
metrics["mape"] = self._safe_format(plot_media_share["mape"].iloc[0])
print(f"metrics with mape= {metrics}")

# Get train size
metrics["train_size"] = self._safe_format(
plot_media_share["train_size"].iloc[0]
)
print(f"metrics with train_size= {metrics}")

# Calculate performance (ROAS/CPA)
dep_var_type = self.mmm_data.mmmdata_spec.dep_var_type
Expand All @@ -218,8 +210,6 @@ def _get_model_info(self, solution_id: str) -> Dict[str, str]:
.agg({"xDecompAgg": "sum", "total_spend": "sum"})
)

print(f"perf = {perf}")

if not perf.empty:
if type_metric == "ROAS":
performance = (
Expand Down Expand Up @@ -257,7 +247,6 @@ def _get_model_info(self, solution_id: str) -> Dict[str, str]:
metrics_text += f" | {metrics['performance']}"

metrics["formatted_text"] = metrics_text
print(f"metrics = {metrics}")
return metrics

except Exception as e:
Expand Down
102 changes: 62 additions & 40 deletions python/src/robyn/visualization/pareto_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def generate_waterfall(

# Calculate x-position as the middle of the bar
x_pos = (row.start + row.end) / 2

# Use y_pos[idx] to ensure alignment with bars
ax.text(
x_pos,
Expand Down Expand Up @@ -234,18 +234,18 @@ def generate_fitted_vs_actual(
"""

logger.debug("Starting generation of fitted vs actual plot")

if solution_id not in self.pareto_result.plot_data_collect:
raise ValueError(f"Invalid solution ID: {solution_id}")

# Get data for specific solution
# Get data for specific solution
plot_data = self.pareto_result.plot_data_collect[solution_id]
ts_data = plot_data["plot5data"]["xDecompVecPlotMelted"].copy()

# Ensure ds column is datetime and remove any NaT values
ts_data["ds"] = pd.to_datetime(ts_data["ds"])
ts_data = ts_data.dropna(subset=["ds"]) # Remove rows with NaT dates

if ts_data.empty:
logger.warning(f"No valid date data found for solution {solution_id}")
return None
Expand All @@ -259,7 +259,7 @@ def generate_fitted_vs_actual(
train_size_series = self.pareto_result.x_decomp_agg[
self.pareto_result.x_decomp_agg["sol_id"] == solution_id
]["train_size"]

if not train_size_series.empty:
train_size = float(train_size_series.iloc[0])
else:
Expand All @@ -270,6 +270,11 @@ def generate_fitted_vs_actual(
else:
fig = None

colors = {
"Actual": "#FF6B00", # Darker orange
"Predicted": "#0066CC", # Darker blue
}

# Plot lines with different styles for predicted vs actual
for var in ts_data["variable"].unique():
var_data = ts_data[ts_data["variable"] == var]
Expand All @@ -279,13 +284,13 @@ def generate_fitted_vs_actual(
var_data["value"],
label=var,
linestyle=linestyle,
linewidth=0.6,
color='orange' if var == 'Actual' else 'lightblue'
linewidth=1,
color=colors[var],
)

# Format y-axis with abbreviations
ax.yaxis.set_major_formatter(ticker.FuncFormatter(self.format_number))

# Set y-axis limits with some padding
y_min, y_max = ax.get_ylim()
ax.set_ylim(y_min, y_max * 1.2) # Add 20% padding at the top
Expand All @@ -296,40 +301,43 @@ def generate_fitted_vs_actual(
# Get unique sorted dates, excluding NaT
unique_dates = sorted(ts_data["ds"].dropna().unique())
total_days = len(unique_dates)

if total_days > 0:
# Calculate split points
train_cut = int(total_days * train_size)
val_cut = train_cut + int(total_days * (1 - train_size) / 2)

# Get dates for splits
splits = [
(train_cut, "Train", train_size),
(val_cut, "Validation", (1 - train_size) / 2),
(total_days - 1, "Test", (1 - train_size) / 2)
(total_days - 1, "Test", (1 - train_size) / 2),
]

# Get y-axis limits for text placement
y_min, y_max = ax.get_ylim()

# Add vertical lines and labels
for idx, label, size in splits:
if 0 <= idx < len(unique_dates): # Ensure index is valid
date = unique_dates[idx]
if pd.notna(date): # Check if date is valid
# Add vertical line - extend beyond the top of the plot
ax.axvline(date, color="#39638b", alpha=0.8, ymin=0, ymax=1.1)

ax.axvline(
date, color="#39638b", alpha=0.8, ymin=0, ymax=1.1
)

# Add rotated text label
ax.text(
date, y_max,
date,
y_max,
f"{label}: {size*100:.1f}%",
rotation=270,
color="#39638b",
alpha=0.5,
size=9,
ha='left',
va='top'
ha="left",
va="top",
)
except Exception as e:
logger.warning(f"Error adding split lines: {str(e)}")
Expand All @@ -342,13 +350,15 @@ def generate_fitted_vs_actual(

# Configure legend
ax.legend(
bbox_to_anchor=(0, 1.02, 1, 0.1),
bbox_to_anchor=(0.01, 1.02), # Position at top-left
loc="lower left",
ncol=2,
mode="expand",
ncol=2, # Two columns side by side
borderaxespad=0,
frameon=False,
fontsize=7
fontsize=7,
handlelength=2, # Length of the legend lines
handletextpad=0.5, # Space between line and text
columnspacing=1.0, # Space between columns
)

# Grid styling
Expand All @@ -358,7 +368,7 @@ def generate_fitted_vs_actual(

# Format dates on x-axis using datetime locator and formatter
years = mdates.YearLocator()
years_fmt = mdates.DateFormatter('%Y')
years_fmt = mdates.DateFormatter("%Y")
ax.xaxis.set_major_locator(years)
ax.xaxis.set_major_formatter(years_fmt)

Expand Down Expand Up @@ -474,13 +484,15 @@ def generate_immediate_vs_carryover(

plot_data = self.pareto_result.plot_data_collect[solution_id]
df_imme_caov = plot_data["plot7data"].copy()

# Ensure percentage is numeric
df_imme_caov['percentage'] = pd.to_numeric(df_imme_caov['percentage'], errors='coerce')
df_imme_caov["percentage"] = pd.to_numeric(
df_imme_caov["percentage"], errors="coerce"
)

# Sort channels alphabetically
df_imme_caov = df_imme_caov.sort_values("rn", ascending=True)

# Sort channels alphabetically
df_imme_caov = df_imme_caov.sort_values('rn', ascending=True)

# Set up type factor levels matching R plot order
df_imme_caov["type"] = pd.Categorical(
df_imme_caov["type"], categories=["Immediate", "Carryover"], ordered=True
Expand All @@ -500,10 +512,12 @@ def generate_immediate_vs_carryover(

# Normalize percentages to sum to 100% for each channel
for channel in channels:
mask = df_imme_caov['rn'] == channel
total = df_imme_caov.loc[mask, 'percentage'].sum()
mask = df_imme_caov["rn"] == channel
total = df_imme_caov.loc[mask, "percentage"].sum()
if total > 0: # Avoid division by zero
df_imme_caov.loc[mask, 'percentage'] = df_imme_caov.loc[mask, 'percentage'] / total
df_imme_caov.loc[mask, "percentage"] = (
df_imme_caov.loc[mask, "percentage"] / total
)

for type_name in types:
type_data = df_imme_caov[df_imme_caov["type"] == type_name]
Expand Down Expand Up @@ -539,12 +553,12 @@ def generate_immediate_vs_carryover(
ax.legend(
title=None,
bbox_to_anchor=(0, 1.02, 0.15, 0.1), # Reduced width from 0.3 to 0.2
loc="lower left",
loc="lower left",
ncol=2,
mode="expand",
borderaxespad=0,
frameon=False,
fontsize=7 # Reduced from 8 to 7
fontsize=7, # Reduced from 8 to 7
)

ax.set_xlabel("% Response")
Expand Down Expand Up @@ -587,10 +601,10 @@ def generate_adstock_rate(

if self.adstock == AdstockType.GEOMETRIC:
dt_geometric = adstock_data["dt_geometric"].copy()

# Sort data alphabetically by channel
dt_geometric = dt_geometric.sort_values('channels', ascending=True)
dt_geometric = dt_geometric.sort_values("channels", ascending=True)

bars = ax.barh(
y=range(len(dt_geometric)),
width=dt_geometric["thetas"],
Expand All @@ -607,13 +621,19 @@ def generate_adstock_rate(
ax.set_yticklabels(dt_geometric["channels"])

# Format x-axis with 25% increments
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f"{x*100:.0f}%"))
ax.xaxis.set_major_formatter(
plt.FuncFormatter(lambda x, p: f"{x*100:.0f}%")
)
ax.set_xlim(0, 1)
ax.set_xticks(np.arange(0, 1.25, 0.25)) # Changed to 0.25 increments

# Set title and labels
interval_type = self.mmm_data.mmmdata_spec.interval_type if self.mmm_data else "day"
ax.set_title(f"Geometric Adstock: Fixed Rate Over Time (Solution {solution_id})")
interval_type = (
self.mmm_data.mmmdata_spec.interval_type if self.mmm_data else "day"
)
ax.set_title(
f"Geometric Adstock: Fixed Rate Over Time (Solution {solution_id})"
)
ax.set_xlabel(f"Thetas [by {interval_type}]")
ax.set_ylabel(None)

Expand All @@ -622,7 +642,9 @@ def generate_adstock_rate(
weibull_data = adstock_data["weibullCollect"]
wb_type = adstock_data["wb_type"]

channels = sorted(weibull_data["channel"].unique()) # Sort channels alphabetically
channels = sorted(
weibull_data["channel"].unique()
) # Sort channels alphabetically
rows = (len(channels) + 2) // 3

if ax is None:
Expand Down
Loading

0 comments on commit b295b14

Please sign in to comment.