Skip to content

Commit

Permalink
Fix ensemble model predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
sivonxay committed Dec 19, 2023
1 parent 32ce4d2 commit 8c51a00
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/NanoParticleTools/machine_learning/modules/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def ensemble_forward(self, data: Data,
output.append(y_hat)

x = torch.cat(output, dim=-1)
return {'y': x, 'y_hat': x.mean(-1), 'std': x.std()}
return {'y': x, 'y_hat': x.mean(-1), 'std': x.std(-1)}

def evaluate_step(self, data: Data) -> tuple[torch.Tensor, torch.Tensor]:
output = []
Expand All @@ -52,6 +52,6 @@ def predict_step(

x = torch.cat(output, dim=-1)
if return_stats:
return {'y': x, 'y_hat': x.mean(-1), 'std': x.std()}
return {'y': x, 'y_hat': x.mean(-1), 'std': x.std(-1)}
else:
return x.mean(-1)

0 comments on commit 8c51a00

Please sign in to comment.