diff --git a/src/NanoParticleTools/machine_learning/modules/ensemble.py b/src/NanoParticleTools/machine_learning/modules/ensemble.py index 3306c6f..5815b24 100644 --- a/src/NanoParticleTools/machine_learning/modules/ensemble.py +++ b/src/NanoParticleTools/machine_learning/modules/ensemble.py @@ -28,7 +28,7 @@ def ensemble_forward(self, data: Data, output.append(y_hat) x = torch.cat(output, dim=-1) - return {'y': x, 'y_hat': x.mean(-1), 'std': x.std()} + return {'y': x, 'y_hat': x.mean(-1), 'std': x.std(-1)} def evaluate_step(self, data: Data) -> tuple[torch.Tensor, torch.Tensor]: output = [] @@ -52,6 +52,6 @@ def predict_step( x = torch.cat(output, dim=-1) if return_stats: - return {'y': x, 'y_hat': x.mean(-1), 'std': x.std()} + return {'y': x, 'y_hat': x.mean(-1), 'std': x.std(-1)} else: return x.mean(-1)