Skip to content

Commit

Permalink
Update build_dataset function
Browse files Browse the repository at this point in the history
  • Loading branch information
ccanamero committed Aug 28, 2024
1 parent 3c59c2b commit 24e970e
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions mmcls/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,39 @@ def build_dataset(cfg, default_args=None):
cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'], default_args)
cp_cfg.pop('type')
dataset = KFoldDataset(**cp_cfg)
elif isinstance(cfg.get('ann_file'), (list, tuple)):
dataset = _concat_dataset(cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)

return dataset

def _concat_dataset(cfg, default_args=None):
from .dataset_wrappers import ConcatDataset
ann_files = cfg['ann_file']
img_prefixes = cfg.get('img_prefix', None)
seg_prefixes = cfg.get('seg_prefix', None)
proposal_files = cfg.get('proposal_file', None)
separate_eval = cfg.get('separate_eval', True)

datasets = []
num_dset = len(ann_files)
for i in range(num_dset):
data_cfg = copy.deepcopy(cfg)
# pop 'separate_eval' since it is not a valid key for common datasets.
if 'separate_eval' in data_cfg:
data_cfg.pop('separate_eval')
data_cfg['ann_file'] = ann_files[i]
if isinstance(img_prefixes, (list, tuple)):
data_cfg['img_prefix'] = img_prefixes[i]
if isinstance(seg_prefixes, (list, tuple)):
data_cfg['seg_prefix'] = seg_prefixes[i]
if isinstance(proposal_files, (list, tuple)):
data_cfg['proposal_file'] = proposal_files[i]
datasets.append(build_dataset(data_cfg, default_args))

return ConcatDataset(datasets, separate_eval)


def build_dataloader(dataset,
samples_per_gpu,
Expand Down

0 comments on commit 24e970e

Please sign in to comment.