Skip to content

Commit

Permalink
Make it runnable and fix runtime errors
Browse files Browse the repository at this point in the history
  • Loading branch information
DubiousCactus committed May 31, 2024
1 parent 4a545f1 commit c53c587
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
17 changes: 9 additions & 8 deletions src/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ def _train_val_iteration(
torch.Tensor: The loss for the batch.
Dict[str, torch.Tensor]: The loss components for the batch.
"""
# x, y = batch
# y_hat = self._model(x)
# losses = self._training_loss(x, y, y_hat)
# loss = sum([v for v in losses.values()])
# return loss, losses
raise NotImplementedError
# TODO: You'll most likely want to override this method.
x, y = batch
y_hat = self._model(x)
losses: Dict[str, torch.Tensor] = self._training_loss(y, y_hat)
loss = sum([v for v in losses.values()])
return loss, losses

def _train_epoch(
self, description: str, visualize: bool, epoch: int, last_val_loss: float
Expand Down Expand Up @@ -134,7 +134,7 @@ def _train_epoch(
break
self._opt.zero_grad()
loss, loss_components = self._train_val_iteration(
batch, epoch
batch, epoch, validation=False
) # User implementation goes here (train.py)
loss.backward()
self._opt.step()
Expand Down Expand Up @@ -193,7 +193,8 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float:
# Blink the progress bar to indicate that the validation loop is running
blink_pbar(i, self._pbar, 4)
loss, loss_components = self._train_val_iteration(
batch
batch,
epoch,
) # User implementation goes here (train.py)
val_loss.update(loss.item())
for k, v in loss_components.items():
Expand Down
6 changes: 5 additions & 1 deletion src/losses/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ def __init__(self, reduction: str):
self._reduction = reduction

def __call__(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.mse_loss(y_pred, y_true, reduction=self._reduction)
return {
"mse": torch.nn.functional.mse_loss(
y_pred, y_true, reduction=self._reduction
)
}

0 comments on commit c53c587

Please sign in to comment.