Skip to content

Commit

Permalink
train/base: remove debug prints
Browse files Browse the repository at this point in the history
  • Loading branch information
raehik committed Dec 8, 2023
1 parent 25c834c commit 68b4151
Showing 1 changed file with 0 additions and 5 deletions.
5 changes: 0 additions & 5 deletions src/gz21_ocean_momentum/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,12 @@ def train_for_one_epoch(
self._locked = True
running_loss = RunningAverage()
running_loss_ = RunningAverage()
print(len(dataloader))
for i, (feature, target) in enumerate(dataloader):
print("SAMPLE: start")
# Zero the gradients
self.net.zero_grad()
print("SAMPLE: zerod grad")
# Move batch to the GPU (if possible)
feature = feature.to(self._device, dtype=torch.float)
target = target.to(self._device, dtype=torch.float)
print("SAMPLE: moved to device")
# predict with input
predict = self.net(feature)
# Compute loss
Expand All @@ -147,7 +143,6 @@ def train_for_one_epoch(
# Update the learning rate via the scheduler
if scheduler is not None:
scheduler.step()
print("end loop")
return running_loss.value

def test(self, dataloader) -> float:
Expand Down

0 comments on commit 68b4151

Please sign in to comment.