Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp2023-470] Runner supports setting the number of iterations for each epoch #1292

Merged
merged 56 commits into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
b89259c
本次修改增加了num_batch_per_epoch功能
ShuRaymond Aug 4, 2023
be7230b
[Feature]Add num_batch_per_epoch
ShuRaymond Aug 4, 2023
75d8dff
[Feature] Add num_batch_per_epoch
ShuRaymond Aug 4, 2023
d7df66a
[Feature] Add num_batch_per_epoch
ShuRaymond Aug 4, 2023
b3f0032
[Feature] Add num_batch_per_epoch
ShuRaymond Aug 4, 2023
40c5252
add tests and edit md
ShuRaymond Aug 7, 2023
e025c53
fix bugs
ShuRaymond Aug 7, 2023
ea3db4f
fix bugs
ShuRaymond Aug 7, 2023
3a06ed8
fix bugs
ShuRaymond Aug 7, 2023
5c444c7
fix bugs
ShuRaymond Aug 7, 2023
2f40dc0
fix bugs
ShuRaymond Aug 7, 2023
8ef9a26
fix bugs
ShuRaymond Aug 7, 2023
a446035
add tests
ShuRaymond Aug 8, 2023
f9b27e4
add tests
ShuRaymond Aug 8, 2023
d598fe1
fix bugs
ShuRaymond Aug 8, 2023
7ee46ed
fix bugs
ShuRaymond Aug 8, 2023
f88b104
fix bugs
ShuRaymond Aug 8, 2023
8bb15b6
modify metrics
ShuRaymond Aug 9, 2023
e0aa78a
modify docstring
ShuRaymond Aug 9, 2023
143ab59
modify unit tests
ShuRaymond Aug 9, 2023
4ed72c9
modify unit tests
ShuRaymond Aug 9, 2023
bd035d7
modify unit tests
ShuRaymond Aug 9, 2023
05087cd
modify unit tests
ShuRaymond Aug 9, 2023
d6e220b
modify unit tests
ShuRaymond Aug 9, 2023
3311ebe
modify unit tests
ShuRaymond Aug 9, 2023
b4d45a3
rerun ci
ShuRaymond Aug 13, 2023
9ea37fc
rerun ci
ShuRaymond Aug 13, 2023
90a23bc
change method to support num_batch_per_epoch
ShuRaymond Aug 18, 2023
22345a9
delete invaild tests
ShuRaymond Aug 18, 2023
f54f1c1
delete invaild tests
ShuRaymond Aug 18, 2023
5b4b2b4
delete invaild tests
ShuRaymond Aug 18, 2023
4d3e2f0
delete invaild tests
ShuRaymond Aug 18, 2023
8ba88da
update the documentation
ShuRaymond Aug 18, 2023
96ec4d6
update the documentation
ShuRaymond Aug 19, 2023
5c57054
fix
ShuRaymond Aug 19, 2023
0050ff7
Modify the variable name
ShuRaymond Aug 23, 2023
77bcad1
solve the conflicts
ShuRaymond Aug 27, 2023
d1a5456
Merge branch 'main' into dev
ShuRaymond Aug 27, 2023
36519f6
Update debug_tricks.md
ShuRaymond Aug 27, 2023
3fa23ec
Update debug_tricks.md
ShuRaymond Aug 27, 2023
deef4ad
modify the doc and runner.py
ShuRaymond Aug 30, 2023
7e31da4
modify the doc and runner.py
ShuRaymond Aug 30, 2023
988d790
Merge remote-tracking branch 'origin/dev' into dev
ShuRaymond Aug 30, 2023
adb92ef
Merge branch 'open-mmlab:main' into dev
ShuRaymond Aug 30, 2023
f01570b
modify the doc and runner.py
ShuRaymond Aug 30, 2023
de6fd78
modify the doc and runner.py
ShuRaymond Aug 30, 2023
53bc8e0
modify the doc and runner.py
ShuRaymond Aug 30, 2023
ac6e046
Merge remote-tracking branch 'origin/dev' into dev
ShuRaymond Aug 30, 2023
2881405
modify the doc and runner.py
ShuRaymond Aug 30, 2023
4eb168f
Update debug_tricks.md
zhouzaida Sep 1, 2023
9540f21
Update distributed_training.py
zhouzaida Sep 1, 2023
15f9d85
Update debug_tricks.md
zhouzaida Sep 1, 2023
379600b
Update tests/test_runner/test_runner.py
zhouzaida Sep 1, 2023
36ee77d
Minor refine
HAOCHENYE Oct 7, 2023
9ccbb3f
Merge remote-tracking branch 'origin/main' into dev
HAOCHENYE Oct 7, 2023
686ebfb
Merge remote-tracking branch 'origin/main' into dev
HAOCHENYE Oct 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions docs/en/common_usage/debug_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,56 @@ As we can see, the number of iterations has changed to `313`. Compared to before
02/20 14:45:01 - mmengine - INFO - Epoch(train) [1][300/313] lr: 1.0000e-01 eta: 0:20:39 time: 0.0143 data_time: 0.0003 memory: 214 loss: 1.814
```

## Training for a fixed number of iterations (epoch-based training)

During the process of debugging code, sometimes it is necessary to train for several epochs, such as debugging the validation process or checking whether the checkpoint saving meets expectations. However, if the dataset is too large, it may take a long time to complete one epoch. In such cases, you can configure the `num_batch_per_epoch` parameter of the dataloader.

```{note}
The `num_batch_per_epoch` parameter is not a native parameter of PyTorch dataloaders but an additional parameter added by MMEngine to achieve this functionality.
```

Let's take the model defined in [5 minutes to get started with MMEngine](../get_started/15_minutes.md) as an example. By setting `num_batch_per_epoch=5` in both `train_dataloader` and `val_dataloader`, you can ensure that one epoch consists of only 5 iterations.

```python
train_dataloader = dict(
batch_size=32,
dataset=train_set,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
num_batch_per_epoch=5)
val_dataloader = dict(
batch_size=32,
dataset=valid_set,
sampler=dict(type='DefaultSampler', shuffle=False),
collate_fn=dict(type='default_collate'),
num_batch_per_epoch=5)
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
launcher=args.launcher,
)
runner.train()
```

As we can see, the number of iterations has been reduced to 5. Compared to the original setting, this allows you to complete one epoch more quickly.

```
08/18 20:27:22 - mmengine - INFO - Epoch(train) [1][5/5] lr: 1.0000e-03 eta: 0:00:02 time: 0.4566 data_time: 0.0074 memory: 477 loss: 6.7576
08/18 20:27:22 - mmengine - INFO - Saving checkpoint at 1 epochs
08/18 20:27:22 - mmengine - WARNING - `save_param_scheduler` is True but `self.param_schedulers` is None, so skip saving parameter schedulers
08/18 20:27:23 - mmengine - INFO - Epoch(val) [1][5/5] accuracy: 7.5000 data_time: 0.0044 time: 0.0146
08/18 20:27:23 - mmengine - INFO - Exp name: 20230818_202715
08/18 20:27:23 - mmengine - INFO - Epoch(train) [2][5/5] lr: 1.0000e-03 eta: 0:00:00 time: 0.2501 data_time: 0.0077 memory: 477 loss: 5.3044
08/18 20:27:23 - mmengine - INFO - Saving checkpoint at 2 epochs
08/18 20:27:24 - mmengine - INFO - Epoch(val) [2][5/5] accuracy: 12.5000 data_time: 0.0058 time: 0.0175
```

## Find Unused Parameters

When using multiple GPUs training, if model's parameters are involved in forward computation but are not used in producing loss, the program may throw the following error:
Expand Down
50 changes: 50 additions & 0 deletions docs/zh_cn/common_usage/debug_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,56 @@ python tools/train.py configs/resnet/resnet18_8xb16_cifar10.py
02/20 14:45:01 - mmengine - INFO - Epoch(train) [1][300/313] lr: 1.0000e-01 eta: 0:20:39 time: 0.0143 data_time: 0.0003 memory: 214 loss: 1.814
```

