diff --git a/src/base_trainer.py b/src/base_trainer.py index 931684a..1658c0e 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -102,7 +102,7 @@ def _train_val_iteration( x, y = batch y_hat = self._model(x) losses: Dict[str, torch.Tensor] = self._training_loss(y, y_hat) - loss: torch.Tensor = torch.sum(torch.tensor([v for v in losses.values()])) + loss: torch.Tensor = sum(list(losses.values())) # type: ignore return loss, losses def _train_epoch(