Skip to content

Commit

Permalink
[Fix] Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
fanqiNO1 committed Feb 22, 2024
1 parent 0e7fcdc commit 850ae69
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
4 changes: 2 additions & 2 deletions mmengine/model/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def train_step(self, data: Union[dict, tuple, list],
optim_wrapper.update_params(parsed_losses)
return log_vars

def val_step(self, data: Union[tuple, dict, list]) -> Union[tuple, list]:
def val_step(self, data: Union[tuple, dict, list]) -> list:
"""Gets the predictions of given data.
Calls ``self.data_preprocessor(data, False)`` and
Expand All @@ -132,7 +132,7 @@ def val_step(self, data: Union[tuple, dict, list]) -> Union[tuple, list]:
data = self.data_preprocessor(data, False)
return self._run_forward(data, mode='predict') # type: ignore

def test_step(self, data: Union[dict, tuple, list]) -> Union[tuple, list]:
def test_step(self, data: Union[dict, tuple, list]) -> list:
"""``BaseModel`` implements ``test_step`` the same as ``val_step``.
Args:
Expand Down
27 changes: 15 additions & 12 deletions mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mmengine.evaluator import Evaluator
from mmengine.logging import print_log
from mmengine.registry import LOOPS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
from .amp import autocast
from .base_loop import BaseLoop
Expand Down Expand Up @@ -399,12 +400,13 @@ def run_iter(self, idx, data_batch: Sequence[dict]):
'before_val_iter', batch_idx=idx, data_batch=data_batch)
# outputs should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
results = self.runner.model.test_step(data_batch)
outputs, loss = list(), dict()
if isinstance(results, tuple):
outputs, loss = results
elif isinstance(results, list):
outputs, loss = results, dict()
outputs = self.runner.model.test_step(data_batch)
if isinstance(outputs[-1],
BaseDataElement) and outputs.keys() == ['loss']:
loss = outputs[-1].loss # type: ignore
outputs = outputs[:-1]
else:
loss = dict()
# get val loss and avoid breaking change
for loss_name, loss_value in loss.items():
if loss_name not in self.val_loss:
Expand Down Expand Up @@ -494,12 +496,13 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
'before_test_iter', batch_idx=idx, data_batch=data_batch)
# predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
results = self.runner.model.test_step(data_batch)
outputs, loss = list(), dict()
if isinstance(results, tuple):
outputs, loss = results
elif isinstance(results, list):
outputs, loss = results, dict()
outputs = self.runner.model.test_step(data_batch)
if isinstance(outputs[-1],
BaseDataElement) and outputs.keys() == ['loss']:
loss = outputs[-1].loss # type: ignore
outputs = outputs[:-1]
else:
loss = dict()
# get val loss and avoid breaking change
for loss_name, loss_value in loss.items():
if loss_name not in self.test_loss:
Expand Down

0 comments on commit 850ae69

Please sign in to comment.