From e1fd196c1808ec3a59a8e59d86163853c7175669 Mon Sep 17 00:00:00 2001 From: Theo Date: Fri, 31 May 2024 17:35:22 +0100 Subject: [PATCH] Fix runtime bug --- src/base_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/base_trainer.py b/src/base_trainer.py index 931684a..1658c0e 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -102,7 +102,7 @@ def _train_val_iteration( x, y = batch y_hat = self._model(x) losses: Dict[str, torch.Tensor] = self._training_loss(y, y_hat) - loss: torch.Tensor = torch.sum(torch.tensor([v for v in losses.values()])) + loss: torch.Tensor = sum(list(losses.values())) # type: ignore return loss, losses def _train_epoch(