## 固定训练的迭代次数(基于 epoch 的训练)

在调试代码的过程中,有时需要训练几个 epoch,例如调试验证过程或者权重的保存是否符合期望。然而如果数据集太大,需要花费较长时间才能训完一个 epoch,在这种情况下,可以配置 dataloader 的 `num_batch_per_epoch` 参数。

```{note}
`num_batch_per_epoch` 参数不是 PyTorch 中 dataloader 的原生参数,而是 MMEngine 为了实现此功能而额外添加的参数。
```

我们以[15 分钟上手 MMEngine](../get_started/15_minutes.md) 中定义的模型为例。通过在 `train_dataloader` 和 `val_dataloader` 中设置 `num_batch_per_epoch=5`,便可实现一个 epoch 只迭代 5 次。

```python
train_dataloader = dict(
batch_size=32,
dataset=train_set,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
num_batch_per_epoch=5)
val_dataloader = dict(
batch_size=32,
dataset=valid_set,
sampler=dict(type='DefaultSampler', shuffle=False),
collate_fn=dict(type='default_collate'),
num_batch_per_epoch=5)
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
launcher=args.launcher,
)
runner.train()
```

可以看到,迭代次数变成了 `5`,相比原先,这样能够更快跑完一个 epoch。

```
08/18 20:27:22 - mmengine - INFO - Epoch(train) [1][5/5] lr: 1.0000e-03 eta: 0:00:02 time: 0.4566 data_time: 0.0074 memory: 477 loss: 6.7576
08/18 20:27:22 - mmengine - INFO - Saving checkpoint at 1 epochs
08/18 20:27:22 - mmengine - WARNING - `save_param_scheduler` is True but `self.param_schedulers` is None, so skip saving parameter schedulers
08/18 20:27:23 - mmengine - INFO - Epoch(val) [1][5/5] accuracy: 7.5000 data_time: 0.0044 time: 0.0146
08/18 20:27:23 - mmengine - INFO - Exp name: 20230818_202715
08/18 20:27:23 - mmengine - INFO - Epoch(train) [2][5/5] lr: 1.0000e-03 eta: 0:00:00 time: 0.2501 data_time: 0.0077 memory: 477 loss: 5.3044
08/18 20:27:23 - mmengine - INFO - Saving checkpoint at 2 epochs
08/18 20:27:24 - mmengine - INFO - Epoch(val) [2][5/5] accuracy: 12.5000 data_time: 0.0058 time: 0.0175
```

