Skip to content

Commit

Permalink
Fix loss accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
DubiousCactus committed May 31, 2024
1 parent c53c587 commit 2f551e8
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/base_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def _test_iteration(
Returns:
torch.Tensor: The loss for the batch.
"""
# x, y = batch # type: ignore
# y_hat = self._model(x)
x, y = batch # type: ignore # noqa
y_hat = self._model(x) # type: ignore # noqa
# TODO: Compute your metrics here!
return {}

Expand Down
2 changes: 1 addition & 1 deletion src/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _train_val_iteration(
x, y = batch
y_hat = self._model(x)
losses: Dict[str, torch.Tensor] = self._training_loss(y, y_hat)
loss = sum([v for v in losses.values()])
loss: torch.Tensor = torch.sum(torch.tensor([v for v in losses.values()]))
return loss, losses

def _train_epoch(
Expand Down
6 changes: 5 additions & 1 deletion src/losses/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@
"""


from typing import Dict

import torch


class MSELoss:
def __init__(self, reduction: str):
self._reduction = reduction

def __call__(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
def __call__(
self, y_pred: torch.Tensor, y_true: torch.Tensor
) -> Dict[str, torch.Tensor]:
return {
"mse": torch.nn.functional.mse_loss(
y_pred, y_true, reduction=self._reduction
Expand Down

0 comments on commit 2f551e8

Please sign in to comment.