forked from xinntao/EDVR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
__init__.py
126 lines (110 loc) · 4.52 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import importlib
import numpy as np
import random
import torch
import torch.utils.data
from functools import partial
from os import path as osp
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.dist_util import get_dist_info
__all__ = ['create_dataset', 'create_dataloader']
# automatically scan and import dataset modules
# scan all the files under the data folder with '_dataset' in file names
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [
osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
if v.endswith('_dataset.py')
]
# import all the dataset modules
_dataset_modules = [
importlib.import_module(f'basicsr.data.{file_name}')
for file_name in dataset_filenames
]
def create_dataset(dataset_opt):
"""Create dataset.
Args:
dataset_opt (dict): Configuration for dataset. It constains:
name (str): Dataset name.
type (str): Dataset type.
"""
dataset_type = dataset_opt['type']
# dynamic instantiation
for module in _dataset_modules:
dataset_cls = getattr(module, dataset_type, None)
if dataset_cls is not None:
break
if dataset_cls is None:
raise ValueError(f'Dataset {dataset_type} is not found.')
dataset = dataset_cls(dataset_opt)
logger = get_root_logger()
logger.info(
f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} '
'is created.')
return dataset
def create_dataloader(dataset,
dataset_opt,
num_gpu=1,
dist=False,
sampler=None,
seed=None):
"""Create dataloader.
Args:
dataset (torch.utils.data.Dataset): Dataset.
dataset_opt (dict): Dataset options. It contains the following keys:
phase (str): 'train' or 'val'.
num_worker_per_gpu (int): Number of workers for each GPU.
batch_size_per_gpu (int): Training batch size for each GPU.
num_gpu (int): Number of GPUs. Used only in the train phase.
Default: 1.
dist (bool): Whether in distributed training. Used only in the train
phase. Default: False.
sampler (torch.utils.data.sampler): Data sampler. Default: None.
seed (int | None): Seed. Default: None
"""
phase = dataset_opt['phase']
rank, _ = get_dist_info()
if phase == 'train':
if dist: # distributed training
batch_size = dataset_opt['batch_size_per_gpu']
num_workers = dataset_opt['num_worker_per_gpu']
else: # non-distributed training
multiplier = 1 if num_gpu == 0 else num_gpu
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
dataloader_args = dict(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=sampler,
drop_last=True)
if sampler is None:
dataloader_args['shuffle'] = True
dataloader_args['worker_init_fn'] = partial(
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
elif phase in ['val', 'test']: # validation
dataloader_args = dict(
dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
else:
raise ValueError(f'Wrong dataset phase: {phase}. '
"Supported ones are 'train', 'val' and 'test'.")
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
prefetch_mode = dataset_opt.get('prefetch_mode')
if prefetch_mode == 'cpu': # CPUPrefetcher
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
logger = get_root_logger()
logger.info(f'Use {prefetch_mode} prefetch dataloader: '
f'num_prefetch_queue = {num_prefetch_queue}')
return PrefetchDataLoader(
num_prefetch_queue=num_prefetch_queue, **dataloader_args)
else:
# prefetch_mode=None: Normal dataloader
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
return torch.utils.data.DataLoader(**dataloader_args)
def worker_init_fn(worker_id, num_workers, rank, seed):
# Set the worker seed to num_workers * rank + worker_id + seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)