## 检查不参与 loss 计算的参数

使用多卡训练时,当模型的参数参与了前向计算,但没有参与 loss 的计算,程序会抛出下面的错误:
Expand Down
18 changes: 1 addition & 17 deletions mmengine/runner/_flexible_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .log_processor import LogProcessor
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
from .priority import Priority, get_priority
from .utils import _get_batch_size

ConfigType = Union[Dict, Config, ConfigDict]
ParamSchedulerType = Union[List[_ParamScheduler], Dict[str,
Expand Down Expand Up @@ -1646,20 +1647,3 @@ def _log_env(self) -> None:

if self.cfg._cfg_dict:
self.logger.info(f'Config:\n{self.cfg.pretty_text}')


def _get_batch_size(dataloader):
if isinstance(dataloader, dict):
if 'batch_size' in dataloader:
return dataloader['batch_size']
elif ('batch_sampler' in dataloader
and 'batch_size' in dataloader['batch_sampler']):
return dataloader['batch_sampler']['batch_size']
else:
raise ValueError('Please set batch_size in `Dataloader` or '
'`batch_sampler`')
elif isinstance(dataloader, DataLoader):
return dataloader.batch_sampler.batch_size
else:
raise ValueError('dataloader should be a dict or a Dataloader '
f'instance, but got {type(dataloader)}')
30 changes: 27 additions & 3 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from mmengine.config import Config, ConfigDict
from mmengine.dataset import worker_init_fn as default_worker_init_fn
from mmengine.device import get_device
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
is_distributed, master_only)
from mmengine.dist import (broadcast, get_dist_info, get_rank, get_world_size,
init_dist, is_distributed, master_only)
from mmengine.evaluator import Evaluator
from mmengine.fileio import FileClient, join_path
from mmengine.hooks import Hook
Expand All @@ -49,14 +49,30 @@
from .log_processor import LogProcessor
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
from .priority import Priority, get_priority
from .utils import set_random_seed
from .utils import _get_batch_size, set_random_seed

ConfigType = Union[Dict, Config, ConfigDict]
ParamSchedulerType = Union[List[_ParamScheduler], Dict[str,
List[_ParamScheduler]]]
OptimWrapperType = Union[OptimWrapper, OptimWrapperDict]


class _SlicedDataset:

def __init__(self, dataset, length) -> None:
self._dataset = dataset
self._length = length

def __getattr__(self, name):
return getattr(self._dataset, name)

def __getitem__(self, idx):
return self._dataset[idx]

def __len__(self):
return self._length


@RUNNERS.register_module()
class Runner:
"""A training helper for PyTorch.
Expand Down Expand Up @@ -1359,6 +1375,14 @@ def build_dataloader(dataloader: Union[DataLoader, Dict],
# if `dataset_cfg` is not a valid type
dataset = dataset_cfg

num_batch_per_epoch = dataloader_cfg.pop('num_batch_per_epoch', None)
if num_batch_per_epoch is not None:
world_size = get_world_size()
num_samples = (
num_batch_per_epoch * _get_batch_size(dataloader_cfg) *
world_size)
dataset = _SlicedDataset(dataset, num_samples)

# build sampler
sampler_cfg = dataloader_cfg.pop('sampler')
if isinstance(sampler_cfg, dict):
Expand Down
18 changes: 18 additions & 0 deletions mmengine/runner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import torch
from torch.utils.data import DataLoader

from mmengine.dist import get_rank, sync_random_seed
from mmengine.logging import print_log
Expand Down Expand Up @@ -84,3 +85,20 @@ def set_random_seed(seed: Optional[int] = None,
if digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
torch.use_deterministic_algorithms(True)
return seed


def _get_batch_size(dataloader: dict):
if isinstance(dataloader, dict):
if 'batch_size' in dataloader:
return dataloader['batch_size']
elif ('batch_sampler' in dataloader
and 'batch_size' in dataloader['batch_sampler']):
return dataloader['batch_sampler']['batch_size']
else:
raise ValueError('Please set batch_size in `Dataloader` or '
'`batch_sampler`')
elif isinstance(dataloader, DataLoader):
return dataloader.batch_sampler.batch_size
else:
raise ValueError('dataloader should be a dict or a Dataloader '
f'instance, but got {type(dataloader)}')
73 changes: 73 additions & 0 deletions tests/test_runner/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,17 @@ def custom_collate(data_batch):
dataloader = runner.build_dataloader(cfg)
self.assertIsInstance(dataloader.collate_fn, partial)

# num_batch_per_epoch is not None
cfg = dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
batch_size=3,
num_workers=2,
num_batch_per_epoch=2)
dataloader = runner.build_dataloader(cfg)
self.assertEqual(len(dataloader.dataset), 6)

def test_build_train_loop(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_train_loop'
Expand Down Expand Up @@ -1812,6 +1823,18 @@ def train_step(self, *args, **kwargs):
self.assertIsInstance(runner._val_loop, BaseLoop)
self.assertIsInstance(runner._test_loop, dict)

# 15. test num_batch_per_epoch
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train15'
cfg.train_dataloader['num_batch_per_epoch'] = 2
cfg.train_cfg = dict(
by_epoch=True,
max_epochs=3,
)
runner = Runner.from_cfg(cfg)
runner.train()
self.assertEqual(runner.iter, 3 * 2)

@skipIf(
SKIP_TEST_COMPILE,
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
Expand Down Expand Up @@ -1899,6 +1922,31 @@ def get_outputs_callback(module, inputs, outputs):
self.assertIsInstance(runner._train_loop, dict)
self.assertIsInstance(runner._test_loop, dict)

# test num_batch_per_epoch
val_result = 0

@HOOKS.register_module(force=True)
class TestIterHook(Hook):

def __init__(self):
self.val_iter = 0

def after_val_iter(self,
runner,
batch_idx,
data_batch=None,
outputs=None):
self.val_iter += 1
nonlocal val_result
val_result = self.val_iter

cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
cfg.val_dataloader['num_batch_per_epoch'] = 2
runner = Runner.from_cfg(cfg)
runner.val()
self.assertEqual(val_result, 2)

@skipIf(
SKIP_TEST_COMPILE,
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
Expand Down Expand Up @@ -1979,6 +2027,31 @@ def get_outputs_callback(module, inputs, outputs):
self.assertIsInstance(runner._train_loop, dict)
self.assertIsInstance(runner._val_loop, dict)

# test num_batch_per_epoch
test_result = 0

@HOOKS.register_module(force=True)
class TestIterHook(Hook):

def __init__(self):
self.test_iter = 0

def after_test_iter(self,
runner,
batch_idx,
data_batch=None,
outputs=None):
self.test_iter += 1
nonlocal test_result
test_result = self.test_iter

cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
cfg.test_dataloader['num_batch_per_epoch'] = 2
runner = Runner.from_cfg(cfg)
runner.test()
self.assertEqual(test_result, 2)

@skipIf(
SKIP_TEST_COMPILE,
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
Expand Down
Loading