Skip to content

Commit

Permalink
[Fix] Clear loss before run_iter
Browse files Browse the repository at this point in the history
  • Loading branch information
fanqiNO1 committed Mar 5, 2024
1 parent b5aaa37 commit 4eb3244
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,9 @@ def run(self) -> dict:
self.runner.call_hook('before_val')
self.runner.call_hook('before_val_epoch')
self.runner.model.eval()

# clear val loss
self.val_loss = dict()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)

Expand All @@ -382,7 +385,7 @@ def run(self) -> dict:
metrics[loss_name] = avg_loss
if 'loss' in loss_name:
val_loss += avg_loss # type: ignore
if val_loss != 0:
if len(self.val_loss.keys()) != 0:
metrics['val_loss'] = val_loss

self.runner.call_hook('after_val_epoch', metrics=metrics)
Expand Down Expand Up @@ -468,6 +471,9 @@ def run(self) -> dict:
self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch')
self.runner.model.eval()

# clear test loss
self.test_loss = dict()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)

Expand All @@ -480,7 +486,7 @@ def run(self) -> dict:
metrics[loss_name] = avg_loss
if 'loss' in loss_name:
test_loss += avg_loss # type: ignore
if test_loss != 0:
if len(self.test_loss.keys()) != 0:
metrics['test_loss'] = test_loss

self.runner.call_hook('after_test_epoch', metrics=metrics)
Expand Down

0 comments on commit 4eb3244

Please sign in to comment.