Skip to content

Commit

Permalink
fix: plot only target
Browse files Browse the repository at this point in the history
  • Loading branch information
AzulGarza committed Nov 1, 2023
1 parent 3586fa9 commit 736fe54
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions action_files/models_performance/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Experiment:
def __init__(
self,
df: pd.DataFrame,
experiment_name: str,
id_col: str,
time_col: str,
target_col: str,
Expand All @@ -38,6 +39,8 @@ def __init__(
level: Optional[List[int]] = None,
n_windows: int = 1, # @A: this should be replaced with cross validation
):
self.df = df
self.experiment_name = experiment_name
self.id_col = id_col
self.time_col = time_col
self.target_col = target_col
Expand All @@ -47,14 +50,14 @@ def __init__(
self.level = level
self.n_windows = n_windows
self.eval_index = [
"experiment_name",
"h",
"season_length",
"freq",
"level",
"n_windows",
"metric",
]
self.df = df
(
self.df_train,
self.df_test,
Expand Down Expand Up @@ -199,7 +202,7 @@ def plot_and_save_forecasts(self, cv_df: pd.DataFrame, plot_dir: str) -> str:
df[self.id_col] = "ts_0"
cv_df[self.time_col] = pd.to_datetime(cv_df[self.time_col])
fig = timegpt.plot(
df,
df[[self.id_col, self.time_col, self.target_col]],
cv_df,
max_insample_length=self.h * (self.n_windows + 4),
id_col=self.id_col,
Expand Down Expand Up @@ -258,6 +261,7 @@ def run_experiments(self):
)
exp = Experiment(
df=df,
experiment_name=experiment_name,
id_col=id_col,
time_col=time_col,
target_col=target_col,
Expand Down Expand Up @@ -292,7 +296,6 @@ def run_experiments(self):
)
eval_models_df = pd.concat(eval_models_df, axis=1)
eval_models_df["plot_path"] = plot_path
eval_models_df["experiment"] = experiment_name
eval_df.append(eval_models_df.reset_index())
eval_df = pd.concat(eval_df)
return eval_df, exp.benchmark_models
Expand All @@ -315,12 +318,13 @@ def summary_performance(
"experiment": exp_desc,
}
)
f.write(
f"## Experiment {exp_number}: {exp_metadata['experiment'].iloc[0]}\n"
)
experiment_name = exp_metadata.query("variable == 'experiment_name'")[
"experiment"
].iloc[0]
exp_metadata.query(
"variable not in ['plot_path', 'experiment']", inplace=True
)
f.write(f"## Experiment {exp_number}: {experiment_name}\n\n")
f.write("### Description:\n")
f.write(f"{exp_metadata.to_markdown(index=False)}\n\n")
f.write("### Results:\n")
Expand Down

0 comments on commit 736fe54

Please sign in to comment.