Skip to content

Commit

Permalink
fix groupbys
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Sep 10, 2023
1 parent e937250 commit 4652fd5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion skpro/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _average_df(self, df_list, weights=None):
n_df = len(df_list)
df_weighted = [df * w for df, w in zip(df_list, weights)]
df_concat = pd.concat(df_weighted, axis=1, keys=range(n_df))
df_res = df_concat.T.groupby(level=-1).sum()
df_res = df_concat.T.groupby(level=-1).T.sum()
return df_res

def pdf(self, x):
Expand Down
8 changes: 4 additions & 4 deletions skpro/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def evaluate(self, y_true, y_pred, **kwargs):
if self.score_average and multioutput == "uniform_average":
out = out.mean(axis=1).iloc[0] # average over all
if self.score_average and multioutput == "raw_values":
out = out.T.groupby(level=0).mean() # average over scores
out = out.T.groupby(level=0).T.mean() # average over scores
if not self.score_average and multioutput == "uniform_average":
out = out.T.groupby(level=1).mean() # average over variables
out = out.T.groupby(level=1).T.mean() # average over variables
if not self.score_average and multioutput == "raw_values":
out = out # don't average

Expand Down Expand Up @@ -202,9 +202,9 @@ def evaluate_by_index(self, y_true, y_pred, **kwargs):
if self.score_average and multioutput == "uniform_average":
out = out.mean(axis=1) # average over all
if self.score_average and multioutput == "raw_values":
out = out.T.groupby(level=0).mean() # average over scores
out = out.T.groupby(level=0).T.mean() # average over scores
if not self.score_average and multioutput == "uniform_average":
out = out.T.groupby(level=1).mean() # average over variables
out = out.T.groupby(level=1).T.mean() # average over variables
if not self.score_average and multioutput == "raw_values":
out = out # don't average

Expand Down

0 comments on commit 4652fd5

Please sign in to comment.