From b8a31671a429a973c8d3e3de7270090cf34c6f0e Mon Sep 17 00:00:00 2001 From: Shu Liqiang <44634649+ShuRaymond@users.noreply.github.com> Date: Sun, 8 Oct 2023 16:43:44 +0800 Subject: [PATCH] [Feature] Runner supports setting the number of iterations for per epoch (#1292) --- docs/en/common_usage/debug_tricks.md | 50 +++++++++++++++++ docs/zh_cn/common_usage/debug_tricks.md | 50 +++++++++++++++++ mmengine/runner/_flexible_runner.py | 18 +----- mmengine/runner/runner.py | 30 +++++++++- mmengine/runner/utils.py | 18 ++++++ tests/test_runner/test_runner.py | 73 +++++++++++++++++++++++++ 6 files changed, 219 insertions(+), 20 deletions(-) diff --git a/docs/en/common_usage/debug_tricks.md b/docs/en/common_usage/debug_tricks.md index bbc63f693f..641077f260 100644 --- a/docs/en/common_usage/debug_tricks.md +++ b/docs/en/common_usage/debug_tricks.md @@ -50,6 +50,56 @@ As we can see, the number of iterations has changed to `313`. Compared to before 02/20 14:45:01 - mmengine - INFO - Epoch(train) [1][300/313] lr: 1.0000e-01 eta: 0:20:39 time: 0.0143 data_time: 0.0003 memory: 214 loss: 1.814 ``` +## Training for a fixed number of iterations (epoch-based training) + +During the process of debugging code, sometimes it is necessary to train for several epochs, such as debugging the validation process or checking whether the checkpoint saving meets expectations. However, if the dataset is too large, it may take a long time to complete one epoch. In such cases, you can configure the `num_batch_per_epoch` parameter of the dataloader. + +```{note} +The `num_batch_per_epoch` parameter is not a native parameter of PyTorch dataloaders but an additional parameter added by MMEngine to achieve this functionality. +``` + +Let's take the model defined in [5 minutes to get started with MMEngine](../get_started/15_minutes.md) as an example. By setting `num_batch_per_epoch=5` in both `train_dataloader` and `val_dataloader`, you can ensure that one epoch consists of only 5 iterations. + +```python +train_dataloader = dict( + batch_size=32, + dataset=train_set, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + num_batch_per_epoch=5) +val_dataloader = dict( + batch_size=32, + dataset=valid_set, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=dict(type='default_collate'), + num_batch_per_epoch=5) +runner = Runner( + model=MMResNet50(), + work_dir='./work_dir', + train_dataloader=train_dataloader, + optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)), + train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), + val_dataloader=val_dataloader, + val_cfg=dict(), + val_evaluator=dict(type=Accuracy), + launcher=args.launcher, +) +runner.train() +``` + +As we can see, the number of iterations has been reduced to 5. Compared to the original setting, this allows you to complete one epoch more quickly. + +``` +08/18 20:27:22 - mmengine - INFO - Epoch(train) [1][5/5] lr: 1.0000e-03 eta: 0:00:02 time: 0.4566 data_time: 0.0074 memory: 477 loss: 6.7576 +08/18 20:27:22 - mmengine - INFO - Saving checkpoint at 1 epochs +08/18 20:27:22 - mmengine - WARNING - `save_param_scheduler` is True but `self.param_schedulers` is None, so skip saving parameter schedulers +08/18 20:27:23 - mmengine - INFO - Epoch(val) [1][5/5] accuracy: 7.5000 data_time: 0.0044 time: 0.0146 +08/18 20:27:23 - mmengine - INFO - Exp name: 20230818_202715 +08/18 20:27:23 - mmengine - INFO - Epoch(train) [2][5/5] lr: 1.0000e-03 eta: 0:00:00 time: 0.2501 data_time: 0.0077 memory: 477 loss: 5.3044 +08/18 20:27:23 - mmengine - INFO - Saving checkpoint at 2 epochs +08/18 20:27:24 - mmengine - INFO - Epoch(val) [2][5/5] accuracy: 12.5000 data_time: 0.0058 time: 0.0175 +``` + ## Find Unused Parameters When using multiple GPUs training, if model's parameters are involved in forward computation but are not used in producing loss, the program may throw the following error: diff --git a/docs/zh_cn/common_usage/debug_tricks.md b/docs/zh_cn/common_usage/debug_tricks.md index 734e967cc6..a3f53432b6 100644 --- a/docs/zh_cn/common_usage/debug_tricks.md +++ b/docs/zh_cn/common_usage/debug_tricks.md @@ -50,6 +50,56 @@ python tools/train.py configs/resnet/resnet18_8xb16_cifar10.py 02/20 14:45:01 - mmengine - INFO - Epoch(train) [1][300/313] lr: 1.0000e-01 eta: 0:20:39 time: 0.0143 data_time: 0.0003 memory: 214 loss: 1.814 ``` +## 固定训练的迭代次数(基于 epoch 的训练) + +在调试代码的过程中,有时需要训练几个 epoch,例如调试验证过程或者权重的保存是否符合期望。然而如果数据集太大,需要花费较长时间才能训完一个 epoch,在这种情况下,可以配置 dataloader 的 `num_batch_per_epoch` 参数。 + +```{note} +`num_batch_per_epoch` 参数不是 PyTorch 中 dataloader 的原生参数,而是 MMEngine 为了实现此功能而额外添加的参数。 +``` + +我们以[15 分钟上手 MMEngine](../get_started/15_minutes.md) 中定义的模型为例。通过在 `train_dataloader` 和 `val_dataloader` 中设置 `num_batch_per_epoch=5`,便可实现一个 epoch 只迭代 5 次。 + +```python +train_dataloader = dict( + batch_size=32, + dataset=train_set, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + num_batch_per_epoch=5) +val_dataloader = dict( + batch_size=32, + dataset=valid_set, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=dict(type='default_collate'), + num_batch_per_epoch=5) +runner = Runner( + model=MMResNet50(), + work_dir='./work_dir', + train_dataloader=train_dataloader, + optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)), + train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), + val_dataloader=val_dataloader, + val_cfg=dict(), + val_evaluator=dict(type=Accuracy), + launcher=args.launcher, +) +runner.train() +``` + +可以看到,迭代次数变成了 `5`,相比原先,这样能够更快跑完一个 epoch。 + +``` +08/18 20:27:22 - mmengine - INFO - Epoch(train) [1][5/5] lr: 1.0000e-03 eta: 0:00:02 time: 0.4566 data_time: 0.0074 memory: 477 loss: 6.7576 +08/18 20:27:22 - mmengine - INFO - Saving checkpoint at 1 epochs +08/18 20:27:22 - mmengine - WARNING - `save_param_scheduler` is True but `self.param_schedulers` is None, so skip saving parameter schedulers +08/18 20:27:23 - mmengine - INFO - Epoch(val) [1][5/5] accuracy: 7.5000 data_time: 0.0044 time: 0.0146 +08/18 20:27:23 - mmengine - INFO - Exp name: 20230818_202715 +08/18 20:27:23 - mmengine - INFO - Epoch(train) [2][5/5] lr: 1.0000e-03 eta: 0:00:00 time: 0.2501 data_time: 0.0077 memory: 477 loss: 5.3044 +08/18 20:27:23 - mmengine - INFO - Saving checkpoint at 2 epochs +08/18 20:27:24 - mmengine - INFO - Epoch(val) [2][5/5] accuracy: 12.5000 data_time: 0.0058 time: 0.0175 +``` + ## 检查不参与 loss 计算的参数 使用多卡训练时,当模型的参数参与了前向计算,但没有参与 loss 的计算,程序会抛出下面的错误: diff --git a/mmengine/runner/_flexible_runner.py b/mmengine/runner/_flexible_runner.py index 714ac611b2..8a771b0550 100644 --- a/mmengine/runner/_flexible_runner.py +++ b/mmengine/runner/_flexible_runner.py @@ -31,6 +31,7 @@ from .log_processor import LogProcessor from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop from .priority import Priority, get_priority +from .utils import _get_batch_size ConfigType = Union[Dict, Config, ConfigDict] ParamSchedulerType = Union[List[_ParamScheduler], Dict[str, @@ -1646,20 +1647,3 @@ def _log_env(self) -> None: if self.cfg._cfg_dict: self.logger.info(f'Config:\n{self.cfg.pretty_text}') - - -def _get_batch_size(dataloader): - if isinstance(dataloader, dict): - if 'batch_size' in dataloader: - return dataloader['batch_size'] - elif ('batch_sampler' in dataloader - and 'batch_size' in dataloader['batch_sampler']): - return dataloader['batch_sampler']['batch_size'] - else: - raise ValueError('Please set batch_size in `Dataloader` or ' - '`batch_sampler`') - elif isinstance(dataloader, DataLoader): - return dataloader.batch_sampler.batch_size - else: - raise ValueError('dataloader should be a dict or a Dataloader ' - f'instance, but got {type(dataloader)}') diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index d66262c559..68716ab253 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -21,8 +21,8 @@ from mmengine.config import Config, ConfigDict from mmengine.dataset import worker_init_fn as default_worker_init_fn from mmengine.device import get_device -from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, - is_distributed, master_only) +from mmengine.dist import (broadcast, get_dist_info, get_rank, get_world_size, + init_dist, is_distributed, master_only) from mmengine.evaluator import Evaluator from mmengine.fileio import FileClient, join_path from mmengine.hooks import Hook @@ -49,7 +49,7 @@ from .log_processor import LogProcessor from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop from .priority import Priority, get_priority -from .utils import set_random_seed +from .utils import _get_batch_size, set_random_seed ConfigType = Union[Dict, Config, ConfigDict] ParamSchedulerType = Union[List[_ParamScheduler], Dict[str, @@ -57,6 +57,22 @@ OptimWrapperType = Union[OptimWrapper, OptimWrapperDict] +class _SlicedDataset: + + def __init__(self, dataset, length) -> None: + self._dataset = dataset + self._length = length + + def __getattr__(self, name): + return getattr(self._dataset, name) + + def __getitem__(self, idx): + return self._dataset[idx] + + def __len__(self): + return self._length + + @RUNNERS.register_module() class Runner: """A training helper for PyTorch. @@ -1359,6 +1375,14 @@ def build_dataloader(dataloader: Union[DataLoader, Dict], # if `dataset_cfg` is not a valid type dataset = dataset_cfg + num_batch_per_epoch = dataloader_cfg.pop('num_batch_per_epoch', None) + if num_batch_per_epoch is not None: + world_size = get_world_size() + num_samples = ( + num_batch_per_epoch * _get_batch_size(dataloader_cfg) * + world_size) + dataset = _SlicedDataset(dataset, num_samples) + # build sampler sampler_cfg = dataloader_cfg.pop('sampler') if isinstance(sampler_cfg, dict): diff --git a/mmengine/runner/utils.py b/mmengine/runner/utils.py index b5f1772db1..d7098c7295 100644 --- a/mmengine/runner/utils.py +++ b/mmengine/runner/utils.py @@ -5,6 +5,7 @@ import numpy as np import torch +from torch.utils.data import DataLoader from mmengine.dist import get_rank, sync_random_seed from mmengine.logging import print_log @@ -84,3 +85,20 @@ def set_random_seed(seed: Optional[int] = None, if digit_version(TORCH_VERSION) >= digit_version('1.10.0'): torch.use_deterministic_algorithms(True) return seed + + +def _get_batch_size(dataloader: dict): + if isinstance(dataloader, dict): + if 'batch_size' in dataloader: + return dataloader['batch_size'] + elif ('batch_sampler' in dataloader + and 'batch_size' in dataloader['batch_sampler']): + return dataloader['batch_sampler']['batch_size'] + else: + raise ValueError('Please set batch_size in `Dataloader` or ' + '`batch_sampler`') + elif isinstance(dataloader, DataLoader): + return dataloader.batch_sampler.batch_size + else: + raise ValueError('dataloader should be a dict or a Dataloader ' + f'instance, but got {type(dataloader)}') diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index c3a799871a..c8a58e9c8a 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -1301,6 +1301,17 @@ def custom_collate(data_batch): dataloader = runner.build_dataloader(cfg) self.assertIsInstance(dataloader.collate_fn, partial) + # num_batch_per_epoch is not None + cfg = dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + batch_size=3, + num_workers=2, + num_batch_per_epoch=2) + dataloader = runner.build_dataloader(cfg) + self.assertEqual(len(dataloader.dataset), 6) + def test_build_train_loop(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_build_train_loop' @@ -1812,6 +1823,18 @@ def train_step(self, *args, **kwargs): self.assertIsInstance(runner._val_loop, BaseLoop) self.assertIsInstance(runner._test_loop, dict) + # 15. test num_batch_per_epoch + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.experiment_name = 'test_train15' + cfg.train_dataloader['num_batch_per_epoch'] = 2 + cfg.train_cfg = dict( + by_epoch=True, + max_epochs=3, + ) + runner = Runner.from_cfg(cfg) + runner.train() + self.assertEqual(runner.iter, 3 * 2) + @skipIf( SKIP_TEST_COMPILE, reason='torch.compile is not valid, please install PyTorch>=2.0.0') @@ -1899,6 +1922,31 @@ def get_outputs_callback(module, inputs, outputs): self.assertIsInstance(runner._train_loop, dict) self.assertIsInstance(runner._test_loop, dict) + # test num_batch_per_epoch + val_result = 0 + + @HOOKS.register_module(force=True) + class TestIterHook(Hook): + + def __init__(self): + self.val_iter = 0 + + def after_val_iter(self, + runner, + batch_idx, + data_batch=None, + outputs=None): + self.val_iter += 1 + nonlocal val_result + val_result = self.val_iter + + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.custom_hooks = [dict(type='TestIterHook', priority=50)] + cfg.val_dataloader['num_batch_per_epoch'] = 2 + runner = Runner.from_cfg(cfg) + runner.val() + self.assertEqual(val_result, 2) + @skipIf( SKIP_TEST_COMPILE, reason='torch.compile is not valid, please install PyTorch>=2.0.0') @@ -1979,6 +2027,31 @@ def get_outputs_callback(module, inputs, outputs): self.assertIsInstance(runner._train_loop, dict) self.assertIsInstance(runner._val_loop, dict) + # test num_batch_per_epoch + test_result = 0 + + @HOOKS.register_module(force=True) + class TestIterHook(Hook): + + def __init__(self): + self.test_iter = 0 + + def after_test_iter(self, + runner, + batch_idx, + data_batch=None, + outputs=None): + self.test_iter += 1 + nonlocal test_result + test_result = self.test_iter + + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.custom_hooks = [dict(type='TestIterHook', priority=50)] + cfg.test_dataloader['num_batch_per_epoch'] = 2 + runner = Runner.from_cfg(cfg) + runner.test() + self.assertEqual(test_result, 2) + @skipIf( SKIP_TEST_COMPILE, reason='torch.compile is not valid, please install PyTorch>=2.0.0')