Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support calculating loss during validation #1503

Merged
merged 16 commits into from
May 17, 2024
82 changes: 81 additions & 1 deletion mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from torch.utils.data import DataLoader

from mmengine.evaluator import Evaluator
from mmengine.logging import print_log
from mmengine.logging import HistoryBuffer, print_log
from mmengine.registry import LOOPS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
from .amp import autocast
from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals
Expand Down Expand Up @@ -361,17 +363,26 @@ def __init__(self,
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.val_loss: Dict[str, HistoryBuffer] = dict()

def run(self) -> dict:
"""Launch validation."""
self.runner.call_hook('before_val')
self.runner.call_hook('before_val_epoch')
self.runner.model.eval()

# clear val loss
self.val_loss.clear()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)

# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))

if self.val_loss:
loss_dict = _parse_losses(self.val_loss, 'val')
metrics.update(loss_dict)

self.runner.call_hook('after_val_epoch', metrics=metrics)
self.runner.call_hook('after_val')
return metrics
Expand All @@ -389,6 +400,9 @@ def run_iter(self, idx, data_batch: Sequence[dict]):
# outputs should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.val_step(data_batch)

outputs, self.val_loss = _update_losses(outputs, self.val_loss)

self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_val_iter',
Expand Down Expand Up @@ -433,17 +447,26 @@ def __init__(self,
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.test_loss: Dict[str, HistoryBuffer] = dict()

def run(self) -> dict:
"""Launch test."""
self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch')
self.runner.model.eval()

# clear test loss
self.test_loss.clear()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)

# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))

if self.test_loss:
loss_dict = _parse_losses(self.test_loss, 'test')
metrics.update(loss_dict)

self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test')
return metrics
Expand All @@ -460,9 +483,66 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
# predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.test_step(data_batch)

outputs, self.test_loss = _update_losses(outputs, self.test_loss)

self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)


def _parse_losses(losses: Dict[str, HistoryBuffer],
stage: str) -> Dict[str, float]:
"""Parses the raw losses of the network.

Args:
losses (dict): raw losses of the network.
stage (str): The stage of loss, e.g., 'val' or 'test'.

Returns:
dict[str, float]: The key is the loss name, and the value is the
average loss.
"""
all_loss = 0
loss_dict: Dict[str, float] = dict()

for loss_name, loss_value in losses.items():
avg_loss = loss_value.mean()
loss_dict[loss_name] = avg_loss
if 'loss' in loss_name:
all_loss += avg_loss

loss_dict[f'{stage}_loss'] = all_loss
return loss_dict


def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]:
"""Update and record the losses of the network.

Args:
outputs (list): The outputs of the network.
losses (dict): The losses of the network.

Returns:
list: The updated outputs of the network.
dict: The updated losses of the network.
"""
if isinstance(outputs[-1],
BaseDataElement) and outputs[-1].keys() == ['loss']:
loss = outputs[-1].loss # type: ignore
outputs = outputs[:-1]
else:
loss = dict()

for loss_name, loss_value in loss.items():
if loss_name not in losses:
losses[loss_name] = HistoryBuffer()
if isinstance(loss_value, torch.Tensor):
losses[loss_name].update(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
for loss_value_i in loss_value:
losses[loss_name].update(loss_value_i.item())
return outputs, losses
Loading