diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 17beaf8d95..270819d1bd 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -8,7 +8,7 @@ from torch.utils.data import DataLoader from mmengine.evaluator import Evaluator -from mmengine.logging import HistoryBuffer, print_log +from mmengine.logging import print_log from mmengine.registry import LOOPS from mmengine.structures import BaseDataElement from mmengine.utils import is_list_of @@ -363,7 +363,7 @@ def __init__(self, logger='current', level=logging.WARNING) self.fp16 = fp16 - self.val_loss: Dict[str, HistoryBuffer] = dict() + self.val_loss: Dict[str, list] = dict() def run(self) -> dict: """Launch validation.""" @@ -378,7 +378,7 @@ def run(self) -> dict: # get val loss and save to metrics val_loss = 0 for loss_name, loss_value in self.val_loss.items(): - avg_loss = loss_value.mean() + avg_loss = sum(loss_value) / len(loss_value) metrics[loss_name] = avg_loss if 'loss' in loss_name: val_loss += avg_loss # type: ignore @@ -408,19 +408,13 @@ def run_iter(self, idx, data_batch: Sequence[dict]): else: loss = dict() # get val loss and avoid breaking change - # similar to MessageHub for loss_name, loss_value in loss.items(): if loss_name not in self.val_loss: - self.val_loss[loss_name] = HistoryBuffer() + self.val_loss[loss_name] = [] if isinstance(loss_value, torch.Tensor): - loss_value = loss_value.mean().item() + self.val_loss[loss_name].append(loss_value.item()) elif is_list_of(loss_value, torch.Tensor): - loss_value = sum([v.mean() - for v in loss_value]).item() # type: ignore - else: - raise TypeError( - f'{loss_name} is not a tensor or list of tensors') - self.val_loss[loss_name].update(loss_value) + self.val_loss[loss_name].extend([v.item() for v in loss_value]) self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( @@ -466,7 +460,7 @@ def __init__(self, logger='current', level=logging.WARNING) self.fp16 = fp16 - self.test_loss: Dict[str, HistoryBuffer] = dict() + self.test_loss: Dict[str, list] = dict() def run(self) -> dict: """Launch test.""" @@ -481,7 +475,7 @@ def run(self) -> dict: # get test loss and save to metrics test_loss = 0 for loss_name, loss_value in self.test_loss.items(): - avg_loss = loss_value.mean() + avg_loss = sum(loss_value) / len(loss_value) metrics[loss_name] = avg_loss if 'loss' in loss_name: test_loss += avg_loss # type: ignore @@ -510,19 +504,14 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None: else: loss = dict() # get val loss and avoid breaking change - # similar to MessageHub for loss_name, loss_value in loss.items(): if loss_name not in self.test_loss: - self.test_loss[loss_name] = HistoryBuffer() + self.test_loss[loss_name] = [] if isinstance(loss_value, torch.Tensor): - loss_value = loss_value.mean().item() + self.test_loss[loss_name].append(loss_value.item()) elif is_list_of(loss_value, torch.Tensor): - loss_value = sum([v.mean() - for v in loss_value]).item() # type: ignore - else: - raise TypeError( - f'{loss_name} is not a tensor or list of tensors') - self.test_loss[loss_name].update(loss_value) + self.test_loss[loss_name].extend( + [v.item() for v in loss_value]) self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook(