diff --git a/examples/distributed_training.py b/examples/distributed_training.py index 3c702eb83e..7030a9f4d8 100644 --- a/examples/distributed_training.py +++ b/examples/distributed_training.py @@ -77,14 +77,12 @@ def main(): batch_size=32, dataset=train_set, sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=dict(type='default_collate'), - num_batch_per_epoch=5) + collate_fn=dict(type='default_collate')) val_dataloader = dict( batch_size=32, dataset=valid_set, sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=dict(type='default_collate'), - num_batch_per_epoch=5) + collate_fn=dict(type='default_collate')) runner = Runner( model=MMResNet50(), work_dir='./work_dir',