diff --git a/mmcls/datasets/builder.py b/mmcls/datasets/builder.py index 1b626b4a12..9db8a0a003 100644 --- a/mmcls/datasets/builder.py +++ b/mmcls/datasets/builder.py @@ -51,11 +51,39 @@ def build_dataset(cfg, default_args=None): cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'], default_args) cp_cfg.pop('type') dataset = KFoldDataset(**cp_cfg) + elif isinstance(cfg.get('ann_file'), (list, tuple)): + dataset = _concat_dataset(cfg, default_args) else: dataset = build_from_cfg(cfg, DATASETS, default_args) return dataset +def _concat_dataset(cfg, default_args=None): + from .dataset_wrappers import ConcatDataset + ann_files = cfg['ann_file'] + img_prefixes = cfg.get('img_prefix', None) + seg_prefixes = cfg.get('seg_prefix', None) + proposal_files = cfg.get('proposal_file', None) + separate_eval = cfg.get('separate_eval', True) + + datasets = [] + num_dset = len(ann_files) + for i in range(num_dset): + data_cfg = copy.deepcopy(cfg) + # pop 'separate_eval' since it is not a valid key for common datasets. + if 'separate_eval' in data_cfg: + data_cfg.pop('separate_eval') + data_cfg['ann_file'] = ann_files[i] + if isinstance(img_prefixes, (list, tuple)): + data_cfg['img_prefix'] = img_prefixes[i] + if isinstance(seg_prefixes, (list, tuple)): + data_cfg['seg_prefix'] = seg_prefixes[i] + if isinstance(proposal_files, (list, tuple)): + data_cfg['proposal_file'] = proposal_files[i] + datasets.append(build_dataset(data_cfg, default_args)) + + return ConcatDataset(datasets, separate_eval) + def build_dataloader(dataset, samples_per_gpu,