diff --git a/src/gz21_ocean_momentum/train/base.py b/src/gz21_ocean_momentum/train/base.py index 1c577787..46beb891 100755 --- a/src/gz21_ocean_momentum/train/base.py +++ b/src/gz21_ocean_momentum/train/base.py @@ -117,16 +117,12 @@ def train_for_one_epoch( self._locked = True running_loss = RunningAverage() running_loss_ = RunningAverage() - print(len(dataloader)) for i, (feature, target) in enumerate(dataloader): - print("SAMPLE: start") # Zero the gradients self.net.zero_grad() - print("SAMPLE: zerod grad") # Move batch to the GPU (if possible) feature = feature.to(self._device, dtype=torch.float) target = target.to(self._device, dtype=torch.float) - print("SAMPLE: moved to device") # predict with input predict = self.net(feature) # Compute loss @@ -147,7 +143,6 @@ def train_for_one_epoch( # Update the learning rate via the scheduler if scheduler is not None: scheduler.step() - print("end loop") return running_loss.value def test(self, dataloader) -> float: