-torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
- int act, int grad, float alpha, float scale);
+torch::Tensor fused_bias_act_op(const torch::Tensor& input,
+ const torch::Tensor& bias,
+ const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
-torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
- int act, int grad, float alpha, float scale) {
+torch::Tensor fused_bias_act(const torch::Tensor& input,
+ const torch::Tensor& bias,
+ const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
CHECK_CUDA(input);
CHECK_CUDA(bias);
diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py
index 66a98b6..92bbf4b 100644
--- a/basicsr/models/sr_model.py
+++ b/basicsr/models/sr_model.py
@@ -1,13 +1,13 @@
import importlib
-import mmcv
import torch
from collections import OrderedDict
from copy import deepcopy
from os import path as osp
+from tqdm import tqdm
from basicsr.models.archs import define_network
from basicsr.models.base_model import BaseModel
-from basicsr.utils import ProgressBar, get_root_logger, tensor2img
+from basicsr.utils import get_root_logger, imwrite, tensor2img
loss_module = importlib.import_module('basicsr.models.losses')
metric_module = importlib.import_module('basicsr.metrics')
@@ -25,10 +25,10 @@ def __init__(self, opt):
self.print_network(self.net_g)
# load pretrained models
- load_path = self.opt['path'].get('pretrain_model_g', None)
+ load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g, load_path,
- self.opt['path']['strict_load'])
+ self.opt['path'].get('strict_load_g', True))
if self.is_train:
self.init_training_settings()
@@ -131,7 +131,7 @@ def nondist_validation(self, dataloader, current_iter, tb_logger,
metric: 0
for metric in self.opt['val']['metrics'].keys()
}
- pbar = ProgressBar(len(dataloader))
+ pbar = tqdm(total=len(dataloader), unit='image')
for idx, val_data in enumerate(dataloader):
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
@@ -163,7 +163,7 @@ def nondist_validation(self, dataloader, current_iter, tb_logger,
save_img_path = osp.join(
self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["name"]}.png')
- mmcv.imwrite(sr_img, save_img_path)
+ imwrite(sr_img, save_img_path)
if with_metrics:
# calculate metrics
@@ -172,7 +172,9 @@ def nondist_validation(self, dataloader, current_iter, tb_logger,
metric_type = opt_.pop('type')
self.metric_results[name] += getattr(
metric_module, metric_type)(sr_img, gt_img, **opt_)
- pbar.update(f'Test {img_name}')
+ pbar.update(1)
+ pbar.set_description(f'Test {img_name}')
+ pbar.close()
if with_metrics:
for metric in self.metric_results.keys():
diff --git a/basicsr/models/srgan_model.py b/basicsr/models/srgan_model.py
index d927773..7d08d7b 100644
--- a/basicsr/models/srgan_model.py
+++ b/basicsr/models/srgan_model.py
@@ -21,10 +21,10 @@ def init_training_settings(self):
self.print_network(self.net_d)
# load pretrained models
- load_path = self.opt['path'].get('pretrain_model_d', None)
+ load_path = self.opt['path'].get('pretrain_network_d', None)
if load_path is not None:
self.load_network(self.net_d, load_path,
- self.opt['path']['strict_load'])
+ self.opt['path'].get('strict_load_d', True))
self.net_g.train()
self.net_d.train()
diff --git a/basicsr/models/stylegan2_model.py b/basicsr/models/stylegan2_model.py
index 7cf7aec..c1ac6cf 100644
--- a/basicsr/models/stylegan2_model.py
+++ b/basicsr/models/stylegan2_model.py
@@ -1,6 +1,6 @@
+import cv2
import importlib
import math
-import mmcv
import numpy as np
import random
import torch
@@ -11,7 +11,7 @@
from basicsr.models.archs import define_network
from basicsr.models.base_model import BaseModel
from basicsr.models.losses.losses import g_path_regularize, r1_penalty
-from basicsr.utils import tensor2img
+from basicsr.utils import imwrite, tensor2img
loss_module = importlib.import_module('basicsr.models.losses')
@@ -27,11 +27,12 @@ def __init__(self, opt):
self.net_g = self.model_to_device(self.net_g)
self.print_network(self.net_g)
# load pretrained model
- load_path = self.opt['path'].get('pretrain_model_g', None)
+ load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
param_key = self.opt['path'].get('param_key_g', 'params')
self.load_network(self.net_g, load_path,
- self.opt['path']['strict_load'], param_key)
+ self.opt['path'].get('strict_load_g',
+ True), param_key)
# latent dimension: self.num_style_feat
self.num_style_feat = opt['network_g']['num_style_feat']
@@ -51,10 +52,10 @@ def init_training_settings(self):
self.print_network(self.net_d)
# load pretrained model
- load_path = self.opt['path'].get('pretrain_model_d', None)
+ load_path = self.opt['path'].get('pretrain_network_d', None)
if load_path is not None:
self.load_network(self.net_d, load_path,
- self.opt['path']['strict_load'])
+ self.opt['path'].get('strict_load_d', True))
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema only used for testing on one GPU and saving, do not need to
@@ -62,10 +63,11 @@ def init_training_settings(self):
self.net_g_ema = define_network(deepcopy(self.opt['network_g'])).to(
self.device)
# load pretrained model
- load_path = self.opt['path'].get('pretrain_model_g', None)
+ load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g_ema, load_path,
- self.opt['path']['strict_load'], 'params_ema')
+ self.opt['path'].get('strict_load_g',
+ True), 'params_ema')
else:
self.model_ema(0) # copy net_g weight
@@ -311,10 +313,10 @@ def nondist_validation(self, dataloader, current_iter, tb_logger,
else:
save_img_path = osp.join(self.opt['path']['visualization'], 'test',
f'test_{self.opt["name"]}.png')
- mmcv.imwrite(result, save_img_path)
+ imwrite(result, save_img_path)
# add sample images to tb_logger
result = (result / 255.).astype(np.float32)
- result = mmcv.bgr2rgb(result)
+ result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
if tb_logger is not None:
tb_logger.add_image(
'samples', result, global_step=current_iter, dataformats='HWC')
diff --git a/basicsr/models/video_base_model.py b/basicsr/models/video_base_model.py
index 6e70eed..c8e8d26 100644
--- a/basicsr/models/video_base_model.py
+++ b/basicsr/models/video_base_model.py
@@ -1,14 +1,14 @@
import importlib
-import mmcv
import torch
from collections import Counter
from copy import deepcopy
-from mmcv.runner import get_dist_info
from os import path as osp
from torch import distributed as dist
+from tqdm import tqdm
from basicsr.models.sr_model import SRModel
-from basicsr.utils import ProgressBar, get_root_logger, tensor2img
+from basicsr.utils import get_root_logger, imwrite, tensor2img
+from basicsr.utils.dist_util import get_dist_info
metric_module = importlib.import_module('basicsr.metrics')
@@ -34,13 +34,13 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
len(self.opt['val']['metrics']),
dtype=torch.float32,
device='cuda')
-
rank, world_size = get_dist_info()
- for _, tensor in self.metric_results.items():
- tensor.zero_()
+ if with_metrics:
+ for _, tensor in self.metric_results.items():
+ tensor.zero_()
# record all frames (border and center frames)
if rank == 0:
- pbar = ProgressBar(len(dataset))
+ pbar = tqdm(total=len(dataset), unit='frame')
for idx in range(rank, len(dataset), world_size):
val_data = dataset[idx]
val_data['lq'].unsqueeze_(0)
@@ -83,7 +83,7 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
save_img_path = osp.join(
self.opt['path']['visualization'], dataset_name,
folder, f'{img_name}_{self.opt["name"]}.png')
- mmcv.imwrite(result_img, save_img_path)
+ imwrite(result_img, save_img_path)
if with_metrics:
# calculate metrics
@@ -98,8 +98,12 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
# progress bar
if rank == 0:
for _ in range(world_size):
- pbar.update(f'Test {folder} - '
- f'{int(frame_idx) + world_size}/{max_idx}')
+ pbar.update(1)
+ pbar.set_description(
+ f'Test {folder}:'
+ f'{int(frame_idx) + world_size}/{max_idx}')
+ if rank == 0:
+ pbar.close()
if with_metrics:
if self.opt['dist']:
diff --git a/basicsr/models/video_gan_model.py b/basicsr/models/video_gan_model.py
index 94ccf4b..290434b 100644
--- a/basicsr/models/video_gan_model.py
+++ b/basicsr/models/video_gan_model.py
@@ -1,142 +1,15 @@
-import importlib
-import torch
-from collections import OrderedDict
-from copy import deepcopy
-
-from basicsr.models.archs import define_network
+from basicsr.models.srgan_model import SRGANModel
from basicsr.models.video_base_model import VideoBaseModel
-loss_module = importlib.import_module('basicsr.models.losses')
-
-
-class VideoGANModel(VideoBaseModel):
- """Video GAN model."""
-
- def init_training_settings(self):
- train_opt = self.opt['train']
-
- # define network net_d
- self.net_d = define_network(deepcopy(self.opt['network_d']))
- self.net_d = self.model_to_device(self.net_d)
- self.print_network(self.net_d)
-
- # load pretrained models
- load_path = self.opt['path'].get('pretrain_model_d', None)
- if load_path is not None:
- self.load_network(self.net_d, load_path,
- self.opt['path']['strict_load'])
-
- self.net_g.train()
- self.net_d.train()
-
- # define losses
- if train_opt.get('pixel_opt'):
- pixel_type = train_opt['pixel_opt'].pop('type')
- cri_pix_cls = getattr(loss_module, pixel_type)
- self.cri_pix = cri_pix_cls(**train_opt['pixel_opt']).to(
- self.device)
- else:
- self.cri_pix = None
-
- if train_opt.get('perceptual_opt'):
- percep_type = train_opt['perceptual_opt'].pop('type')
- cri_perceptual_cls = getattr(loss_module, percep_type)
- self.cri_perceptual = cri_perceptual_cls(
- **train_opt['perceptual_opt']).to(self.device)
- else:
- self.cri_perceptual = None
-
- if train_opt.get('gan_opt'):
- gan_type = train_opt['gan_opt'].pop('type')
- cri_gan_cls = getattr(loss_module, gan_type)
- self.cri_gan = cri_gan_cls(**train_opt['gan_opt']).to(self.device)
-
- self.net_d_iters = train_opt.get('net_d_iters', 1)
- self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
-
- # set up optimizers and schedulers
- self.setup_optimizers()
- self.setup_schedulers()
-
- def setup_optimizers(self):
- train_opt = self.opt['train']
- # optimizer g
- optim_type = train_opt['optim_g'].pop('type')
- if optim_type == 'Adam':
- self.optimizer_g = torch.optim.Adam(self.net_g.parameters(),
- **train_opt['optim_g'])
- else:
- raise NotImplementedError(
- f'optimizer {optim_type} is not supperted yet.')
- self.optimizers.append(self.optimizer_g)
- # optimizer d
- optim_type = train_opt['optim_d'].pop('type')
- if optim_type == 'Adam':
- self.optimizer_d = torch.optim.Adam(self.net_d.parameters(),
- **train_opt['optim_d'])
- else:
- raise NotImplementedError(
- f'optimizer {optim_type} is not supperted yet.')
- self.optimizers.append(self.optimizer_d)
-
- def optimize_parameters(self, current_iter):
- # optimize net_g
- for p in self.net_d.parameters():
- p.requires_grad = False
-
- self.optimizer_g.zero_grad()
- self.output = self.net_g(self.lq)
-
- l_g_total = 0
- loss_dict = OrderedDict()
- if (current_iter % self.net_d_iters == 0
- and current_iter > self.net_d_init_iters):
- # pixel loss
- if self.cri_pix:
- l_g_pix = self.cri_pix(self.output, self.gt)
- l_g_total += l_g_pix
- loss_dict['l_g_pix'] = l_g_pix
- # perceptual loss
- if self.cri_perceptual:
- l_g_percep, l_g_style = self.cri_perceptual(
- self.output, self.gt)
- if l_g_percep is not None:
- l_g_total += l_g_percep
- loss_dict['l_g_percep'] = l_g_percep
- if l_g_style is not None:
- l_g_total += l_g_style
- loss_dict['l_g_style'] = l_g_style
- # gan loss
- fake_g_pred = self.net_d(self.output)
- l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
- l_g_total += l_g_gan
- loss_dict['l_g_gan'] = l_g_gan
-
- l_g_total.backward()
- self.optimizer_g.step()
-
- # optimize net_d
- for p in self.net_d.parameters():
- p.requires_grad = True
-
- self.optimizer_d.zero_grad()
- # real
- real_d_pred = self.net_d(self.gt)
- l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
- loss_dict['l_d_real'] = l_d_real
- loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
- l_d_real.backward()
- # fake
- fake_d_pred = self.net_d(self.output.detach())
- l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
- loss_dict['l_d_fake'] = l_d_fake
- loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
- l_d_fake.backward()
- self.optimizer_d.step()
- self.log_dict = self.reduce_loss_dict(loss_dict)
+class VideoGANModel(SRGANModel, VideoBaseModel):
+ """Video GAN model.
- def save(self, epoch, current_iter):
- self.save_network(self.net_g, 'net_g', current_iter)
- self.save_network(self.net_d, 'net_d', current_iter)
- self.save_training_state(epoch, current_iter)
+ Use multiple inheritance.
+ It will first use the functions of SRGANModel:
+ init_training_settings
+ setup_optimizers
+ optimize_parameters
+ save
+ Then find functions in VideoBaseModel.
+ """
diff --git a/basicsr/test.py b/basicsr/test.py
index 7bdae15..622df4e 100644
--- a/basicsr/test.py
+++ b/basicsr/test.py
@@ -1,46 +1,23 @@
-import argparse
import logging
-import random
import torch
-from mmcv.runner import get_dist_info, get_time_str, init_dist
from os import path as osp
from basicsr.data import create_dataloader, create_dataset
from basicsr.models import create_model
-from basicsr.utils import (get_env_info, get_root_logger, make_exp_dirs,
- set_random_seed)
-from basicsr.utils.options import dict2str, parse
+from basicsr.train import parse_options
+from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
+ make_exp_dirs)
+from basicsr.utils.options import dict2str
def main():
- # options
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '-opt', type=str, required=True, help='Path to option YAML file.')
- parser.add_argument(
- '--launcher',
- choices=['none', 'pytorch', 'slurm'],
- default='none',
- help='job launcher')
- parser.add_argument('--local_rank', type=int, default=0)
- args = parser.parse_args()
- opt = parse(args.opt, is_train=False)
+ # parse options, set distributed setting, set ramdom seed
+ opt = parse_options(is_train=False)
- # distributed testing settings
- if args.launcher == 'none': # non-distributed testing
- opt['dist'] = False
- print('Disable distributed testing.', flush=True)
- else:
- opt['dist'] = True
- if args.launcher == 'slurm' and 'dist_params' in opt:
- init_dist(args.launcher, **opt['dist_params'])
- else:
- init_dist(args.launcher)
-
- rank, world_size = get_dist_info()
- opt['rank'] = rank
- opt['world_size'] = world_size
+ torch.backends.cudnn.benchmark = True
+ # torch.backends.cudnn.deterministic = True
+ # mkdir and initialize loggers
make_exp_dirs(opt)
log_file = osp.join(opt['path']['log'],
f"test_{opt['name']}_{get_time_str()}.log")
@@ -49,17 +26,6 @@ def main():
logger.info(get_env_info())
logger.info(dict2str(opt))
- # random seed
- seed = opt['manual_seed']
- if seed is None:
- seed = random.randint(1, 10000)
- opt['manual_seed'] = seed
- logger.info(f'Random seed: {seed}')
- set_random_seed(seed + rank)
-
- torch.backends.cudnn.benchmark = True
- # torch.backends.cudnn.deterministic = True
-
# create test dataset and dataloader
test_loaders = []
for phase, dataset_opt in sorted(opt['datasets'].items()):
@@ -70,7 +36,7 @@ def main():
num_gpu=opt['num_gpu'],
dist=opt['dist'],
sampler=None,
- seed=seed)
+ seed=opt['manual_seed'])
logger.info(
f"Number of test images in {dataset_opt['name']}: {len(test_set)}")
test_loaders.append(test_loader)
diff --git a/basicsr/train.py b/basicsr/train.py
index 0d769c8..02a460f 100644
--- a/basicsr/train.py
+++ b/basicsr/train.py
@@ -5,7 +5,6 @@
import random
import time
import torch
-from mmcv.runner import get_dist_info, get_time_str, init_dist
from os import path as osp
from basicsr.data import create_dataloader, create_dataset
@@ -13,13 +12,14 @@
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.models import create_model
from basicsr.utils import (MessageLogger, check_resume, get_env_info,
- get_root_logger, init_tb_logger, init_wandb_logger,
- make_exp_dirs, mkdir_and_rename, set_random_seed)
+ get_root_logger, get_time_str, init_tb_logger,
+ init_wandb_logger, make_exp_dirs, mkdir_and_rename,
+ set_random_seed)
+from basicsr.utils.dist_util import get_dist_info, init_dist
from basicsr.utils.options import dict2str, parse
-def main():
- # options
+def parse_options(is_train=True):
parser = argparse.ArgumentParser()
parser.add_argument(
'-opt', type=str, required=True, help='Path to option YAML file.')
@@ -30,12 +30,12 @@ def main():
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
- opt = parse(args.opt, is_train=True)
+ opt = parse(args.opt, is_train=is_train)
- # distributed training settings
- if args.launcher == 'none': # non-distributed training
+ # distributed settings
+ if args.launcher == 'none':
opt['dist'] = False
- print('Disable distributed training.', flush=True)
+ print('Disable distributed.', flush=True)
else:
opt['dist'] = True
if args.launcher == 'slurm' and 'dist_params' in opt:
@@ -43,68 +43,55 @@ def main():
else:
init_dist(args.launcher)
- rank, world_size = get_dist_info()
- opt['rank'] = rank
- opt['world_size'] = world_size
+ opt['rank'], opt['world_size'] = get_dist_info()
- # load resume states if exists
- if opt['path'].get('resume_state'):
- device_id = torch.cuda.current_device()
- resume_state = torch.load(
- opt['path']['resume_state'],
- map_location=lambda storage, loc: storage.cuda(device_id))
- else:
- resume_state = None
+ # random seed
+ seed = opt.get('manual_seed')
+ if seed is None:
+ seed = random.randint(1, 10000)
+ opt['manual_seed'] = seed
+ set_random_seed(seed + opt['rank'])
- # mkdir and loggers
- if resume_state is None:
- make_exp_dirs(opt)
+ return opt
+
+
+def init_loggers(opt):
log_file = osp.join(opt['path']['log'],
f"train_{opt['name']}_{get_time_str()}.log")
logger = get_root_logger(
logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
logger.info(get_env_info())
logger.info(dict2str(opt))
+
# initialize tensorboard logger and wandb logger
tb_logger = None
if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
- log_dir = './tb_logger/' + opt['name']
- if resume_state is None and opt['rank'] == 0:
- mkdir_and_rename(log_dir)
- tb_logger = init_tb_logger(log_dir=log_dir)
+ tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
if (opt['logger'].get('wandb')
is not None) and (opt['logger']['wandb'].get('project')
is not None) and ('debug' not in opt['name']):
assert opt['logger'].get('use_tb_logger') is True, (
'should turn on tensorboard when using wandb')
init_wandb_logger(opt)
+ return logger, tb_logger
- # random seed
- seed = opt['manual_seed']
- if seed is None:
- seed = random.randint(1, 10000)
- opt['manual_seed'] = seed
- logger.info(f'Random seed: {seed}')
- set_random_seed(seed + rank)
-
- torch.backends.cudnn.benchmark = True
- # torch.backends.cudnn.deterministic = True
+def create_train_val_dataloader(opt, logger):
# create train and val dataloaders
train_loader, val_loader = None, None
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train':
dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
train_set = create_dataset(dataset_opt)
- train_sampler = EnlargedSampler(train_set, world_size, rank,
- dataset_enlarge_ratio)
+ train_sampler = EnlargedSampler(train_set, opt['world_size'],
+ opt['rank'], dataset_enlarge_ratio)
train_loader = create_dataloader(
train_set,
dataset_opt,
num_gpu=opt['num_gpu'],
dist=opt['dist'],
sampler=train_sampler,
- seed=seed)
+ seed=opt['manual_seed'])
num_iter_per_epoch = math.ceil(
len(train_set) * dataset_enlarge_ratio /
@@ -119,6 +106,7 @@ def main():
f'\n\tWorld size (gpu number): {opt["world_size"]}'
f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
+
elif phase == 'val':
val_set = create_dataset(dataset_opt)
val_loader = create_dataloader(
@@ -127,27 +115,57 @@ def main():
num_gpu=opt['num_gpu'],
dist=opt['dist'],
sampler=None,
- seed=seed)
+ seed=opt['manual_seed'])
logger.info(
f'Number of val images/folders in {dataset_opt["name"]}: '
f'{len(val_set)}')
else:
raise ValueError(f'Dataset phase {phase} is not recognized.')
- assert train_loader is not None
- # create model
- if resume_state:
- check_resume(opt, resume_state['iter']) # modify pretrain_model paths
- model = create_model(opt)
+ return train_loader, train_sampler, val_loader, total_epochs, total_iters
+
+
+def main():
+ # parse options, set distributed setting, set ramdom seed
+ opt = parse_options(is_train=True)
+
+ torch.backends.cudnn.benchmark = True
+ # torch.backends.cudnn.deterministic = True
+
+ # load resume states if necessary
+ if opt['path'].get('resume_state'):
+ device_id = torch.cuda.current_device()
+ resume_state = torch.load(
+ opt['path']['resume_state'],
+ map_location=lambda storage, loc: storage.cuda(device_id))
+ else:
+ resume_state = None
+
+ # mkdir for experiments and logger
+ if resume_state is None:
+ make_exp_dirs(opt)
+ if opt['logger'].get('use_tb_logger') and 'debug' not in opt[
+ 'name'] and opt['rank'] == 0:
+ mkdir_and_rename(osp.join('tb_logger', opt['name']))
+
+ # initialize loggers
+ logger, tb_logger = init_loggers(opt)
+
+ # create train and validation dataloaders
+ result = create_train_val_dataloader(opt, logger)
+ train_loader, train_sampler, val_loader, total_epochs, total_iters = result
- # resume training
- if resume_state:
+ # create model
+ if resume_state: # resume training
+ check_resume(opt, resume_state['iter'])
+ model = create_model(opt)
+ model.resume_training(resume_state) # handle optimizers and schedulers
logger.info(f"Resuming training from epoch: {resume_state['epoch']}, "
f"iter: {resume_state['iter']}.")
start_epoch = resume_state['epoch']
current_iter = resume_state['iter']
- model.resume_training(resume_state) # handle optimizers and schedulers
else:
+ model = create_model(opt)
start_epoch = 0
current_iter = 0
diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py
index 95f7a50..2b91571 100644
--- a/basicsr/utils/__init__.py
+++ b/basicsr/utils/__init__.py
@@ -1,12 +1,31 @@
from .file_client import FileClient
+from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
from .logger import (MessageLogger, get_env_info, get_root_logger,
init_tb_logger, init_wandb_logger)
-from .util import (ProgressBar, check_resume, crop_border, make_exp_dirs,
- mkdir_and_rename, set_random_seed, tensor2img)
+from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename,
+ scandir, set_random_seed, sizeof_fmt)
__all__ = [
- 'FileClient', 'MessageLogger', 'get_root_logger', 'make_exp_dirs',
- 'init_tb_logger', 'init_wandb_logger', 'set_random_seed', 'ProgressBar',
- 'tensor2img', 'crop_border', 'check_resume', 'mkdir_and_rename',
- 'get_env_info'
+ # file_client.py
+ 'FileClient',
+ # img_util.py
+ 'img2tensor',
+ 'tensor2img',
+ 'imfrombytes',
+ 'imwrite',
+ 'crop_border',
+ # logger.py
+ 'MessageLogger',
+ 'init_tb_logger',
+ 'init_wandb_logger',
+ 'get_root_logger',
+ 'get_env_info',
+ # misc.py
+ 'set_random_seed',
+ 'get_time_str',
+ 'mkdir_and_rename',
+ 'make_exp_dirs',
+ 'scandir',
+ 'check_resume',
+ 'sizeof_fmt'
]
diff --git a/basicsr/utils/dist_util.py b/basicsr/utils/dist_util.py
new file mode 100644
index 0000000..43cf4cd
--- /dev/null
+++ b/basicsr/utils/dist_util.py
@@ -0,0 +1,83 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
+import functools
+import os
+import subprocess
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(
+ f'scontrol show hostname {node_list} | head -n1')
+ # specify master port
+ if port is not None:
+ os.environ['MASTER_PORT'] = str(port)
+ elif 'MASTER_PORT' in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+ if dist.is_available():
+ initialized = dist.is_initialized()
+ else:
+ initialized = False
+ if initialized:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def master_only(func):
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
diff --git a/basicsr/utils/download.py b/basicsr/utils/download_util.py
similarity index 73%
rename from basicsr/utils/download.py
rename to basicsr/utils/download_util.py
index e03516c..64a0016 100644
--- a/basicsr/utils/download.py
+++ b/basicsr/utils/download_util.py
@@ -1,7 +1,8 @@
import math
import requests
+from tqdm import tqdm
-from basicsr.utils import ProgressBar
+from .misc import sizeof_fmt
def download_file_from_google_drive(file_id, save_path):
@@ -49,7 +50,8 @@ def save_response_content(response,
file_size=None,
chunk_size=32768):
if file_size is not None:
- pbar = ProgressBar(math.ceil(file_size / chunk_size))
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
+
readable_file_size = sizeof_fmt(file_size)
else:
pbar = None
@@ -59,24 +61,10 @@ def save_response_content(response,
for chunk in response.iter_content(chunk_size):
downloaded_size += chunk_size
if pbar is not None:
- pbar.update(f'Downloading {sizeof_fmt(downloaded_size)} '
- f'/ {readable_file_size}')
+ pbar.update(1)
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
+ f'/ {readable_file_size}')
if chunk: # filter out keep-alive new chunks
f.write(chunk)
-
-
-def sizeof_fmt(size, suffix='B'):
- """Get human readable file size.
-
- Args:
- size (int): File size.
- suffix (str): Suffix. Default: 'B'.
-
- Return:
- str: Formated file siz.
- """
- for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
- if abs(size) < 1024.0:
- return f'{size:3.1f} {unit}{suffix}'
- size /= 1024.0
- return f'{size:3.1f} Y{suffix}'
+ if pbar is not None:
+ pbar.close()
diff --git a/basicsr/utils/face_util.py b/basicsr/utils/face_util.py
new file mode 100644
index 0000000..33fe178
--- /dev/null
+++ b/basicsr/utils/face_util.py
@@ -0,0 +1,217 @@
+import cv2
+import numpy as np
+import os
+import torch
+from skimage import transform as trans
+
+from basicsr.utils import imwrite
+
+try:
+ import dlib
+except ImportError:
+ print('Please install dlib before testing face restoration.'
+ 'Reference: https://github.com/davisking/dlib')
+
+
+class FaceRestorationHelper(object):
+ """Helper for the face restoration pipeline."""
+
+ def __init__(self, upscale_factor, face_size=512):
+ self.upscale_factor = upscale_factor
+ self.face_size = (face_size, face_size)
+
+ # standard 5 landmarks for FFHQ faces with 1024 x 1024
+ self.face_template = np.array([[686.77227723, 488.62376238],
+ [586.77227723, 493.59405941],
+ [337.91089109, 488.38613861],
+ [437.95049505, 493.51485149],
+ [513.58415842, 678.5049505]])
+ self.face_template = self.face_template / (1024 // face_size)
+ # for estimation the 2D similarity transformation
+ self.similarity_trans = trans.SimilarityTransform()
+
+ self.all_landmarks_5 = []
+ self.all_landmarks_68 = []
+ self.affine_matrices = []
+ self.inverse_affine_matrices = []
+ self.cropped_faces = []
+ self.restored_faces = []
+ self.save_png = True
+
+ def init_dlib(self, detection_path, landmark5_path, landmark68_path):
+ """Initialize the dlib detectors and predictors."""
+ self.face_detector = dlib.cnn_face_detection_model_v1(detection_path)
+ self.shape_predictor_5 = dlib.shape_predictor(landmark5_path)
+ self.shape_predictor_68 = dlib.shape_predictor(landmark68_path)
+
+ def free_dlib_gpu_memory(self):
+ del self.face_detector
+ del self.shape_predictor_5
+ del self.shape_predictor_68
+
+ def read_input_image(self, img_path):
+ # self.input_img is Numpy array, (h, w, c) with RGB order
+ self.input_img = dlib.load_rgb_image(img_path)
+
+ def detect_faces(self,
+ img_path,
+ upsample_num_times=1,
+ only_keep_largest=False):
+ """
+ Args:
+ img_path (str): Image path.
+ upsample_num_times (int): Upsamples the image before running the
+ face detector
+
+ Returns:
+ int: Number of detected faces.
+ """
+ self.read_input_image(img_path)
+ det_faces = self.face_detector(self.input_img, upsample_num_times)
+ if len(det_faces) == 0:
+ print('No face detected. Try to increase upsample_num_times.')
+ else:
+ if only_keep_largest:
+ print('Detect several faces and only keep the largest.')
+ face_areas = []
+ for i in range(len(det_faces)):
+ face_area = (det_faces[i].rect.right() -
+ det_faces[i].rect.left()) * (
+ det_faces[i].rect.bottom() -
+ det_faces[i].rect.top())
+ face_areas.append(face_area)
+ largest_idx = face_areas.index(max(face_areas))
+ self.det_faces = [det_faces[largest_idx]]
+ else:
+ self.det_faces = det_faces
+ return len(self.det_faces)
+
+ def get_face_landmarks_5(self):
+ for face in self.det_faces:
+ shape = self.shape_predictor_5(self.input_img, face.rect)
+ landmark = np.array([[part.x, part.y] for part in shape.parts()])
+ self.all_landmarks_5.append(landmark)
+ return len(self.all_landmarks_5)
+
+ def get_face_landmarks_68(self):
+ """Get 68 densemarks for cropped images.
+
+ Should only have one face at most in the cropped image.
+ """
+ num_detected_face = 0
+ for idx, face in enumerate(self.cropped_faces):
+ # face detection
+ det_face = self.face_detector(face, 1) # TODO: can we remove it?
+ if len(det_face) == 0:
+ print(f'Cannot find faces in cropped image with index {idx}.')
+ self.all_landmarks_68.append(None)
+ else:
+ if len(det_face) > 1:
+ print('Detect several faces in the cropped face. Use the '
+ ' largest one. Note that it will also cause overlap '
+ 'during paste_faces_to_input_image.')
+ face_areas = []
+ for i in range(len(det_face)):
+ face_area = (det_face[i].rect.right() -
+ det_face[i].rect.left()) * (
+ det_face[i].rect.bottom() -
+ det_face[i].rect.top())
+ face_areas.append(face_area)
+ largest_idx = face_areas.index(max(face_areas))
+ face_rect = det_face[largest_idx].rect
+ else:
+ face_rect = det_face[0].rect
+ shape = self.shape_predictor_68(face, face_rect)
+ landmark = np.array([[part.x, part.y]
+ for part in shape.parts()])
+ self.all_landmarks_68.append(landmark)
+ num_detected_face += 1
+
+ return num_detected_face
+
+ def warp_crop_faces(self,
+ save_cropped_path=None,
+ save_inverse_affine_path=None):
+ """Get affine matrix, warp and cropped faces.
+
+ Also get inverse affine matrix for post-processing.
+ """
+ for idx, landmark in enumerate(self.all_landmarks_5):
+ # use 5 landmarks to get affine matrix
+ self.similarity_trans.estimate(landmark, self.face_template)
+ affine_matrix = self.similarity_trans.params[0:2, :]
+ self.affine_matrices.append(affine_matrix)
+ # warp and crop faces
+ cropped_face = cv2.warpAffine(self.input_img, affine_matrix,
+ self.face_size)
+ self.cropped_faces.append(cropped_face)
+ # save the cropped face
+ if save_cropped_path is not None:
+ path, ext = os.path.splitext(save_cropped_path)
+ if self.save_png:
+ save_path = f'{path}_{idx:02d}.png'
+ else:
+ save_path = f'{path}_{idx:02d}{ext}'
+
+ imwrite(
+ cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path)
+
+ # get inverse affine matrix
+ self.similarity_trans.estimate(self.face_template,
+ landmark * self.upscale_factor)
+ inverse_affine = self.similarity_trans.params[0:2, :]
+ self.inverse_affine_matrices.append(inverse_affine)
+ # save inverse affine matrices
+ if save_inverse_affine_path is not None:
+ path, _ = os.path.splitext(save_inverse_affine_path)
+ save_path = f'{path}_{idx:02d}.pth'
+ torch.save(inverse_affine, save_path)
+
+ def add_restored_face(self, face):
+ self.restored_faces.append(face)
+
+ def paste_faces_to_input_image(self, save_path):
+ # operate in the BGR order
+ input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR)
+ h, w, _ = input_img.shape
+ h_up, w_up = h * self.upscale_factor, w * self.upscale_factor
+ # simply resize the background
+ upsample_img = cv2.resize(input_img, (w_up, h_up))
+ assert len(self.restored_faces) == len(self.inverse_affine_matrices), (
+ 'length of restored_faces and affine_matrices are different.')
+ for restored_face, inverse_affine in zip(self.restored_faces,
+ self.inverse_affine_matrices):
+ inv_restored = cv2.warpAffine(restored_face, inverse_affine,
+ (w_up, h_up))
+ mask = np.ones((*self.face_size, 3), dtype=np.float32)
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+ # remove the black borders
+ inv_mask_erosion = cv2.erode(
+ inv_mask,
+ np.ones((2 * self.upscale_factor, 2 * self.upscale_factor),
+ np.uint8))
+ inv_restored_remove_border = inv_mask_erosion * inv_restored
+ total_face_area = np.sum(inv_mask_erosion) // 3
+ # compute the fusion edge based on the area of face
+ w_edge = int(total_face_area**0.5) // 20
+ erosion_radius = w_edge * 2
+ inv_mask_center = cv2.erode(
+ inv_mask_erosion,
+ np.ones((erosion_radius, erosion_radius), np.uint8))
+ blur_size = w_edge * 2
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center,
+ (blur_size + 1, blur_size + 1), 0)
+ upsample_img = inv_soft_mask * inv_restored_remove_border + (
+ 1 - inv_soft_mask) * upsample_img
+ if self.save_png:
+ save_path = save_path.replace('.jpg',
+ '.png').replace('.jpeg', '.png')
+ imwrite(upsample_img.astype(np.uint8), save_path)
+
+ def clean_all(self):
+ self.all_landmarks_5 = []
+ self.all_landmarks_68 = []
+ self.restored_faces = []
+ self.affine_matrices = []
+ self.cropped_faces = []
+ self.inverse_affine_matrices = []
diff --git a/basicsr/utils/file_client.py b/basicsr/utils/file_client.py
index 1d8e5cf..066b22f 100644
--- a/basicsr/utils/file_client.py
+++ b/basicsr/utils/file_client.py
@@ -1,113 +1,183 @@
-from mmcv.fileio.file_client import (BaseStorageBackend, CephBackend,
- HardDiskBackend, MemcachedBackend)
-
-
-class LmdbBackend(BaseStorageBackend):
- """Lmdb storage backend.
-
- Args:
- db_paths (str | list[str]): Lmdb database paths.
- client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
- readonly (bool, optional): Lmdb environment parameter. If True,
- disallow any write operations. Default: True.
- lock (bool, optional): Lmdb environment parameter. If False, when
- concurrent access occurs, do not lock the database. Default: False.
- readahead (bool, optional): Lmdb environment parameter. If False,
- disable the OS filesystem readahead mechanism, which may improve
- random read performance when a database is larger than RAM.
- Default: False.
-
- Attributes:
- db_paths (list): Lmdb database path.
- _client (list): A list of several lmdb envs.
- """
-
- def __init__(self,
- db_paths,
- client_keys='default',
- readonly=True,
- lock=False,
- readahead=False,
- **kwargs):
- try:
- import lmdb
- except ImportError:
- raise ImportError('Please install lmdb to enable LmdbBackend.')
-
- if isinstance(client_keys, str):
- client_keys = [client_keys]
-
- if isinstance(db_paths, list):
- self.db_paths = [str(v) for v in db_paths]
- elif isinstance(db_paths, str):
- self.db_paths = [str(db_paths)]
- assert len(client_keys) == len(self.db_paths), (
- 'client_keys and db_paths should have the same length, '
- f'but received {len(client_keys)} and {len(self.db_paths)}.')
-
- self._client = {}
- for client, path in zip(client_keys, self.db_paths):
- self._client[client] = lmdb.open(
- path,
- readonly=readonly,
- lock=lock,
- readahead=readahead,
- **kwargs)
-
- def get(self, filepath, client_key):
- """Get values according to the filepath from one lmdb named client_key.
-
- Args:
- filepath (str | obj:`Path`): Here, filepath is the lmdb key.
- client_key (str): Used for distinguishing differnet lmdb envs.
- """
- filepath = str(filepath)
- assert client_key in self._client, (f'client_key {client_key} is not '
- 'in lmdb clients.')
- client = self._client[client_key]
- with client.begin(write=False) as txn:
- value_buf = txn.get(filepath.encode('ascii'))
- return value_buf
-
- def get_text(self, filepath):
- raise NotImplementedError
-
-
-class FileClient(object):
- """A general file client to access files in different backend.
-
- The client loads a file or text in a specified backend from its path
- and return it as a binary file. it can also register other backend
- accessor with a given name and backend class.
-
- Attributes:
- backend (str): The storage backend type. Options are "disk", "ceph",
- "memcached" and "lmdb".
- client (:obj:`BaseStorageBackend`): The backend object.
- """
-
- _backends = {
- 'disk': HardDiskBackend,
- 'ceph': CephBackend,
- 'memcached': MemcachedBackend,
- 'lmdb': LmdbBackend,
- }
-
- def __init__(self, backend='disk', **kwargs):
- if backend not in self._backends:
- raise ValueError(
- f'Backend {backend} is not supported. Currently supported ones'
- f' are {list(self._backends.keys())}')
- self.backend = backend
- self.client = self._backends[backend](**kwargs)
-
- def get(self, filepath, client_key='default'):
- # client_key is used only for lmdb, where different fileclients have
- # different lmdb environments.
- if self.backend == 'lmdb':
- return self.client.get(filepath, client_key)
- else:
- return self.client.get(filepath)
-
- def get_text(self, filepath):
- return self.client.get_text(filepath)
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
+from abc import ABCMeta, abstractmethod
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+ """Abstract class of storage backends.
+
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+ as texts.
+ """
+
+ @abstractmethod
+ def get(self, filepath):
+ pass
+
+ @abstractmethod
+ def get_text(self, filepath):
+ pass
+
+
+class MemcachedBackend(BaseStorageBackend):
+ """Memcached storage backend.
+
+ Attributes:
+ server_list_cfg (str): Config file for memcached server list.
+ client_cfg (str): Config file for memcached client.
+ sys_path (str | None): Additional path to be appended to `sys.path`.
+ Default: None.
+ """
+
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+ if sys_path is not None:
+ import sys
+ sys.path.append(sys_path)
+ try:
+ import mc
+ except ImportError:
+ raise ImportError(
+ 'Please install memcached to enable MemcachedBackend.')
+
+ self.server_list_cfg = server_list_cfg
+ self.client_cfg = client_cfg
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
+ self.client_cfg)
+ # mc.pyvector servers as a point which points to a memory cache
+ self._mc_buffer = mc.pyvector()
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ import mc
+ self._client.Get(filepath, self._mc_buffer)
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class HardDiskBackend(BaseStorageBackend):
+ """Raw hard disks storage backend."""
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'rb') as f:
+ value_buf = f.read()
+ return value_buf
+
+ def get_text(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'r') as f:
+ value_buf = f.read()
+ return value_buf
+
+
+class LmdbBackend(BaseStorageBackend):
+ """Lmdb storage backend.
+
+ Args:
+ db_paths (str | list[str]): Lmdb database paths.
+ client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
+ readonly (bool, optional): Lmdb environment parameter. If True,
+ disallow any write operations. Default: True.
+ lock (bool, optional): Lmdb environment parameter. If False, when
+ concurrent access occurs, do not lock the database. Default: False.
+ readahead (bool, optional): Lmdb environment parameter. If False,
+ disable the OS filesystem readahead mechanism, which may improve
+ random read performance when a database is larger than RAM.
+ Default: False.
+
+ Attributes:
+ db_paths (list): Lmdb database path.
+ _client (list): A list of several lmdb envs.
+ """
+
+ def __init__(self,
+ db_paths,
+ client_keys='default',
+ readonly=True,
+ lock=False,
+ readahead=False,
+ **kwargs):
+ try:
+ import lmdb
+ except ImportError:
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+ if isinstance(client_keys, str):
+ client_keys = [client_keys]
+
+ if isinstance(db_paths, list):
+ self.db_paths = [str(v) for v in db_paths]
+ elif isinstance(db_paths, str):
+ self.db_paths = [str(db_paths)]
+ assert len(client_keys) == len(self.db_paths), (
+ 'client_keys and db_paths should have the same length, '
+ f'but received {len(client_keys)} and {len(self.db_paths)}.')
+
+ self._client = {}
+ for client, path in zip(client_keys, self.db_paths):
+ self._client[client] = lmdb.open(
+ path,
+ readonly=readonly,
+ lock=lock,
+ readahead=readahead,
+ **kwargs)
+
+ def get(self, filepath, client_key):
+ """Get values according to the filepath from one lmdb named client_key.
+
+ Args:
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+ client_key (str): Used for distinguishing differnet lmdb envs.
+ """
+ filepath = str(filepath)
+ assert client_key in self._client, (f'client_key {client_key} is not '
+ 'in lmdb clients.')
+ client = self._client[client_key]
+ with client.begin(write=False) as txn:
+ value_buf = txn.get(filepath.encode('ascii'))
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class FileClient(object):
+ """A general file client to access files in different backend.
+
+ The client loads a file or text in a specified backend from its path
+ and return it as a binary file. it can also register other backend
+ accessor with a given name and backend class.
+
+ Attributes:
+ backend (str): The storage backend type. Options are "disk",
+ "memcached" and "lmdb".
+ client (:obj:`BaseStorageBackend`): The backend object.
+ """
+
+ _backends = {
+ 'disk': HardDiskBackend,
+ 'memcached': MemcachedBackend,
+ 'lmdb': LmdbBackend,
+ }
+
+ def __init__(self, backend='disk', **kwargs):
+ if backend not in self._backends:
+ raise ValueError(
+ f'Backend {backend} is not supported. Currently supported ones'
+ f' are {list(self._backends.keys())}')
+ self.backend = backend
+ self.client = self._backends[backend](**kwargs)
+
+ def get(self, filepath, client_key='default'):
+ # client_key is used only for lmdb, where different fileclients have
+ # different lmdb environments.
+ if self.backend == 'lmdb':
+ return self.client.get(filepath, client_key)
+ else:
+ return self.client.get(filepath)
+
+ def get_text(self, filepath):
+ return self.client.get_text(filepath)
diff --git a/basicsr/utils/flow_util.py b/basicsr/utils/flow_util.py
new file mode 100644
index 0000000..2b052cc
--- /dev/null
+++ b/basicsr/utils/flow_util.py
@@ -0,0 +1,180 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
+import cv2
+import numpy as np
+import os
+
+
+def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
+ """Read an optical flow map.
+
+ Args:
+ flow_path (ndarray or str): Flow path.
+ quantize (bool): whether to read quantized pair, if set to True,
+ remaining args will be passed to :func:`dequantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+
+ Returns:
+ ndarray: Optical flow represented as a (h, w, 2) numpy array
+ """
+ if quantize:
+ assert concat_axis in [0, 1]
+ cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
+ if cat_flow.ndim != 2:
+ raise IOError(f'{flow_path} is not a valid quantized flow file, '
+ f'its dimension is {cat_flow.ndim}.')
+ assert cat_flow.shape[concat_axis] % 2 == 0
+ dx, dy = np.split(cat_flow, 2, axis=concat_axis)
+ flow = dequantize_flow(dx, dy, *args, **kwargs)
+ else:
+ with open(flow_path, 'rb') as f:
+ try:
+ header = f.read(4).decode('utf-8')
+ except Exception:
+ raise IOError(f'Invalid flow file: {flow_path}')
+ else:
+ if header != 'PIEH':
+ raise IOError(f'Invalid flow file: {flow_path}, '
+ 'header does not contain PIEH')
+
+ w = np.fromfile(f, np.int32, 1).squeeze()
+ h = np.fromfile(f, np.int32, 1).squeeze()
+ flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
+
+ return flow.astype(np.float32)
+
+
+def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
+ """Write optical flow to file.
+
+ If the flow is not quantized, it will be saved as a .flo file losslessly,
+ otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
+ will be concatenated horizontally into a single image if quantize is True.)
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ filename (str): Output filepath.
+ quantize (bool): Whether to quantize the flow and save it to 2 jpeg
+ images. If set to True, remaining args will be passed to
+ :func:`quantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+ """
+ if not quantize:
+ with open(filename, 'wb') as f:
+ f.write('PIEH'.encode('utf-8'))
+ np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
+ flow = flow.astype(np.float32)
+ flow.tofile(f)
+ f.flush()
+ else:
+ assert concat_axis in [0, 1]
+ dx, dy = quantize_flow(flow, *args, **kwargs)
+ dxdy = np.concatenate((dx, dy), axis=concat_axis)
+ os.makedirs(filename, exist_ok=True)
+ cv2.imwrite(dxdy, filename)
+
+
+def quantize_flow(flow, max_val=0.02, norm=True):
+ """Quantize flow to [0, 255].
+
+ After this step, the size of flow will be much smaller, and can be
+ dumped as jpeg images.
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ max_val (float): Maximum value of flow, values beyond
+ [-max_val, max_val] will be truncated.
+ norm (bool): Whether to divide flow values by image width/height.
+
+ Returns:
+ tuple[ndarray]: Quantized dx and dy.
+ """
+ h, w, _ = flow.shape
+ dx = flow[..., 0]
+ dy = flow[..., 1]
+ if norm:
+ dx = dx / w # avoid inplace operations
+ dy = dy / h
+ # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
+ flow_comps = [
+ quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
+ ]
+ return tuple(flow_comps)
+
+
+def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
+ """Recover from quantized flow.
+
+ Args:
+ dx (ndarray): Quantized dx.
+ dy (ndarray): Quantized dy.
+ max_val (float): Maximum value used when quantizing.
+ denorm (bool): Whether to multiply flow values with width/height.
+
+ Returns:
+ ndarray: Dequantized flow.
+ """
+ assert dx.shape == dy.shape
+ assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
+
+ dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
+
+ if denorm:
+ dx *= dx.shape[1]
+ dy *= dx.shape[0]
+ flow = np.dstack((dx, dy))
+ return flow
+
+
+def quantize(arr, min_val, max_val, levels, dtype=np.int64):
+ """Quantize an array of (-inf, inf) to [0, levels-1].
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the quantized array.
+
+ Returns:
+ tuple: Quantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(
+ f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ arr = np.clip(arr, min_val, max_val) - min_val
+ quantized_arr = np.minimum(
+ np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
+
+ return quantized_arr
+
+
+def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
+ """Dequantize an array.
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the dequantized array.
+
+ Returns:
+ tuple: Dequantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(
+ f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
+ min_val) / levels + min_val
+
+ return dequantized_arr
diff --git a/basicsr/utils/img_util.py b/basicsr/utils/img_util.py
new file mode 100644
index 0000000..152be01
--- /dev/null
+++ b/basicsr/utils/img_util.py
@@ -0,0 +1,165 @@
+import cv2
+import math
+import numpy as np
+import os
+import torch
+from torchvision.utils import make_grid
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
+ """Convert torch Tensors into image numpy arrays.
+
+ After clamping to [min, max], values will be normalized to [0, 1].
+
+ Args:
+ tensor (Tensor or list[Tensor]): Accept shapes:
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
+ 2) 3D Tensor of shape (3/1 x H x W);
+ 3) 2D Tensor of shape (H x W).
+ Tensor channel should be in RGB order.
+ rgb2bgr (bool): Whether to change rgb to bgr.
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
+ to uint8 type with range [0, 255]; otherwise, float type with
+ range [0, 1]. Default: ``np.uint8``.
+ min_max (tuple[int]): min and max values for clamp.
+
+ Returns:
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
+ shape (H x W). The channel order is BGR.
+ """
+ if not (torch.is_tensor(tensor) or
+ (isinstance(tensor, list)
+ and all(torch.is_tensor(t) for t in tensor))):
+ raise TypeError(
+ f'tensor or list of tensors expected, got {type(tensor)}')
+
+ if torch.is_tensor(tensor):
+ tensor = [tensor]
+ result = []
+ for _tensor in tensor:
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
+
+ n_dim = _tensor.dim()
+ if n_dim == 4:
+ img_np = make_grid(
+ _tensor, nrow=int(math.sqrt(_tensor.size(0))),
+ normalize=False).numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 3:
+ img_np = _tensor.numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image
+ img_np = np.squeeze(img_np, axis=2)
+ else:
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 2:
+ img_np = _tensor.numpy()
+ else:
+ raise TypeError('Only support 4D, 3D or 2D tensor. '
+ f'But received with dimension: {n_dim}')
+ if out_type == np.uint8:
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
+ img_np = (img_np * 255.0).round()
+ img_np = img_np.astype(out_type)
+ result.append(img_np)
+ if len(result) == 1:
+ result = result[0]
+ return result
+
+
+def imfrombytes(content, flag='color', float32=False):
+ """Read an image from bytes.
+
+ Args:
+ content (bytes): Image bytes got from files or other streams.
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are `color`, `grayscale` and `unchanged`.
+ float32 (bool): Whether to change to float32., If True, will also norm
+ to [0, 1]. Default: False.
+
+ Returns:
+ ndarray: Loaded image array.
+ """
+ img_np = np.frombuffer(content, np.uint8)
+ imread_flags = {
+ 'color': cv2.IMREAD_COLOR,
+ 'grayscale': cv2.IMREAD_GRAYSCALE,
+ 'unchanged': cv2.IMREAD_UNCHANGED
+ }
+ img = cv2.imdecode(img_np, imread_flags[flag])
+ if float32:
+ img = img.astype(np.float32) / 255.
+ return img
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv's :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = os.path.abspath(os.path.dirname(file_path))
+ os.makedirs(dir_name, exist_ok=True)
+ return cv2.imwrite(file_path, img, params)
+
+
+def crop_border(imgs, crop_border):
+ """Crop borders of images.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
+ crop_border (int): Crop border for each end of height and weight.
+
+ Returns:
+ list[ndarray]: Cropped images.
+ """
+ if crop_border == 0:
+ return imgs
+ else:
+ if isinstance(imgs, list):
+ return [
+ v[crop_border:-crop_border, crop_border:-crop_border, ...]
+ for v in imgs
+ ]
+ else:
+ return imgs[crop_border:-crop_border, crop_border:-crop_border,
+ ...]
diff --git a/basicsr/utils/lmdb.py b/basicsr/utils/lmdb_util.py
similarity index 93%
rename from basicsr/utils/lmdb.py
rename to basicsr/utils/lmdb_util.py
index 8e3e99d..a81278f 100644
--- a/basicsr/utils/lmdb.py
+++ b/basicsr/utils/lmdb_util.py
@@ -1,11 +1,9 @@
import cv2
import lmdb
-import mmcv
import sys
from multiprocessing import Pool
from os import path as osp
-
-from .util import ProgressBar
+from tqdm import tqdm
def make_lmdb_from_imgs(data_path,
@@ -76,12 +74,13 @@ def make_lmdb_from_imgs(data_path,
dataset = {} # use dict to keep the order for multiprocessing
shapes = {}
print(f'Read images with multiprocessing, #thread: {n_thread} ...')
- pbar = ProgressBar(len(img_path_list))
+ pbar = tqdm(total=len(img_path_list), unit='image')
def callback(arg):
"""get the image data and update pbar."""
key, dataset[key], shapes[key] = arg
- pbar.update('Reading {}'.format(key))
+ pbar.update(1)
+ pbar.set_description(f'Read {key}')
pool = Pool(n_thread)
for path, key in zip(img_path_list, keys):
@@ -91,13 +90,14 @@ def callback(arg):
callback=callback)
pool.close()
pool.join()
+ pbar.close()
print(f'Finish reading {len(img_path_list)} images.')
# create lmdb environment
if map_size is None:
# obtain data size for one image
- img = mmcv.imread(
- osp.join(data_path, img_path_list[0]), flag='unchanged')
+ img = cv2.imread(
+ osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
_, img_byte = cv2.imencode(
'.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
data_size_per_img = img_byte.nbytes
@@ -108,11 +108,12 @@ def callback(arg):
env = lmdb.open(lmdb_path, map_size=map_size)
# write data to lmdb
- pbar = ProgressBar(len(img_path_list))
+ pbar = tqdm(total=len(img_path_list), unit='chunk')
txn = env.begin(write=True)
txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
for idx, (path, key) in enumerate(zip(img_path_list, keys)):
- pbar.update(f'Write {key}')
+ pbar.update(1)
+ pbar.set_description(f'Write {key}')
key_byte = key.encode('ascii')
if multiprocessing_read:
img_byte = dataset[key]
@@ -128,6 +129,7 @@ def callback(arg):
if idx % batch == 0:
txn.commit()
txn = env.begin(write=True)
+ pbar.close()
txn.commit()
env.close()
txt_file.close()
@@ -148,7 +150,7 @@ def read_img_worker(path, key, compress_level):
tuple[int]: Image shape.
"""
- img = mmcv.imread(path, flag='unchanged')
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img.ndim == 2:
h, w = img.shape
c = 1
diff --git a/basicsr/utils/logger.py b/basicsr/utils/logger.py
index 6aee50b..48671ed 100644
--- a/basicsr/utils/logger.py
+++ b/basicsr/utils/logger.py
@@ -1,7 +1,8 @@
import datetime
import logging
import time
-from mmcv.runner import get_dist_info, master_only
+
+from .dist_util import get_dist_info, master_only
class MessageLogger():
@@ -153,7 +154,6 @@ def get_env_info():
Currently, only log the software version.
"""
- import mmcv
import torch
import torchvision
@@ -173,6 +173,5 @@ def get_env_info():
msg += ('\nVersion Information: '
f'\n\tBasicSR: {__version__}'
f'\n\tPyTorch: {torch.__version__}'
- f'\n\tTorchVision: {torchvision.__version__}'
- f'\n\tMMCV: {mmcv.__version__}')
+ f'\n\tTorchVision: {torchvision.__version__}')
return msg
diff --git a/basicsr/utils/matlab_functions.py b/basicsr/utils/matlab_functions.py
new file mode 100644
index 0000000..cd96a2c
--- /dev/null
+++ b/basicsr/utils/matlab_functions.py
@@ -0,0 +1,361 @@
+import math
+import numpy as np
+import torch
+
+
+def cubic(x):
+ """cubic function used for calculate_weights_indices."""
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * (
+ (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx +
+ 2) * (((absx > 1) *
+ (absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel,
+ kernel_width, antialiasing):
+ """Calculate weights and indices, used for imresize function.
+
+ Args:
+ in_length (int): Input length.
+ out_length (int): Output length.
+ scale (float): Scale factor.
+ kernel_width (int): Kernel width.
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+ """
+
+ if (scale < 1) and antialiasing:
+ # Use a modified kernel (larger kernel width) to simultaneously
+ # interpolate and antialias
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5 + scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ p = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(
+ 0, p - 1, p).view(1, p).expand(out_length, p)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
+
+ # apply cubic kernel
+ if (scale < 1) and antialiasing:
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, p)
+
+ # If a column in weights is all zero, get rid of it. only consider the
+ # first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, p - 2)
+ weights = weights.narrow(1, 1, p - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, p - 2)
+ weights = weights.narrow(1, 0, p - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+@torch.no_grad()
+def imresize(img, scale, antialiasing=True):
+ """imresize function same as MATLAB.
+
+ It now only supports bicubic.
+ The same scale applies for both height and width.
+
+ Args:
+ img (Tensor | Numpy array):
+ Tensor: Input image with shape (c, h, w), [0, 1] range.
+ Numpy: Input image with shape (h, w, c), [0, 1] range.
+ scale (float): Scale factor. The same scale applies for both height
+ and width.
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+ Default: True.
+
+ Returns:
+ Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
+ """
+ if type(img).__module__ == np.__name__: # numpy type
+ numpy_type = True
+ img = torch.from_numpy(img.transpose(2, 0, 1)).float()
+ else:
+ numpy_type = False
+
+ in_c, in_h, in_w = img.size()
+ out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # get weights and indices
+ weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(
+ in_h, out_h, scale, kernel, kernel_width, antialiasing)
+ weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(
+ in_w, out_w, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
+ img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
+
+ sym_patch = img[:, :sym_len_hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_he:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_c, out_h, in_w)
+ kernel_width = weights_h.size(1)
+ for i in range(out_h):
+ idx = int(indices_h[i][0])
+ for j in range(in_c):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(
+ 0, 1).mv(weights_h[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
+ out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_we:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_c, out_h, out_w)
+ kernel_width = weights_w.size(1)
+ for i in range(out_w):
+ idx = int(indices_w[i][0])
+ for j in range(in_c):
+ out_2[j, :, i] = out_1_aug[j, :,
+ idx:idx + kernel_width].mv(weights_w[i])
+
+ if numpy_type:
+ out_2 = out_2.numpy().transpose(1, 2, 0)
+ return out_2
+
+
+def rgb2ycbcr(img, y_only=False):
+ """Convert a RGB image to YCbCr image.
+
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def bgr2ycbcr(img, y_only=False):
+ """Convert a BGR image to YCbCr image.
+
+ The bgr version of rgb2ycbcr.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2rgb(img):
+ """Convert a YCbCr image to RGB image.
+
+ This function produces the same results as Matlab's ycbcr2rgb function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted RGB image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+ [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [
+ -222.921, 135.576, -276.836
+ ] # noqa: E126
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2bgr(img):
+ """Convert a YCbCr image to BGR image.
+
+ The bgr version of ycbcr2rgb.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted BGR image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+ [0.00791071, -0.00153632, 0],
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [
+ -276.836, 135.576, -222.921
+ ] # noqa: E126
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def _convert_input_type_range(img):
+ """Convert the type and range of the input image.
+
+ It converts the input image to np.float32 type and range of [0, 1].
+ It is mainly used for pre-processing the input image in colorspace
+ convertion functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with type of np.float32 and range of
+ [0, 1].
+ """
+ img_type = img.dtype
+ img = img.astype(np.float32)
+ if img_type == np.float32:
+ pass
+ elif img_type == np.uint8:
+ img /= 255.
+ else:
+ raise TypeError('The img type should be np.float32 or np.uint8, '
+ f'but got {img_type}')
+ return img
+
+
+def _convert_output_type_range(img, dst_type):
+ """Convert the type and range of the image according to dst_type.
+
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
+ images will be converted to np.uint8 type with range [0, 255]. If
+ `dst_type` is np.float32, it converts the image to np.float32 type with
+ range [0, 1].
+ It is mainly used for post-processing images in colorspace convertion
+ functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The image to be converted with np.float32 type and
+ range [0, 255].
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+ converts the image to np.uint8 type with range [0, 255]. If
+ dst_type is np.float32, it converts the image to np.float32 type
+ with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with desired type and range.
+ """
+ if dst_type not in (np.uint8, np.float32):
+ raise TypeError('The dst_type should be np.float32 or np.uint8, '
+ f'but got {dst_type}')
+ if dst_type == np.uint8:
+ img = img.round()
+ else:
+ img /= 255.
+ return img.astype(dst_type)
diff --git a/basicsr/utils/misc.py b/basicsr/utils/misc.py
new file mode 100644
index 0000000..200527c
--- /dev/null
+++ b/basicsr/utils/misc.py
@@ -0,0 +1,139 @@
+import numpy as np
+import os
+import random
+import time
+import torch
+from os import path as osp
+
+from .dist_util import master_only
+from .logger import get_root_logger
+
+
+def set_random_seed(seed):
+ """Set random seeds."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def get_time_str():
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
+
+
+def mkdir_and_rename(path):
+ """mkdirs. If path exists, rename it with timestamp and create a new one.
+
+ Args:
+ path (str): Folder path.
+ """
+ if osp.exists(path):
+ new_name = path + '_archived_' + get_time_str()
+ print(f'Path already exists. Rename it to {new_name}', flush=True)
+ os.rename(path, new_name)
+ os.makedirs(path, exist_ok=True)
+
+
+@master_only
+def make_exp_dirs(opt):
+ """Make dirs for experiments."""
+ path_opt = opt['path'].copy()
+ if opt['is_train']:
+ mkdir_and_rename(path_opt.pop('experiments_root'))
+ else:
+ mkdir_and_rename(path_opt.pop('results_root'))
+ for key, path in path_opt.items():
+ if ('strict_load' not in key) and ('pretrain_network'
+ not in key) and ('resume'
+ not in key):
+ os.makedirs(path, exist_ok=True)
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+ """Scan a directory to find the interested files.
+
+ Args:
+ dir_path (str): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ full_path (bool, optional): If set to True, include the dir_path.
+ Default: False.
+
+ Returns:
+ A generator for all the interested files with relative pathes.
+ """
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ if full_path:
+ return_path = entry.path
+ else:
+ return_path = osp.relpath(entry.path, root)
+
+ if suffix is None:
+ yield return_path
+ elif return_path.endswith(suffix):
+ yield return_path
+ else:
+ if recursive:
+ yield from _scandir(
+ entry.path, suffix=suffix, recursive=recursive)
+ else:
+ continue
+
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
+
+
+def check_resume(opt, resume_iter):
+ """Check resume states and pretrain_network paths.
+
+ Args:
+ opt (dict): Options.
+ resume_iter (int): Resume iteration.
+ """
+ logger = get_root_logger()
+ if opt['path']['resume_state']:
+ # get all the networks
+ networks = [key for key in opt.keys() if key.startswith('network_')]
+ flag_pretrain = False
+ for network in networks:
+ if opt['path'].get(f'pretrain_{network}') is not None:
+ flag_pretrain = True
+ if flag_pretrain:
+ logger.warning(
+ 'pretrain_network path will be ignored during resuming.')
+ # set pretrained model paths
+ for network in networks:
+ name = f'pretrain_{network}'
+ basename = network.replace('network_', '')
+ if opt['path'].get('ignore_resume_networks') is None or (
+ basename not in opt['path']['ignore_resume_networks']):
+ opt['path'][name] = osp.join(
+ opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
+ logger.info(f"Set {name} to {opt['path'][name]}")
+
+
+def sizeof_fmt(size, suffix='B'):
+ """Get human readable file size.
+
+ Args:
+ size (int): File size.
+ suffix (str): Suffix. Default: 'B'.
+
+ Return:
+ str: Formated file siz.
+ """
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
+ if abs(size) < 1024.0:
+ return f'{size:3.1f} {unit}{suffix}'
+ size /= 1024.0
+ return f'{size:3.1f} Y{suffix}'
diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py
index f7717f0..3670d17 100644
--- a/basicsr/utils/options.py
+++ b/basicsr/utils/options.py
@@ -57,9 +57,10 @@ def parse(opt_path, is_train=True):
dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
# paths
- for key, path in opt['path'].items():
- if path and key != 'strict_load':
- opt['path'][key] = osp.expanduser(path)
+ for key, val in opt['path'].items():
+ if (val is not None) and ('resume_state' in key
+ or 'pretrain_network' in key):
+ opt['path'][key] = osp.expanduser(val)
opt['path']['root'] = osp.abspath(
osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
if is_train:
diff --git a/basicsr/utils/util.py b/basicsr/utils/util.py
deleted file mode 100644
index 7419e7b..0000000
--- a/basicsr/utils/util.py
+++ /dev/null
@@ -1,218 +0,0 @@
-import math
-import mmcv
-import numpy as np
-import os
-import random
-import sys
-import time
-import torch
-from mmcv.runner import get_time_str, master_only
-from os import path as osp
-from shutil import get_terminal_size
-from torchvision.utils import make_grid
-
-from basicsr.utils import get_root_logger
-
-
-def check_resume(opt, resume_iter):
- """Check resume states and pretrain_model paths.
-
- Args:
- opt (dict): Options.
- resume_iter (int): Resume iteration.
- """
- logger = get_root_logger()
- if opt['path']['resume_state']:
- # ignore pretrained model paths
- if opt['path'].get('pretrain_model_g') is not None or opt['path'].get(
- 'pretrain_model_d') is not None:
- logger.warning(
- 'pretrain_model path will be ignored during resuming.')
-
- # set pretrained model paths
- opt['path']['pretrain_model_g'] = osp.join(opt['path']['models'],
- f'net_g_{resume_iter}.pth')
- logger.info(
- f"Set pretrain_model_g to {opt['path']['pretrain_model_g']}")
-
- opt['path']['pretrain_model_d'] = osp.join(opt['path']['models'],
- f'net_d_{resume_iter}.pth')
- logger.info(
- f"Set pretrain_model_d to {opt['path']['pretrain_model_d']}")
-
-
-def mkdir_and_rename(path):
- """mkdirs. If path exists, rename it with timestamp and create a new one.
-
- Args:
- path (str): Folder path.
- """
- if osp.exists(path):
- new_name = path + '_archived_' + get_time_str()
- print(f'Path already exists. Rename it to {new_name}', flush=True)
- os.rename(path, new_name)
- mmcv.mkdir_or_exist(path)
-
-
-@master_only
-def make_exp_dirs(opt):
- """Make dirs for experiments."""
- path_opt = opt['path'].copy()
- if opt['is_train']:
- mkdir_and_rename(path_opt.pop('experiments_root'))
- else:
- mkdir_and_rename(path_opt.pop('results_root'))
- path_opt.pop('strict_load')
- for key, path in path_opt.items():
- if 'pretrain_model' not in key and 'resume' not in key:
- mmcv.mkdir_or_exist(path)
-
-
-def set_random_seed(seed):
- """Set random seeds."""
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
-
-
-def crop_border(imgs, crop_border):
- """Crop borders of images.
-
- Args:
- imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
- crop_border (int): Crop border for each end of height and weight.
-
- Returns:
- list[ndarray]: Cropped images.
- """
- if crop_border == 0:
- return imgs
- else:
- if isinstance(imgs, list):
- return [
- v[crop_border:-crop_border, crop_border:-crop_border, ...]
- for v in imgs
- ]
- else:
- return imgs[crop_border:-crop_border, crop_border:-crop_border,
- ...]
-
-
-def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
- """Convert torch Tensors into image numpy arrays.
-
- After clamping to [min, max], values will be normalized to [0, 1].
-
- Args:
- tensor (Tensor or list[Tensor]): Accept shapes:
- 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
- 2) 3D Tensor of shape (3/1 x H x W);
- 3) 2D Tensor of shape (H x W).
- Tensor channel should be in RGB order.
- out_type (numpy type): output types. If ``np.uint8``, transform outputs
- to uint8 type with range [0, 255]; otherwise, float type with
- range [0, 1]. Default: ``np.uint8``.
- min_max (tuple[int]): min and max values for clamp.
-
- Returns:
- (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
- shape (H x W). The channel order is BGR.
- """
- if not (torch.is_tensor(tensor) or
- (isinstance(tensor, list)
- and all(torch.is_tensor(t) for t in tensor))):
- raise TypeError(
- f'tensor or list of tensors expected, got {type(tensor)}')
-
- if torch.is_tensor(tensor):
- tensor = [tensor]
- result = []
- for _tensor in tensor:
- _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
- _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
-
- n_dim = _tensor.dim()
- if n_dim == 4:
- img_np = make_grid(
- _tensor, nrow=int(math.sqrt(_tensor.size(0))),
- normalize=False).numpy()
- img_np = np.transpose(img_np[[2, 1, 0], :, :],
- (1, 2, 0)) # HWC, BGR
- elif n_dim == 3:
- img_np = _tensor.numpy()
- img_np = np.transpose(img_np[[2, 1, 0], :, :],
- (1, 2, 0)) # HWC, BGR
- elif n_dim == 2:
- img_np = _tensor.numpy()
- else:
- raise TypeError('Only support 4D, 3D or 2D tensor. '
- f'But received with dimension: {n_dim}')
- if out_type == np.uint8:
- # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
- img_np = (img_np * 255.0).round()
- img_np = img_np.astype(out_type)
- result.append(img_np)
- if len(result) == 1:
- result = result[0]
- return result
-
-
-class ProgressBar(object):
- """A progress bar that can print the progress.
-
- Modified from:
- https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
- """
-
- def __init__(self, task_num=0, bar_width=50, start=True):
- self.task_num = task_num
- max_bar_width = self._get_max_bar_width()
- self.bar_width = (
- bar_width if bar_width <= max_bar_width else max_bar_width)
- self.completed = 0
- if start:
- self.start()
-
- def _get_max_bar_width(self):
- terminal_width, _ = get_terminal_size()
- max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
- if max_bar_width < 10:
- print(f'terminal width is too small ({terminal_width}), '
- 'please consider widen the terminal for better '
- 'progressbar visualization')
- max_bar_width = 10
- return max_bar_width
-
- def start(self):
- if self.task_num > 0:
- sys.stdout.write(f"[{' ' * self.bar_width}] 0/{self.task_num}, "
- f'elapsed: 0s, ETA:\nStart...\n')
- else:
- sys.stdout.write('completed: 0, elapsed: 0s')
- sys.stdout.flush()
- self.start_time = time.time()
-
- def update(self, msg='In progress...'):
- self.completed += 1
- elapsed = time.time() - self.start_time + 1e-8
- fps = self.completed / elapsed
- if self.task_num > 0:
- percentage = self.completed / float(self.task_num)
- eta = int(elapsed * (1 - percentage) / percentage + 0.5)
- mark_width = int(self.bar_width * percentage)
- bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
- sys.stdout.write('\033[2F') # cursor up 2 lines
- sys.stdout.write(
- '\033[J'
- ) # clean the output (remove extra chars since last display)
- sys.stdout.write(
- f'[{bar_chars}] {self.completed}/{self.task_num}, '
- f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, '
- f'ETA: {eta:5}s\n{msg}\n')
- else:
- sys.stdout.write(
- f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
- f' {fps:.1f} tasks/s')
- sys.stdout.flush()
diff --git a/colab/README.md b/colab/README.md
new file mode 100644
index 0000000..0e83739
--- /dev/null
+++ b/colab/README.md
@@ -0,0 +1,13 @@
+# Colab
+
+
+
+To maintain a small size of BasicSR repo, we do not include the original colab notebooks in this repo, but provide links to the google colab.
+
+| Face Restoration| |
+| :--- | :---: |
+|DFDNet | [BasicSR_inference_DFDNet.ipynb](https://colab.research.google.com/drive/1RoNDeipp9yPjI3EbpEbUhn66k5Uzg4n8?usp=sharing)|
+| **Super-Resolution**| |
+|ESRGAN |[BasicSR_inference_ESRGAN.ipynb](https://colab.research.google.com/drive/1JQScYICvEC3VqaabLu-lxvq9h7kSV1ML?usp=sharing)|
+| **Deblurring**| |
+| **Denoise**| |
diff --git a/docs/Config.md b/docs/Config.md
index f2a0775..5a3b04f 100644
--- a/docs/Config.md
+++ b/docs/Config.md
@@ -127,11 +127,11 @@ network_g:
#########################################################
path:
# Path for pretrained models, usually end with pth
- pretrain_model_g: ~
+ pretrain_network_g: ~
# Whether to load pretrained models strictly, that is the corresponding parameter names should be the same
- strict_load: true
+ strict_load_g: true
# Path for resume state. Usually in the `experiments/exp_name/training_states` folder
- # This argument will over-write the `pretrain_model_g`
+ # This argument will over-write the `pretrain_network_g`
resume_state: ~
@@ -302,9 +302,9 @@ network_g:
#################################################
path:
## Path for pretrained models, usually end with pth
- pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
+ pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
# Whether to load pretrained models strictly, that is the corresponding parameter names should be the same
- strict_load: true
+ strict_load_g: true
##########################################################
# The following are validation settings (Also for testing)
diff --git a/docs/Config_CN.md b/docs/Config_CN.md
index 6fa159d..6517110 100644
--- a/docs/Config_CN.md
+++ b/docs/Config_CN.md
@@ -126,11 +126,11 @@ network_g:
######################################
path:
# 预训练模型的路径, 需要以pth结尾的模型
- pretrain_model_g: ~
+ pretrain_network_g: ~
# 加载预训练模型的时候, 是否需要网络参数的名称严格对应
- strict_load: true
+ strict_load_g: true
# 重启训练的状态路径, 一般在`experiments/exp_name/training_states`目录下
- # 这个设置了, 会覆盖 pretrain_model_g 的设定
+ # 这个设置了, 会覆盖 pretrain_network_g 的设定
resume_state: ~
@@ -299,9 +299,9 @@ network_g:
#############################
path:
# 预训练模型的路径, 需要以pth结尾的模型
- pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
+ pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
# 加载预训练模型的时候, 是否需要网络参数的名称严格对应
- strict_load: true
+ strict_load_g: true
##################################
# 以下为Validation (也是测试)的设置
diff --git a/docs/DatasetPreparation.md b/docs/DatasetPreparation.md
index b579f58..207df31 100644
--- a/docs/DatasetPreparation.md
+++ b/docs/DatasetPreparation.md
@@ -24,7 +24,7 @@ At present, there are three types of data storage formats supported:
1. Store in `hard disk` directly in the format of images / video frames.
1. Make [LMDB](https://lmdb.readthedocs.io/en/release/), which could accelerate the IO and decompression speed during training.
-1. [memcached](https://memcached.org/) or [CEPH](https://ceph.io/) are also supported, if they are installed (usually on clusters).
+1. [memcached](https://memcached.org/) is also supported, if they are installed (usually on clusters).
#### How to Use
@@ -115,7 +115,7 @@ For convenience, the binary content stored in LMDB dataset is encoded image by c
**How to Make LMDB**
We provide a script to make LMDB. Before running the script, we need to modify the corresponding parameters accordingly. At present, we support DIV2K, REDS and Vimeo90K datasets; other datasets can also be made in a similar way.
- `python scripts/create_lmdb.py`
+ `python scripts/data_preparation/create_lmdb.py`
#### Data Pre-fetcher
@@ -155,17 +155,17 @@ It is recommended to symlink the dataset root to `datasets` with the command `ln
1. Download the datasets from the [official DIV2K website](https://data.vision.ee.ethz.ch/cvl/DIV2K/).
1. Crop to sub-images: DIV2K has 2K resolution (e.g., 2048 × 1080) images but the training patches are usually small (e.g., 128x128 or 192x192). So there is a waste if reading the whole image but only using a very small part of it. In order to accelerate the IO speed during training, we crop the 2K resolution images to sub-images (here, we crop to 480x480 sub-images).
Note that the size of sub-images is different from the training patch size (`gt_size`) defined in the config file. Specifically, the cropped sub-images with 480x480 are stored. The dataloader will further randomly crop the sub-images to `GT_size x GT_size` patches for training.
- Run the script [extract_subimages.py](../scripts/extract_subimages.py):
+ Run the script [extract_subimages.py](../scripts/data_preparation/extract_subimages.py):
```python
- python scripts/extract_subimages.py
+ python scripts/data_preparation/extract_subimages.py
```
Remember to modify the paths and configurations if you have different settings.
-1. [Optional] Create LMDB files. Please refer to [LMDB Description](#LMDB-Description). `python scripts/create_lmdb.py`. Use the `create_lmdb_for_div2k` function and remember to modify the paths and configurations accordingly.
+1. [Optional] Create LMDB files. Please refer to [LMDB Description](#LMDB-Description). `python scripts/data_preparation/create_lmdb.py`. Use the `create_lmdb_for_div2k` function and remember to modify the paths and configurations accordingly.
1. Test the dataloader with the script `tests/test_paired_image_dataset.py`.
Remember to modify the paths and configurations accordingly.
-1. [Optional] If you want to use meta_info_file, you may need to run `python scripts/generate_meta_info.py` to generate the meta_info_file.
+1. [Optional] If you want to use meta_info_file, you may need to run `python scripts/data_preparation/generate_meta_info.py` to generate the meta_info_file.
### Common Image SR Datasets
@@ -182,7 +182,7 @@ We provide a list of common image super-resolution datasets.
Classical SR Training |
T91 |
91 images for training |
- Google Drive / Baidu Drive |
+ Google Drive / Baidu Drive |
BSDS200 |
@@ -277,8 +277,8 @@ All the left clips are used for training. Note that it it not required to explic
**Preparation Steps**
1. Download the datasets from the [official website](https://seungjunnah.github.io/Datasets/reds.html).
-1. Regroup the training and validation datasets: `python scripts/regroup_reds_dataset.py`
-1. [Optional] Make LMDB files when necessary. Please refer to [LMDB Description](#LMDB-Description). `python scripts/create_lmdb.py`. Use the `create_lmdb_for_reds` function and remember to modify the paths and configurations accordingly.
+1. Regroup the training and validation datasets: `python scripts/data_preparation/regroup_reds_dataset.py`
+1. [Optional] Make LMDB files when necessary. Please refer to [LMDB Description](#LMDB-Description). `python scripts/data_preparation/create_lmdb.py`. Use the `create_lmdb_for_reds` function and remember to modify the paths and configurations accordingly.
1. Test the dataloader with the script `tests/test_reds_dataset.py`.
Remember to modify the paths and configurations accordingly.
@@ -289,7 +289,7 @@ Remember to modify the paths and configurations accordingly.
1. Download the dataset: [`Septuplets dataset --> The original training + test set (82GB)`](http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip).This is the Ground-Truth (GT). There is a `sep_trainlist.txt` file listing the training samples in the download zip file.
1. Generate the low-resolution images (TODO)
The low-resolution images in the Vimeo90K test dataset are generated with the MATLAB bicubic downsampling kernel. Use the script `data_scripts/generate_LR_Vimeo90K.m` (run in MATLAB) to generate the low-resolution images.
-1. [Optional] Make LMDB files when necessary. Please refer to [LMDB Description](#LMDB-Description). `python scripts/create_lmdb.py`. Use the `create_lmdb_for_vimeo90k` function and remember to modify the paths and configurations accordingly.
+1. [Optional] Make LMDB files when necessary. Please refer to [LMDB Description](#LMDB-Description). `python scripts/data_preparation/create_lmdb.py`. Use the `create_lmdb_for_vimeo90k` function and remember to modify the paths and configurations accordingly.
1. Test the dataloader with the script `tests/test_vimeo90k_dataset.py`.
Remember to modify the paths and configurations accordingly.
@@ -303,5 +303,5 @@ Training dataset: [FFHQ](https://github.com/NVlabs/ffhq-dataset).
1. Extract tfrecords to images or LMDBs. (TensorFlow is required to read tfrecords). For each resolution, we will create images folder or LMDB files separately.
```bash
- python scripts/extract_images_from_tfrecords.py
+ python scripts/data_preparation/extract_images_from_tfrecords.py
```
diff --git a/docs/DatasetPreparation_CN.md b/docs/DatasetPreparation_CN.md
index 7600256..b3e90a0 100644
--- a/docs/DatasetPreparation_CN.md
+++ b/docs/DatasetPreparation_CN.md
@@ -24,7 +24,7 @@
1. 直接以图像/视频帧的格式存放在硬盘
2. 制作成 [LMDB](https://lmdb.readthedocs.io/en/release/). 训练数据使用这种形式, 一般会加快读取速度.
-3. 若是支持 [Memcached](https://memcached.org/) 或 [Ceph](https://ceph.io/), 则可以使用. 它们一般应用在集群上.
+3. 若是支持 [Memcached](https://memcached.org/), 则可以使用. 它们一般应用在集群上.
#### 如何使用
@@ -116,7 +116,7 @@ DIV2K_train_HR_sub.lmdb
**如何制作**
我们提供了脚本来制作. 在运行脚本前, 需要根据需求修改相应的参数. 目前支持 DIV2K, REDS 和 Vimeo90K 数据集; 其他数据集可仿照进行制作.
- `python scripts/create_lmdb.py`
+ `python scripts/data_preparation/create_lmdb.py`
#### 预读取数据
@@ -155,17 +155,17 @@ DIV2K 数据集被广泛使用在图像复原的任务中.
1. 从[官网](https://data.vision.ee.ethz.ch/cvl/DIV2K)下载数据.
1. Crop to sub-images: 因为 DIV2K 数据集是 2K 分辨率的 (比如: 2048x1080), 而我们在训练的时候往往并不要那么大 (常见的是 128x128 或者 192x192 的训练patch). 因此我们可以先把2K的图片裁剪成有overlap的 480x480 的子图像块. 然后再由 dataloader 从这个 480x480 的子图像块中随机crop出 128x128 或者 192x192 的训练patch.
- 运行脚本 [extract_subimages.py](../scripts/extract_subimages.py):
+ 运行脚本 [extract_subimages.py](../scripts/data_preparation/extract_subimages.py):
```python
- python scripts/extract_subimages.py
+ python scripts/data_preparation/extract_subimages.py
```
使用之前可能需要修改文件里面的路径和配置参数.
**注意**: sub-image 的尺寸和训练patch的尺寸 (`gt_size`) 是不同的. 我们先把2K分辨率的图像 crop 成 sub-images (往往是 480x480), 然后存储起来. 在训练的时候, dataloader会读取这些sub-images, 然后进一步随机裁剪成 `gt_size` x `gt_size`的大小.
-1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/create_lmdb.py`, 注意选择`create_lmdb_for_div2k`函数, 并需要修改函数相应的配置和路径.
+1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/data_preparation/create_lmdb.py`, 注意选择`create_lmdb_for_div2k`函数, 并需要修改函数相应的配置和路径.
1. 测试: `tests/test_paired_image_dataset.py`, 注意修改函数相应的配置和路径.
-1. [可选] 若需要使用 meta_info_file, 运行 `python scripts/generate_meta_info.py` 来生成 meta_info_file.
+1. [可选] 若需要使用 meta_info_file, 运行 `python scripts/data_preparation/generate_meta_info.py` 来生成 meta_info_file.
### 其他常见图像超分数据集
@@ -182,7 +182,7 @@ DIV2K 数据集被广泛使用在图像复原的任务中.
Classical SR Training |
T91 |
91 images for training |
- Google Drive / Baidu Drive |
+ Google Drive / Baidu Drive |
BSDS200 |
@@ -277,8 +277,8 @@ DIV2K 数据集被广泛使用在图像复原的任务中.
**数据准备步骤**
1. 从[官网](https://seungjunnah.github.io/Datasets/reds.html)下载数据
-1. 整合 training 和 validation 数据: `python scripts/regroup_reds_dataset.py`
-1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/create_lmdb.py`, 注意选择`create_lmdb_for_reds`函数, 并需要修改函数相应的配置和路径.
+1. 整合 training 和 validation 数据: `python scripts/data_preparation/regroup_reds_dataset.py`
+1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/data_preparation/create_lmdb.py`, 注意选择`create_lmdb_for_reds`函数, 并需要修改函数相应的配置和路径.
1. 测试: `python tests/test_reds_dataset.py`, 注意修改函数相应的配置和路径.
### Vimeo90K
@@ -290,7 +290,7 @@ DIV2K 数据集被广泛使用在图像复原的任务中.
1. 下载数据: [`Septuplets dataset --> The original training + test set (82GB)`](http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip). 这些是Ground-Truth. 里面有`sep_trainlist.txt`文件来区分训练数据.
1. 生成低分辨率图片. (TODO)
The low-resolution images in the Vimeo90K test dataset are generated with the MATLAB bicubic downsampling kernel. Use the script `data_scripts/generate_LR_Vimeo90K.m` (run in MATLAB) to generate the low-resolution images.
-1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/create_lmdb.py`, 注意选择`create_lmdb_for_vimeo90k`函数, 并需要修改函数相应的配置和路径.
+1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/data_preparation/create_lmdb.py`, 注意选择`create_lmdb_for_vimeo90k`函数, 并需要修改函数相应的配置和路径.
1. 测试: `python tests/test_vimeo90k_dataset.py`, 注意修改函数相应的配置和路径.
## StyleGAN2
@@ -303,5 +303,5 @@ The low-resolution images in the Vimeo90K test dataset are generated with the MA
1. 从 tfrecords 提取到*图片*或者*LMDB*. (需要安装 TensorFlow 来读取 tfrecords). 我们对每一个分辨率的人脸都单独创建文件夹或者LMDB文件.
```bash
- python scripts/extract_images_from_tfrecords.py
+ python scripts/data_preparation/extract_images_from_tfrecords.py
```
diff --git a/docs/DesignConvention.md b/docs/DesignConvention.md
index 35ee55f..10d737a 100644
--- a/docs/DesignConvention.md
+++ b/docs/DesignConvention.md
@@ -34,7 +34,7 @@ Specifically, we implement it through `importlib` and `getattr`. Taking the data
# scan all the files under the data folder with '_dataset' in file names
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [
- osp.splitext(osp.basename(v))[0] for v in mmcv.scandir(data_folder)
+ osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
if v.endswith('_dataset.py')
]
# import all the dataset modules
diff --git a/docs/DesignConvention_CN.md b/docs/DesignConvention_CN.md
index d3c16d3..536d6a6 100644
--- a/docs/DesignConvention_CN.md
+++ b/docs/DesignConvention_CN.md
@@ -36,7 +36,7 @@
# scan all the files under the data folder with '_dataset' in file names
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [
- osp.splitext(osp.basename(v))[0] for v in mmcv.scandir(data_folder)
+ osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
if v.endswith('_dataset.py')
]
# import all the dataset modules
diff --git a/docs/HOWTOs.md b/docs/HOWTOs.md
index a2d6433..ddda95f 100644
--- a/docs/HOWTOs.md
+++ b/docs/HOWTOs.md
@@ -8,23 +8,23 @@
1. Download FFHQ dataset. Recommend to download the tfrecords files from [NVlabs/ffhq-dataset](https://github.com/NVlabs/ffhq-dataset).
1. Extract tfrecords to images or LMDBs (TensorFlow is required to read tfrecords):
- > python scripts/extract_images_from_tfrecords.py
+ > python scripts/data_preparation/extract_images_from_tfrecords.py
1. Modify the config file in `options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml`
1. Train with distributed training. More training commands are in [TrainTest.md](TrainTest.md).
> python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ_800k.yml --launcher pytorch
-## How to test StyleGAN2
+## How to inference StyleGAN2
1. Download pre-trained models from **ModelZoo** ([Google Drive](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing), [百度网盘](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g)) to the `experiments/pretrained_models` folder.
1. Test.
- > python tests/test_stylegan2.py
+ > python inference/inference_stylegan2.py
1. The results are in the `samples` folder.
-## How to test DFDNet
+## How to inference DFDNet
1. Install [dlib](http://dlib.net/), because DFDNet uses dlib to do face recognition and landmark detection. [Installation reference](https://github.com/davisking/dlib).
1. Clone dlib repo: `git clone git@github.com:davisking/dlib.git`
@@ -43,6 +43,6 @@
4. Prepare the testing dataset in the `datasets`, for example, we put images in the `datasets/TestWhole` folder.
5. Test.
- > python tests/test_face_dfdnet.py --upscale_factor=2 --test_path datasets/TestWhole
+ > python inference/inference_dfdnet.py --upscale_factor=2 --test_path datasets/TestWhole
6. The results are in the `results/DFDNet` folder.
diff --git a/docs/HOWTOs_CN.md b/docs/HOWTOs_CN.md
index aad7f25..df2ab25 100644
--- a/docs/HOWTOs_CN.md
+++ b/docs/HOWTOs_CN.md
@@ -8,7 +8,7 @@
1. 下载 FFHQ 数据集. 推荐从 [NVlabs/ffhq-dataset](https://github.com/NVlabs/ffhq-dataset) 下载 tfrecords 文件.
1. 从tfrecords 提取到*图片*或者*LMDB*. (需要安装 TensorFlow 来读取 tfrecords).
- > python scripts/extract_images_from_tfrecords.py
+ > python scripts/data_preparation/extract_images_from_tfrecords.py
1. 修改配置文件 `options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml`
1. 使用分布式训练. 更多训练命令: [TrainTest_CN.md](TrainTest_CN.md)
@@ -20,7 +20,7 @@
1. 从 **ModelZoo** ([Google Drive](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing), [百度网盘](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g)) 下载预训练模型到 `experiments/pretrained_models` 文件夹.
1. 测试.
- > python tests/test_stylegan2.py
+ > python inference/inference_stylegan2.py
1. 结果在 `samples` 文件夹
@@ -43,6 +43,6 @@
4. 准备测试图片到 `datasets`, 比如说我们把测试图片放在 `datasets/TestWhole` 文件夹.
5. 测试.
- > python tests/test_face_dfdnet.py --upscale_factor=2 --test_path datasets/TestWhole
+ > python inference/inference_dfdnet.py --upscale_factor=2 --test_path datasets/TestWhole
6. 结果在 `results/DFDNet` 文件夹.
diff --git a/docs/Metrics.md b/docs/Metrics.md
new file mode 100644
index 0000000..c4f0cb1
--- /dev/null
+++ b/docs/Metrics.md
@@ -0,0 +1,35 @@
+# Metrics
+
+[English](Metrics.md) **|** [简体中文](Metrics_CN.md)
+
+## PSNR and SSIM
+
+## NIQE
+
+## FID
+
+> FID measures the similarity between two datasets of images. It was shown to correlate well with human judgement of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks.
+> FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network.
+
+References
+
+- https://github.com/mseitzer/pytorch-fid
+- [GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium](https://arxiv.org/abs/1706.08500)
+- [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337)
+
+### Pre-calculated FFHQ inception feature statistics
+
+Usually, we put the downloaded inception feature statistics in `basicsr/metrics`.
+
+:arrow_double_down: Google Drive: [metrics data](https://drive.google.com/drive/folders/13cWIQyHX3iNmZRJ5v8v3kdyrT9RBTAi9?usp=sharing)
+:arrow_double_down: 百度网盘: [评价指标数据](https://pan.baidu.com/s/10mMKXSEgrC5y7m63W5vbMQ)
+
+| File Name | Dataset | Image Shape | Sample Numbers|
+| :------------- | :----------:|:----------:|:----------:|
+| inception_FFHQ_256-0948f50d.pth | FFHQ | 256 x 256 | 50,000 |
+| inception_FFHQ_512-f7b384ab.pth | FFHQ | 512 x 512 | 50,000 |
+| inception_FFHQ_1024-75f195dc.pth | FFHQ | 1024 x 1024 | 50,000 |
+| inception_FFHQ_256_stylegan2_pytorch-abba9d31.pth | FFHQ | 256 x 256 | 50,000 |
+
+- All the FFHQ inception feature statistics calculated on the resized 299 x 299 size.
+- `inception_FFHQ_256_stylegan2_pytorch-abba9d31.pth` is converted from the statistics in [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch).
diff --git a/docs/Metrics_CN.md b/docs/Metrics_CN.md
new file mode 100644
index 0000000..c5f518c
--- /dev/null
+++ b/docs/Metrics_CN.md
@@ -0,0 +1,36 @@
+# 评价指标
+
+[English](Metrics.md) **|** [简体中文](Metrics_CN.md)
+
+## PSNR and SSIM
+
+## NIQE
+
+## FID
+
+> FID measures the similarity between two datasets of images. It was shown to correlate well with human judgement of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks.
+> FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network.
+
+参考
+
+- https://github.com/mseitzer/pytorch-fid
+- [GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium](https://arxiv.org/abs/1706.08500)
+- [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337)
+
+### Pre-calculated FFHQ inception feature statistics
+
+通常, 我们把下载的 inception 网络的特征统计数据 (用于计算FID) 放在 `basicsr/metrics`.
+
+
+:arrow_double_down: 百度网盘: [评价指标数据](https://pan.baidu.com/s/10mMKXSEgrC5y7m63W5vbMQ)
+:arrow_double_down: Google Drive: [metrics data](https://drive.google.com/drive/folders/13cWIQyHX3iNmZRJ5v8v3kdyrT9RBTAi9?usp=sharing)
+
+| File Name | Dataset | Image Shape | Sample Numbers|
+| :------------- | :----------:|:----------:|:----------:|
+| inception_FFHQ_256-0948f50d.pth | FFHQ | 256 x 256 | 50,000 |
+| inception_FFHQ_512-f7b384ab.pth | FFHQ | 512 x 512 | 50,000 |
+| inception_FFHQ_1024-75f195dc.pth | FFHQ | 1024 x 1024 | 50,000 |
+| inception_FFHQ_256_stylegan2_pytorch-abba9d31.pth | FFHQ | 256 x 256 | 50,000 |
+
+- All the FFHQ inception feature statistics calculated on the resized 299 x 299 size.
+- `inception_FFHQ_256_stylegan2_pytorch-abba9d31.pth` is converted from the statistics in [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch).
diff --git a/docs/ModelZoo.md b/docs/ModelZoo.md
index af6579c..4dd25aa 100644
--- a/docs/ModelZoo.md
+++ b/docs/ModelZoo.md
@@ -2,6 +2,11 @@
[English](ModelZoo.md) **|** [简体中文](ModelZoo_CN.md)
+:arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing)
+:arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ)
+
+---
+
We provide:
1. Official models converted directly from official released models
@@ -9,7 +14,7 @@ We provide:
You can put the downloaded models in the `experiments/pretrained_models` folder.
-**[Download official pre-trained models]** ([Google Drive](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing), [百度网盘](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g))(https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing))
+**[Download official pre-trained models]** ([Google Drive](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing), [百度网盘](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g))
You can use the scrip to download pre-trained models from Google Drive.
@@ -93,7 +98,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity)
- **L** (Large): # of channels = 128, # of back residual blocks = 40. This setting is used in our competition submission.
- **M** (Moderate): # of channels = 64, # of back residual blocks = 10.
-[Download Models from Google Drive](https://drive.google.com/open?id=1WfROVUqKOBS5gGvQzBfU1DNZ4XwPA3LD)
| Model name |[Test Set] PSNR/SSIM |
|:----------:|:----------:|
@@ -107,7 +111,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity)
1 Y or RGB denotes the evaluation on Y (luminance) or RGB channels.
#### Stage 2 models for the NTIRE19 Competition
-[Download Models from Google Drive](https://drive.google.com/drive/folders/1PMoy1cKlIYWly6zY0tG2Q4YAH7V_HZns?usp=sharing)
| Model name |[Test Set] PSNR/SSIM |
|:----------:|:----------:|
@@ -119,7 +122,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity)
## DUF
The models are converted from the [officially released models](https://github.com/yhjo09/VSR-DUF).
-[Download Models from Google Drive](https://drive.google.com/open?id=1seY9nclMuwk_SpqKQhx1ItTcQShM5R50)
| Model name | [Test Set] PSNR/SSIM1 | Official Results2 |
|:----------:|:----------:|:----------:|
@@ -136,7 +138,6 @@ The models are converted from the [officially released models](https://github.co
## TOF
The models are converted from the [officially released models](https://github.com/anchen1011/toflow).
-[Download Models from Google Drive](https://drive.google.com/open?id=18kJcxPLeNK8e0kYEiwmsnu9wVmhdMFFG)
| Model name | [Test Set] PSNR/SSIM | Official Results1 |
|:----------:|:----------:|:----------:|
diff --git a/docs/ModelZoo_CN.md b/docs/ModelZoo_CN.md
index 9290a91..b192ee1 100644
--- a/docs/ModelZoo_CN.md
+++ b/docs/ModelZoo_CN.md
@@ -2,6 +2,11 @@
[English](ModelZoo.md) **|** [简体中文](ModelZoo_CN.md)
+:arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ)
+:arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing)
+
+---
+
我们提供了:
1. 官方的模型, 它们是从官方release的models直接转化过来的
@@ -92,8 +97,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity)
- **L** (Large): # of channels = 128, # of back residual blocks = 40. This setting is used in our competition submission.
- **M** (Moderate): # of channels = 64, # of back residual blocks = 10.
-[Download Models from Google Drive](https://drive.google.com/open?id=1WfROVUqKOBS5gGvQzBfU1DNZ4XwPA3LD)
-
| Model name |[Test Set] PSNR/SSIM |
|:----------:|:----------:|
| EDVR_Vimeo90K_SR_L | [Vid4] (Y1) 27.35/0.8264 [[↓Results]](https://drive.google.com/open?id=14nozpSfe9kC12dVuJ9mspQH5ZqE4mT9K)
(RGB) 25.83/0.8077|
@@ -106,7 +109,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity)
1 Y or RGB denotes the evaluation on Y (luminance) or RGB channels.
#### Stage 2 models for the NTIRE19 Competition
-[Download Models from Google Drive](https://drive.google.com/drive/folders/1PMoy1cKlIYWly6zY0tG2Q4YAH7V_HZns?usp=sharing)
| Model name |[Test Set] PSNR/SSIM |
|:----------:|:----------:|
@@ -118,7 +120,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity)
## DUF
The models are converted from the [officially released models](https://github.com/yhjo09/VSR-DUF).
-[Download Models from Google Drive](https://drive.google.com/open?id=1seY9nclMuwk_SpqKQhx1ItTcQShM5R50)
| Model name | [Test Set] PSNR/SSIM1 | Official Results2 |
|:----------:|:----------:|:----------:|
@@ -135,7 +136,6 @@ The models are converted from the [officially released models](https://github.co
## TOF
The models are converted from the [officially released models](https://github.com/anchen1011/toflow).
-[Download Models from Google Drive](https://drive.google.com/open?id=18kJcxPLeNK8e0kYEiwmsnu9wVmhdMFFG)
| Model name | [Test Set] PSNR/SSIM | Official Results1 |
|:----------:|:----------:|:----------:|
diff --git a/inference/inference_dfdnet.py b/inference/inference_dfdnet.py
new file mode 100644
index 0000000..982c524
--- /dev/null
+++ b/inference/inference_dfdnet.py
@@ -0,0 +1,210 @@
+import argparse
+import glob
+import numpy as np
+import os
+import torch
+import torchvision.transforms as transforms
+from skimage import io
+
+from basicsr.models.archs.dfdnet_arch import DFDNet
+from basicsr.utils import imwrite, tensor2img
+from basicsr.utils.face_util import FaceRestorationHelper
+
+
+def get_part_location(landmarks):
+ """Get part locations from landmarks."""
+ map_left_eye = list(np.hstack((range(17, 22), range(36, 42))))
+ map_right_eye = list(np.hstack((range(22, 27), range(42, 48))))
+ map_nose = list(range(29, 36))
+ map_mouth = list(range(48, 68))
+
+ # left eye
+ mean_left_eye = np.mean(landmarks[map_left_eye], 0) # (x, y)
+ half_len_left_eye = np.max((np.max(
+ np.max(landmarks[map_left_eye], 0) -
+ np.min(landmarks[map_left_eye], 0)) / 2, 16)) # A number
+ loc_left_eye = np.hstack((mean_left_eye - half_len_left_eye + 1,
+ mean_left_eye + half_len_left_eye)).astype(int)
+ loc_left_eye = torch.from_numpy(loc_left_eye).unsqueeze(0)
+ # (1, 4), the four numbers forms two coordinates in the diagonal
+
+ # right eye
+ mean_right_eye = np.mean(landmarks[map_right_eye], 0)
+ half_len_right_eye = np.max((np.max(
+ np.max(landmarks[map_right_eye], 0) -
+ np.min(landmarks[map_right_eye], 0)) / 2, 16))
+ loc_right_eye = np.hstack(
+ (mean_right_eye - half_len_right_eye + 1,
+ mean_right_eye + half_len_right_eye)).astype(int)
+ loc_right_eye = torch.from_numpy(loc_right_eye).unsqueeze(0)
+ # nose
+ mean_nose = np.mean(landmarks[map_nose], 0)
+ half_len_nose = np.max((np.max(
+ np.max(landmarks[map_nose], 0) - np.min(landmarks[map_nose], 0)) / 2,
+ 16)) # noqa: E126
+ loc_nose = np.hstack(
+ (mean_nose - half_len_nose + 1, mean_nose + half_len_nose)).astype(int)
+ loc_nose = torch.from_numpy(loc_nose).unsqueeze(0)
+ # mouth
+ mean_mouth = np.mean(landmarks[map_mouth], 0)
+ half_len_mouth = np.max((np.max(
+ np.max(landmarks[map_mouth], 0) - np.min(landmarks[map_mouth], 0)) / 2,
+ 16)) # noqa: E126
+ loc_mouth = np.hstack((mean_mouth - half_len_mouth + 1,
+ mean_mouth + half_len_mouth)).astype(int)
+ loc_mouth = torch.from_numpy(loc_mouth).unsqueeze(0)
+
+ return loc_left_eye, loc_right_eye, loc_nose, loc_mouth
+
+
+if __name__ == '__main__':
+ """We try to align to the official codes. But there are still slight
+ differences: 1) we use dlib for 68 landmark detection; 2) the used image
+ package are different (especially for reading and writing.)
+ """
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('--upscale_factor', type=int, default=2)
+ parser.add_argument(
+ '--model_path',
+ type=str,
+ default= # noqa: E251
+ 'experiments/pretrained_models/DFDNet/DFDNet_official-d1fa5650.pth')
+ parser.add_argument(
+ '--dict_path',
+ type=str,
+ default= # noqa: E251
+ 'experiments/pretrained_models/DFDNet/DFDNet_dict_512-f79685f0.pth')
+ parser.add_argument('--test_path', type=str, default='datasets/TestWhole')
+ parser.add_argument('--upsample_num_times', type=int, default=1)
+ parser.add_argument('--save_inverse_affine', action='store_true')
+ parser.add_argument('--only_keep_largest', action='store_true')
+ # The official codes use skimage.io to read the cropped images from disk
+ # instead of directly using the intermediate results in the memory (as we
+ # do). Such a different operation brings slight differences due to
+ # skimage.io. For aligning with the official results, we could set the
+ # official_adaption to True.
+ parser.add_argument('--official_adaption', type=bool, default=True)
+
+ # The following are the paths for dlib models
+ parser.add_argument(
+ '--detection_path',
+ type=str,
+ default= # noqa: E251
+ 'experiments/pretrained_models/dlib/mmod_human_face_detector-4cb19393.dat' # noqa: E501
+ )
+ parser.add_argument(
+ '--landmark5_path',
+ type=str,
+ default= # noqa: E251
+ 'experiments/pretrained_models/dlib/shape_predictor_5_face_landmarks-c4b1e980.dat' # noqa: E501
+ )
+ parser.add_argument(
+ '--landmark68_path',
+ type=str,
+ default= # noqa: E251
+ 'experiments/pretrained_models/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat' # noqa: E501
+ )
+
+ args = parser.parse_args()
+ if args.test_path.endswith('/'): # solve when path ends with /
+ args.test_path = args.test_path[:-1]
+ result_root = f'results/DFDNet/{os.path.basename(args.test_path)}'
+
+ # set up the DFDNet
+ net = DFDNet(64, dict_path=args.dict_path).to(device)
+ checkpoint = torch.load(
+ args.model_path, map_location=lambda storage, loc: storage)
+ net.load_state_dict(checkpoint['params'])
+ net.eval()
+
+ save_crop_root = os.path.join(result_root, 'cropped_faces')
+ save_inverse_affine_root = os.path.join(result_root, 'inverse_affine')
+ os.makedirs(save_inverse_affine_root, exist_ok=True)
+ save_restore_root = os.path.join(result_root, 'restored_faces')
+ save_final_root = os.path.join(result_root, 'final_results')
+
+ face_helper = FaceRestorationHelper(args.upscale_factor, face_size=512)
+
+ # scan all the jpg and png images
+ for img_path in sorted(
+ glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))):
+ img_name = os.path.basename(img_path)
+ print(f'Processing {img_name} image ...')
+ save_crop_path = os.path.join(save_crop_root, img_name)
+ if args.save_inverse_affine:
+ save_inverse_affine_path = os.path.join(save_inverse_affine_root,
+ img_name)
+ else:
+ save_inverse_affine_path = None
+
+ face_helper.init_dlib(args.detection_path, args.landmark5_path,
+ args.landmark68_path)
+ # detect faces
+ num_det_faces = face_helper.detect_faces(
+ img_path,
+ upsample_num_times=args.upsample_num_times,
+ only_keep_largest=args.only_keep_largest)
+ # get 5 face landmarks for each face
+ num_landmarks = face_helper.get_face_landmarks_5()
+ print(f'\tDetect {num_det_faces} faces, {num_landmarks} landmarks.')
+ # warp and crop each face
+ face_helper.warp_crop_faces(save_crop_path, save_inverse_affine_path)
+
+ if args.official_adaption:
+ path, ext = os.path.splitext(save_crop_path)
+ pathes = sorted(glob.glob(f'{path}_[0-9]*.png'))
+ cropped_faces = [io.imread(path) for path in pathes]
+ else:
+ cropped_faces = face_helper.cropped_faces
+
+ # get 68 landmarks for each cropped face
+ num_landmarks = face_helper.get_face_landmarks_68()
+ print(f'\tDetect {num_landmarks} faces for 68 landmarks.')
+
+ face_helper.free_dlib_gpu_memory()
+
+ print('\tFace restoration ...')
+ # face restoration for each cropped face
+ assert len(cropped_faces) == len(face_helper.all_landmarks_68)
+ for idx, (cropped_face, landmarks) in enumerate(
+ zip(cropped_faces, face_helper.all_landmarks_68)):
+ if landmarks is None:
+ print(f'Landmarks is None, skip cropped faces with idx {idx}.')
+ # just copy the cropped faces to the restored faces
+ restored_face = cropped_face
+ else:
+ # prepare data
+ part_locations = get_part_location(landmarks)
+ cropped_face = transforms.ToTensor()(cropped_face)
+ cropped_face = transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))(
+ cropped_face)
+ cropped_face = cropped_face.unsqueeze(0).to(device)
+
+ try:
+ with torch.no_grad():
+ output = net(cropped_face, part_locations)
+ restored_face = tensor2img(output, min_max=(-1, 1))
+ del output
+ torch.cuda.empty_cache()
+ except Exception as e:
+ print(f'DFDNet inference fail: {e}')
+ restored_face = tensor2img(cropped_face, min_max=(-1, 1))
+
+ path = os.path.splitext(os.path.join(save_restore_root,
+ img_name))[0]
+ save_path = f'{path}_{idx:02d}.png'
+ imwrite(restored_face, save_path)
+ face_helper.add_restored_face(restored_face)
+
+ print('\tGenerate the final result ...')
+ # paste each restored face to the input image
+ face_helper.paste_faces_to_input_image(
+ os.path.join(save_final_root, img_name))
+
+ # clean all the intermediate results to process the next image
+ face_helper.clean_all()
+
+ print(f'\nAll results are saved in {result_root}')
diff --git a/inference/inference_esrgan.py b/inference/inference_esrgan.py
new file mode 100644
index 0000000..8c64966
--- /dev/null
+++ b/inference/inference_esrgan.py
@@ -0,0 +1,55 @@
+import argparse
+import cv2
+import glob
+import numpy as np
+import os
+import torch
+
+from basicsr.models.archs.rrdbnet_arch import RRDBNet
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--model_path',
+ type=str,
+ default= # noqa: E251
+ 'experiments/pretrained_models/ESRGAN/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth' # noqa: E501
+ )
+ parser.add_argument(
+ '--folder',
+ type=str,
+ default='datasets/Set14/LRbicx4',
+ help='input test image folder')
+ args = parser.parse_args()
+
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ # set up model
+ model = RRDBNet(
+ num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32)
+ model.load_state_dict(torch.load(args.model_path)['params'], strict=True)
+ model.eval()
+ model = model.to(device)
+
+ os.makedirs('results/ESRGAN', exist_ok=True)
+ for idx, path in enumerate(
+ sorted(glob.glob(os.path.join(args.folder, '*')))):
+ imgname = os.path.splitext(os.path.basename(path))[0]
+ print('Testing', idx, imgname)
+ # read image
+ img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
+ img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]],
+ (2, 0, 1))).float()
+ img = img.unsqueeze(0).to(device)
+ # inference
+ with torch.no_grad():
+ output = model(img)
+ # save image
+ output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
+ output = (output * 255.0).round().astype(np.uint8)
+ cv2.imwrite(f'results/ESRGAN/{imgname}_ESRGAN.png', output)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tests/test_stylegan2.py b/inference/inference_stylegan2.py
similarity index 89%
rename from tests/test_stylegan2.py
rename to inference/inference_stylegan2.py
index c166e64..47bbe47 100644
--- a/tests/test_stylegan2.py
+++ b/inference/inference_stylegan2.py
@@ -1,6 +1,6 @@
import argparse
import math
-import mmcv
+import os
import torch
from torchvision import utils
@@ -30,7 +30,7 @@ def generate(args, g_ema, device, mean_latent, randomize_noise):
if __name__ == '__main__':
- device = 'cuda'
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
@@ -43,7 +43,7 @@ def generate(args, g_ema, device, mean_latent, randomize_noise):
'--ckpt',
type=str,
default= # noqa: E251
- 'experiments/pretrained_models/stylegan2_ffhq_config_f_1024_official-b09c3668.pth' # noqa: E501
+ 'experiments/pretrained_models/StyleGAN/stylegan2_ffhq_config_f_1024_official-b09c3668.pth' # noqa: E501
)
parser.add_argument('--channel_multiplier', type=int, default=2)
parser.add_argument('--randomize_noise', type=bool, default=True)
@@ -52,7 +52,7 @@ def generate(args, g_ema, device, mean_latent, randomize_noise):
args.latent = 512
args.n_mlp = 8
- mmcv.mkdir_or_exist('samples')
+ os.makedirs('samples', exist_ok=True)
set_random_seed(2020)
g_ema = StyleGAN2Generator(
diff --git a/make.sh b/make.sh
deleted file mode 100644
index 1990c6b..0000000
--- a/make.sh
+++ /dev/null
@@ -1,7 +0,0 @@
-#!/usr/bin/env bash
-
-# You may need to modify the following paths before compiling
-CUDA_HOME=/usr/local/cuda \
-CUDNN_INCLUDE_DIR=/usr/local/cuda \
-CUDNN_LIB_DIR=/usr/local/cuda \
-python setup.py develop
diff --git a/options/test/DUF/test_DUF_official.yml b/options/test/DUF/test_DUF_official.yml
index 5d16dd2..d0bc81c 100644
--- a/options/test/DUF/test_DUF_official.yml
+++ b/options/test/DUF/test_DUF_official.yml
@@ -28,8 +28,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/DUF_x4_52L_official-483d2c78.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/DUF/DUF_x4_52L_official-483d2c78.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/EDSR/test_EDSR_Lx2.yml b/options/test/EDSR/test_EDSR_Lx2.yml
index 05a1398..82dcb49 100644
--- a/options/test/EDSR/test_EDSR_Lx2.yml
+++ b/options/test/EDSR/test_EDSR_Lx2.yml
@@ -43,8 +43,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/EDSR/test_EDSR_Lx3.yml b/options/test/EDSR/test_EDSR_Lx3.yml
index c7c951c..6053ba6 100644
--- a/options/test/EDSR/test_EDSR_Lx3.yml
+++ b/options/test/EDSR/test_EDSR_Lx3.yml
@@ -43,8 +43,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/EDSR/test_EDSR_Lx4.yml b/options/test/EDSR/test_EDSR_Lx4.yml
index e9a55e0..37bb209 100644
--- a/options/test/EDSR/test_EDSR_Lx4.yml
+++ b/options/test/EDSR/test_EDSR_Lx4.yml
@@ -43,8 +43,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/EDSR/test_EDSR_Mx2.yml b/options/test/EDSR/test_EDSR_Mx2.yml
index f18dae7..b6ab304 100644
--- a/options/test/EDSR/test_EDSR_Mx2.yml
+++ b/options/test/EDSR/test_EDSR_Mx2.yml
@@ -43,8 +43,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/EDSR/test_EDSR_Mx3.yml b/options/test/EDSR/test_EDSR_Mx3.yml
index 612f213..c799603 100644
--- a/options/test/EDSR/test_EDSR_Mx3.yml
+++ b/options/test/EDSR/test_EDSR_Mx3.yml
@@ -43,8 +43,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/EDSR/test_EDSR_Mx4.yml b/options/test/EDSR/test_EDSR_Mx4.yml
index 0d52ef1..2686861 100644
--- a/options/test/EDSR/test_EDSR_Mx4.yml
+++ b/options/test/EDSR/test_EDSR_Mx4.yml
@@ -43,8 +43,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/EDVR/test_EDVR_L_deblur_REDS.yml b/options/test/EDVR/test_EDVR_L_deblur_REDS.yml
index 1576fc1..6982ab8 100644
--- a/options/test/EDVR/test_EDVR_L_deblur_REDS.yml
+++ b/options/test/EDVR/test_EDVR_L_deblur_REDS.yml
@@ -35,8 +35,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/EDVR_L_deblur_REDS_official-ca46bd8c.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_deblur_REDS_official-ca46bd8c.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml b/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml
index fbb243d..4108a2a 100644
--- a/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml
+++ b/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml
@@ -35,8 +35,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/EDVR_L_deblurcomp_REDS_official-0e988e5c.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_deblurcomp_REDS_official-0e988e5c.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml b/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml
index bd75815..768c173 100644
--- a/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml
+++ b/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml
@@ -35,8 +35,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/EDVR_L_x4_SR_REDS_official-9f5f5039.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_x4_SR_REDS_official-9f5f5039.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml b/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml
index 7428355..9929067 100644
--- a/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml
+++ b/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml
@@ -34,8 +34,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml b/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml
index 21cf0bf..dff07d8 100644
--- a/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml
+++ b/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml
@@ -35,8 +35,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml b/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml
index ed4ed55..fbe2b1b 100644
--- a/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml
+++ b/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml
@@ -35,8 +35,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml b/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml
index 95271f8..773286d 100644
--- a/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml
+++ b/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml
@@ -35,8 +35,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/EDVR_M_x4_SR_REDS_official-32075921.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_M_x4_SR_REDS_official-32075921.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/ESRGAN/test_ESRGAN_x4.yml b/options/test/ESRGAN/test_ESRGAN_x4.yml
index 1d23fb8..845789c 100644
--- a/options/test/ESRGAN/test_ESRGAN_x4.yml
+++ b/options/test/ESRGAN/test_ESRGAN_x4.yml
@@ -40,8 +40,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/ESRGAN/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml b/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml
index 997381d..d428740 100644
--- a/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml
+++ b/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml
@@ -29,8 +29,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/ESRGAN/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml b/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml
index 7c39a50..9636f22 100644
--- a/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml
+++ b/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml
@@ -40,8 +40,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/pretrained_models/ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/ESRGAN/ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/RCAN/test_RCAN.yml b/options/test/RCAN/test_RCAN.yml
index 7734d91..3f22dd2 100644
--- a/options/test/RCAN/test_RCAN.yml
+++ b/options/test/RCAN/test_RCAN.yml
@@ -49,5 +49,5 @@ save_img: true
# path
path:
- pretrain_model_g: ./experiments/pretrained_models/RCAN_BIX4-official.pth
- strict_load: true
+ pretrain_network_g: ./experiments/pretrained_models/RCAN/RCAN_BIX4-official.pth
+ strict_load_g: true
diff --git a/options/test/SRResNet_SRGAN/test_MSRGAN_x4.yml b/options/test/SRResNet_SRGAN/test_MSRGAN_x4.yml
index 517150c..5fef091 100644
--- a/options/test/SRResNet_SRGAN/test_MSRGAN_x4.yml
+++ b/options/test/SRResNet_SRGAN/test_MSRGAN_x4.yml
@@ -40,8 +40,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/004_MSRGAN_x4_f64b16_DIV2K_400k_B16G1_wandb/models/net_g_400000.pth
- strict_load: true
+ pretrain_network_g: experiments/004_MSRGAN_x4_f64b16_DIV2K_400k_B16G1_wandb/models/net_g_400000.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/SRResNet_SRGAN/test_MSRResNet_x2.yml b/options/test/SRResNet_SRGAN/test_MSRResNet_x2.yml
index 29c09f8..d76411d 100644
--- a/options/test/SRResNet_SRGAN/test_MSRResNet_x2.yml
+++ b/options/test/SRResNet_SRGAN/test_MSRResNet_x2.yml
@@ -40,8 +40,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/002_MSRResNet_x2_f64b16_DIV2K_1000k_B16G1_001pretrain_wandb/models/net_g_1000000.pth
- strict_load: true
+ pretrain_network_g: experiments/002_MSRResNet_x2_f64b16_DIV2K_1000k_B16G1_001pretrain_wandb/models/net_g_1000000.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/SRResNet_SRGAN/test_MSRResNet_x3.yml b/options/test/SRResNet_SRGAN/test_MSRResNet_x3.yml
index 91b4e7f..0e8dc78 100644
--- a/options/test/SRResNet_SRGAN/test_MSRResNet_x3.yml
+++ b/options/test/SRResNet_SRGAN/test_MSRResNet_x3.yml
@@ -40,8 +40,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/003_MSRResNet_x3_f64b16_DIV2K_1000k_B16G1_001pretrain_wandb/models/net_g_1000000.pth
- strict_load: true
+ pretrain_network_g: experiments/003_MSRResNet_x3_f64b16_DIV2K_1000k_B16G1_001pretrain_wandb/models/net_g_1000000.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml b/options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml
index c5b0e32..ce5e1cf 100644
--- a/options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml
+++ b/options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml
@@ -40,8 +40,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
- strict_load: true
+ pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/SRResNet_SRGAN/test_MSRResNet_x4_woGT.yml b/options/test/SRResNet_SRGAN/test_MSRResNet_x4_woGT.yml
index 8e499cf..cdc8ea7 100644
--- a/options/test/SRResNet_SRGAN/test_MSRResNet_x4_woGT.yml
+++ b/options/test/SRResNet_SRGAN/test_MSRResNet_x4_woGT.yml
@@ -29,8 +29,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
- strict_load: true
+ pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/test/TOF/test_TOF_official.yml b/options/test/TOF/test_TOF_official.yml
index ab916c7..f61dbaf 100644
--- a/options/test/TOF/test_TOF_official.yml
+++ b/options/test/TOF/test_TOF_official.yml
@@ -26,8 +26,8 @@ save_img: true
# path
path:
- pretrain_model_g: experiments/pretrained_models/tof_official-e81c455f.pth
- strict_load: true
+ pretrain_network_g: experiments/pretrained_models/TOF/tof_official-e81c455f.pth
+ strict_load_g: true
# validation settings
val:
diff --git a/options/train/EDSR/train_EDSR_Lx2.yml b/options/train/EDSR/train_EDSR_Lx2.yml
index da645b7..bb3167e 100644
--- a/options/train/EDSR/train_EDSR_Lx2.yml
+++ b/options/train/EDSR/train_EDSR_Lx2.yml
@@ -54,8 +54,8 @@ network_g:
# path
path:
- pretrain_model_g: ~
- strict_load: true
+ pretrain_network_g: ~
+ strict_load_g: true
resume_state: ~
# training settings
diff --git a/options/train/EDSR/train_EDSR_Lx3.yml b/options/train/EDSR/train_EDSR_Lx3.yml
index 7b6ae45..326d95e 100644
--- a/options/train/EDSR/train_EDSR_Lx3.yml
+++ b/options/train/EDSR/train_EDSR_Lx3.yml
@@ -54,8 +54,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/204_EDSR_Lx2_f256b32_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth
- strict_load: false
+ pretrain_network_g: experiments/204_EDSR_Lx2_f256b32_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth
+ strict_load_g: false
resume_state: ~
# training settings
diff --git a/options/train/EDSR/train_EDSR_Lx4.yml b/options/train/EDSR/train_EDSR_Lx4.yml
index 6fe945c..ffd3a60 100644
--- a/options/train/EDSR/train_EDSR_Lx4.yml
+++ b/options/train/EDSR/train_EDSR_Lx4.yml
@@ -54,8 +54,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/204_EDSR_Lx2_f256b32_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth
- strict_load: false
+ pretrain_network_g: experiments/204_EDSR_Lx2_f256b32_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth
+ strict_load_g: false
resume_state: ~
# training settings
diff --git a/options/train/EDSR/train_EDSR_Mx2.yml b/options/train/EDSR/train_EDSR_Mx2.yml
index 37410f0..b8c81f9 100644
--- a/options/train/EDSR/train_EDSR_Mx2.yml
+++ b/options/train/EDSR/train_EDSR_Mx2.yml
@@ -54,8 +54,8 @@ network_g:
# path
path:
- pretrain_model_g: ~
- strict_load: true
+ pretrain_network_g: ~
+ strict_load_g: true
resume_state: ~
# training settings
diff --git a/options/train/EDSR/train_EDSR_Mx3.yml b/options/train/EDSR/train_EDSR_Mx3.yml
index 7f473a0..bd44e87 100644
--- a/options/train/EDSR/train_EDSR_Mx3.yml
+++ b/options/train/EDSR/train_EDSR_Mx3.yml
@@ -54,8 +54,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/201_EDSR_Mx2_f64b16_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth
- strict_load: false
+ pretrain_network_g: experiments/201_EDSR_Mx2_f64b16_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth
+ strict_load_g: false
resume_state: ~
# training settings
diff --git a/options/train/EDSR/train_EDSR_Mx4.yml b/options/train/EDSR/train_EDSR_Mx4.yml
index aa12b57..0f5e583 100644
--- a/options/train/EDSR/train_EDSR_Mx4.yml
+++ b/options/train/EDSR/train_EDSR_Mx4.yml
@@ -54,8 +54,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/201_EDSR_Mx2_f64b16_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth
- strict_load: false
+ pretrain_network_g: experiments/201_EDSR_Mx2_f64b16_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth
+ strict_load_g: false
resume_state: ~
# training settings
diff --git a/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml b/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml
index 0623d4c..ec5a78e 100644
--- a/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml
+++ b/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml
@@ -71,8 +71,8 @@ network_d:
# path
path:
- pretrain_model_g: experiments/101_EDVR_M_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth
- strict_load: true
+ pretrain_network_g: experiments/101_EDVR_M_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth
+ strict_load_g: true
resume_state: ~
# training settings
@@ -107,9 +107,9 @@ train:
'conv5_4': 1 # before relu
vgg_type: vgg19
use_input_norm: true
+ range_norm: false
perceptual_weight: 1.0
style_weight: 0
- norm_img: false
criterion: l1
gan_opt:
type: GANLoss
diff --git a/options/train/EDVR/train_EDVR_L_x4_SR_REDS.yml b/options/train/EDVR/train_EDVR_L_x4_SR_REDS.yml
index d0bb472..bcc6418 100644
--- a/options/train/EDVR/train_EDVR_L_x4_SR_REDS.yml
+++ b/options/train/EDVR/train_EDVR_L_x4_SR_REDS.yml
@@ -63,8 +63,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/103_EDVR_L_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth
- strict_load: false
+ pretrain_network_g: experiments/103_EDVR_L_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth
+ strict_load_g: false
resume_state: ~
# training settings
diff --git a/options/train/EDVR/train_EDVR_L_x4_SR_REDS_woTSA.yml b/options/train/EDVR/train_EDVR_L_x4_SR_REDS_woTSA.yml
index becefe2..32645a4 100644
--- a/options/train/EDVR/train_EDVR_L_x4_SR_REDS_woTSA.yml
+++ b/options/train/EDVR/train_EDVR_L_x4_SR_REDS_woTSA.yml
@@ -63,8 +63,8 @@ network_g:
# path
path:
- pretrain_model_g: ~
- strict_load: true
+ pretrain_network_g: ~
+ strict_load_g: true
resume_state: ~
# training settings
diff --git a/options/train/EDVR/train_EDVR_M_x4_SR_REDS.yml b/options/train/EDVR/train_EDVR_M_x4_SR_REDS.yml
index c463310..d79c8d3 100644
--- a/options/train/EDVR/train_EDVR_M_x4_SR_REDS.yml
+++ b/options/train/EDVR/train_EDVR_M_x4_SR_REDS.yml
@@ -63,8 +63,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/101_EDVR_M_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth
- strict_load: false
+ pretrain_network_g: experiments/101_EDVR_M_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth
+ strict_load_g: false
resume_state: ~
# training settings
diff --git a/options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml b/options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml
index bb8dba3..75552e9 100644
--- a/options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml
+++ b/options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml
@@ -63,8 +63,8 @@ network_g:
# path
path:
- pretrain_model_g: ~
- strict_load: true
+ pretrain_network_g: ~
+ strict_load_g: true
resume_state: ~
# training settings
diff --git a/options/train/ESRGAN/train_ESRGAN_x4.yml b/options/train/ESRGAN/train_ESRGAN_x4.yml
index 23acff2..057de06 100644
--- a/options/train/ESRGAN/train_ESRGAN_x4.yml
+++ b/options/train/ESRGAN/train_ESRGAN_x4.yml
@@ -55,8 +55,8 @@ network_d:
# path
path:
- pretrain_model_g: experiments/051_RRDBNet_PSNR_x4_f64b23_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
- strict_load: true
+ pretrain_network_g: experiments/051_RRDBNet_PSNR_x4_f64b23_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
+ strict_load_g: true
resume_state: ~
# training settings
@@ -91,9 +91,9 @@ train:
'conv5_4': 1 # before relu
vgg_type: vgg19
use_input_norm: true
+ range_norm: false
perceptual_weight: 1.0
style_weight: 0
- norm_img: false
criterion: l1
gan_opt:
type: GANLoss
diff --git a/options/train/ESRGAN/train_RRDBNet_PSNR_x4.yml b/options/train/ESRGAN/train_RRDBNet_PSNR_x4.yml
index a4ede70..c5882c8 100644
--- a/options/train/ESRGAN/train_RRDBNet_PSNR_x4.yml
+++ b/options/train/ESRGAN/train_RRDBNet_PSNR_x4.yml
@@ -51,8 +51,8 @@ network_g:
# path
path:
- pretrain_model_g: ~
- strict_load: true
+ pretrain_network_g: ~
+ strict_load_g: true
resume_state: ~
# training settings
diff --git a/options/train/RCAN/train_RCAN_x2.yml b/options/train/RCAN/train_RCAN_x2.yml
index c525c0d..531b142 100644
--- a/options/train/RCAN/train_RCAN_x2.yml
+++ b/options/train/RCAN/train_RCAN_x2.yml
@@ -57,8 +57,8 @@ network_g:
# path
path:
- pretrain_model_g: ~
- strict_load: true
+ pretrain_network_g: ~
+ strict_load_g: true
resume_state: ~
# training settings
diff --git a/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml b/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml
index 978b28e..e3681f2 100644
--- a/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml
+++ b/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml
@@ -60,8 +60,8 @@ network_d:
# path
path:
- pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
- strict_load: true
+ pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
+ strict_load_g: true
resume_state: ~
# training settings
@@ -96,9 +96,9 @@ train:
'conv5_4': 1 # before relu
vgg_type: vgg19
use_input_norm: true
+ scale: false
perceptual_weight: 1.0
style_weight: 0
- norm_img: false
criterion: l1
gan_opt:
type: GANLoss
diff --git a/options/train/SRResNet_SRGAN/train_MSRResNet_x2.yml b/options/train/SRResNet_SRGAN/train_MSRResNet_x2.yml
index f7e6014..3688a1a 100644
--- a/options/train/SRResNet_SRGAN/train_MSRResNet_x2.yml
+++ b/options/train/SRResNet_SRGAN/train_MSRResNet_x2.yml
@@ -54,8 +54,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
- strict_load: false
+ pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
+ strict_load_g: false
resume_state: ~
# training settings
diff --git a/options/train/SRResNet_SRGAN/train_MSRResNet_x3.yml b/options/train/SRResNet_SRGAN/train_MSRResNet_x3.yml
index 9b94d29..5c414ad 100644
--- a/options/train/SRResNet_SRGAN/train_MSRResNet_x3.yml
+++ b/options/train/SRResNet_SRGAN/train_MSRResNet_x3.yml
@@ -54,8 +54,8 @@ network_g:
# path
path:
- pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
- strict_load: false
+ pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
+ strict_load_g: false
resume_state: ~
# training settings
diff --git a/options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml b/options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml
index 647b334..1fa782f 100644
--- a/options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml
+++ b/options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml
@@ -54,8 +54,8 @@ network_g:
# path
path:
- pretrain_model_g: ~
- strict_load: true
+ pretrain_network_g: ~
+ strict_load_g: true
resume_state: ~
# training settings
diff --git a/options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml b/options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml
index 00b77ba..e112d44 100644
--- a/options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml
+++ b/options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml
@@ -42,8 +42,8 @@ network_d:
# path
path:
- pretrain_model_g: ~
- strict_load: true
+ pretrain_network_g: ~
+ strict_load_g: true
resume_state: ~
# training settings
diff --git a/requirements.txt b/requirements.txt
index 8202611..c014cdc 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,14 +1,15 @@
addict
future
lmdb
-matplotlib
-mmcv>=0.6
numpy
opencv-python
+Pillow
pyyaml
+requests
scikit-image
scipy
tb-nightly
torch>=1.3
torchvision
+tqdm
yapf
diff --git a/scripts/create_lmdb.py b/scripts/data_preparation/create_lmdb.py
similarity index 87%
rename from scripts/create_lmdb.py
rename to scripts/data_preparation/create_lmdb.py
index 4fa359b..e8eec3b 100644
--- a/scripts/create_lmdb.py
+++ b/scripts/data_preparation/create_lmdb.py
@@ -1,7 +1,8 @@
-import mmcv
+import argparse
from os import path as osp
-from basicsr.utils.lmdb import make_lmdb_from_imgs
+from basicsr.utils import scandir
+from basicsr.utils.lmdb_util import make_lmdb_from_imgs
def create_lmdb_for_div2k():
@@ -53,7 +54,7 @@ def prepare_keys_div2k(folder_path):
"""
print('Reading image path list ...')
img_path_list = sorted(
- list(mmcv.scandir(folder_path, suffix='png', recursive=False)))
+ list(scandir(folder_path, suffix='png', recursive=False)))
keys = [img_path.split('.png')[0] for img_path in sorted(img_path_list)]
return img_path_list, keys
@@ -96,7 +97,7 @@ def prepare_keys_reds(folder_path):
"""
print('Reading image path list ...')
img_path_list = sorted(
- list(mmcv.scandir(folder_path, suffix='png', recursive=True)))
+ list(scandir(folder_path, suffix='png', recursive=True)))
keys = [v.split('.png')[0] for v in img_path_list] # example: 000/00000000
return img_path_list, keys
@@ -160,6 +161,22 @@ def prepare_keys_vimeo90k(folder_path, train_list_path, mode):
if __name__ == '__main__':
- create_lmdb_for_div2k()
- # create_lmdb_for_reds()
- # create_lmdb_for_vimeo90k()
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ '--dataset',
+ type=str,
+ help=(
+ "Options: 'DIV2K', 'REDS', 'Vimeo90K' "
+ 'You may need to modify the corresponding configurations in codes.'
+ ))
+ args = parser.parse_args()
+ dataset = args.dataset.lower()
+ if dataset == 'div2k':
+ create_lmdb_for_div2k()
+ elif dataset == 'reds':
+ create_lmdb_for_reds()
+ elif dataset == 'vimeo90k':
+ create_lmdb_for_vimeo90k()
+ else:
+ raise ValueError('Wrong dataset.')
diff --git a/scripts/data_preparation/download_datasets.py b/scripts/data_preparation/download_datasets.py
new file mode 100644
index 0000000..215e3c8
--- /dev/null
+++ b/scripts/data_preparation/download_datasets.py
@@ -0,0 +1,71 @@
+import argparse
+import glob
+import os
+from os import path as osp
+
+from basicsr.utils.download_util import download_file_from_google_drive
+
+
+def download_dataset(dataset, file_ids):
+ save_path_root = './datasets/'
+ os.makedirs(save_path_root, exist_ok=True)
+
+ for file_name, file_id in file_ids.items():
+ save_path = osp.abspath(osp.join(save_path_root, file_name))
+ if osp.exists(save_path):
+ user_response = input(
+ f'{file_name} already exist. Do you want to cover it? Y/N\n')
+ if user_response.lower() == 'y':
+ print(f'Covering {file_name} to {save_path}')
+ download_file_from_google_drive(file_id, save_path)
+ elif user_response.lower() == 'n':
+ print(f'Skipping {file_name}')
+ else:
+ raise ValueError('Wrong input. Only accpets Y/N.')
+ else:
+ print(f'Downloading {file_name} to {save_path}')
+ download_file_from_google_drive(file_id, save_path)
+
+ # unzip
+ if save_path.endswith('.zip'):
+ extracted_path = save_path.replace('.zip', '')
+ print(f'Extract {save_path} to {extracted_path}')
+ import zipfile
+ with zipfile.ZipFile(save_path, 'r') as zip_ref:
+ zip_ref.extractall(extracted_path)
+
+ file_name = file_name.replace('.zip', '')
+ subfolder = osp.join(extracted_path, file_name)
+ if osp.isdir(subfolder):
+ print(f'Move {subfolder} to {extracted_path}')
+ import shutil
+ for path in glob.glob(osp.join(subfolder, '*')):
+ shutil.move(path, extracted_path)
+ shutil.rmtree(subfolder)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ 'dataset',
+ type=str,
+ help=("Options: 'Set5', 'Set14'. "
+ "Set to 'all' if you want to download all the dataset."))
+ args = parser.parse_args()
+
+ file_ids = {
+ 'Set5': {
+ 'Set5.zip': # file name
+ '1RtyIeUFTyW8u7oa4z7a0lSzT3T1FwZE9', # file id
+ },
+ 'Set14': {
+ 'Set14.zip': '1vsw07sV8wGrRQ8UARe2fO5jjgy9QJy_E',
+ }
+ }
+
+ if args.dataset == 'all':
+ for dataset in file_ids.keys():
+ download_dataset(dataset, file_ids[dataset])
+ else:
+ download_dataset(args.dataset, file_ids[args.dataset])
diff --git a/scripts/data_preparation/extract_images_from_tfrecords.py b/scripts/data_preparation/extract_images_from_tfrecords.py
new file mode 100644
index 0000000..14a4f67
--- /dev/null
+++ b/scripts/data_preparation/extract_images_from_tfrecords.py
@@ -0,0 +1,235 @@
+import argparse
+import cv2
+import glob
+import numpy as np
+import os
+
+from basicsr.utils.lmdb_util import LmdbMaker
+
+
+def convert_celeba_tfrecords(tf_file,
+ log_resolution,
+ save_root,
+ save_type='img',
+ compress_level=1):
+ """Convert CelebA tfrecords to images or lmdb files.
+
+ Args:
+ tf_file (str): Input tfrecords file in glob pattern.
+ Example: 'datasets/celeba/celeba_tfrecords/validation/validation-r08-s-*-of-*.tfrecords' # noqa:E501
+ log_resolution (int): Log scale of resolution.
+ save_root (str): Path root to save.
+ save_type (str): Save type. Options: img | lmdb. Default: img.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ """
+ if 'validation' in tf_file:
+ phase = 'validation'
+ else:
+ phase = 'train'
+ if save_type == 'lmdb':
+ save_path = os.path.join(save_root,
+ f'celeba_{2**log_resolution}_{phase}.lmdb')
+ lmdb_maker = LmdbMaker(save_path)
+ elif save_type == 'img':
+ save_path = os.path.join(save_root,
+ f'celeba_{2**log_resolution}_{phase}')
+ else:
+ raise ValueError('Wrong save type.')
+
+ os.makedirs(save_path, exist_ok=True)
+
+ idx = 0
+ for record in sorted(glob.glob(tf_file)):
+ print('Processing record: ', record)
+ record_iterator = tf.python_io.tf_record_iterator(record)
+ for string_record in record_iterator:
+ example = tf.train.Example()
+ example.ParseFromString(string_record)
+
+ # label = example.features.feature['label'].int64_list.value[0]
+ # attr = example.features.feature['attr'].int64_list.value
+ # male = attr[20]
+ # young = attr[39]
+
+ shape = example.features.feature['shape'].int64_list.value
+ h, w, c = shape
+ img_str = example.features.feature['data'].bytes_list.value[0]
+ img = np.fromstring(img_str, dtype=np.uint8).reshape((h, w, c))
+
+ img = img[:, :, [2, 1, 0]]
+
+ if save_type == 'img':
+ cv2.imwrite(os.path.join(save_path, f'{idx:08d}.png'), img)
+ elif save_type == 'lmdb':
+ _, img_byte = cv2.imencode(
+ '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ key = f'{idx:08d}/r{log_resolution:02d}'
+ lmdb_maker.put(img_byte, key, (h, w, c))
+
+ idx += 1
+ print(idx)
+
+ if save_type == 'lmdb':
+ lmdb_maker.close()
+
+
+def convert_ffhq_tfrecords(tf_file,
+ log_resolution,
+ save_root,
+ save_type='img',
+ compress_level=1):
+ """Convert FFHQ tfrecords to images or lmdb files.
+
+ Args:
+ tf_file (str): Input tfrecords file.
+ log_resolution (int): Log scale of resolution.
+ save_root (str): Path root to save.
+ save_type (str): Save type. Options: img | lmdb. Default: img.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ """
+
+ if save_type == 'lmdb':
+ save_path = os.path.join(save_root, f'ffhq_{2**log_resolution}.lmdb')
+ lmdb_maker = LmdbMaker(save_path)
+ elif save_type == 'img':
+ save_path = os.path.join(save_root, f'ffhq_{2**log_resolution}')
+ else:
+ raise ValueError('Wrong save type.')
+
+ os.makedirs(save_path, exist_ok=True)
+
+ idx = 0
+ for record in sorted(glob.glob(tf_file)):
+ print('Processing record: ', record)
+ record_iterator = tf.python_io.tf_record_iterator(record)
+ for string_record in record_iterator:
+ example = tf.train.Example()
+ example.ParseFromString(string_record)
+
+ shape = example.features.feature['shape'].int64_list.value
+ c, h, w = shape
+ img_str = example.features.feature['data'].bytes_list.value[0]
+ img = np.fromstring(img_str, dtype=np.uint8).reshape((c, h, w))
+
+ img = img.transpose(1, 2, 0)
+ img = img[:, :, [2, 1, 0]]
+ if save_type == 'img':
+ cv2.imwrite(os.path.join(save_path, f'{idx:08d}.png'), img)
+ elif save_type == 'lmdb':
+ _, img_byte = cv2.imencode(
+ '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ key = f'{idx:08d}/r{log_resolution:02d}'
+ lmdb_maker.put(img_byte, key, (h, w, c))
+
+ idx += 1
+ print(idx)
+
+ if save_type == 'lmdb':
+ lmdb_maker.close()
+
+
+def make_ffhq_lmdb_from_imgs(folder_path,
+ log_resolution,
+ save_root,
+ save_type='lmdb',
+ compress_level=1):
+ """Make FFHQ lmdb from images.
+
+ Args:
+ folder_path (str): Folder path.
+ log_resolution (int): Log scale of resolution.
+ save_root (str): Path root to save.
+ save_type (str): Save type. Options: img | lmdb. Default: img.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ """
+
+ if save_type == 'lmdb':
+ save_path = os.path.join(save_root,
+ f'ffhq_{2**log_resolution}_crop1.2.lmdb')
+ lmdb_maker = LmdbMaker(save_path)
+ else:
+ raise ValueError('Wrong save type.')
+
+ os.makedirs(save_path, exist_ok=True)
+
+ img_list = sorted(glob.glob(os.path.join(folder_path, '*')))
+ for idx, img_path in enumerate(img_list):
+ print(f'Processing {idx}: ', img_path)
+ img = cv2.imread(img_path)
+ h, w, c = img.shape
+
+ if save_type == 'lmdb':
+ _, img_byte = cv2.imencode(
+ '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ key = f'{idx:08d}/r{log_resolution:02d}'
+ lmdb_maker.put(img_byte, key, (h, w, c))
+
+ if save_type == 'lmdb':
+ lmdb_maker.close()
+
+
+if __name__ == '__main__':
+ """Read tfrecords w/o define a graph.
+
+ We have tested it on TensorFlow 1.15
+
+ Ref:
+ http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--dataset',
+ type=str,
+ default='ffhq',
+ help="Dataset name. Options: 'ffhq' | 'celeba'. Default: 'ffhq'.")
+ parser.add_argument(
+ '--tf_file',
+ type=str,
+ default='datasets/ffhq/ffhq-r10.tfrecords',
+ help=(
+ 'Input tfrecords file. For celeba, it should be glob pattern. '
+ 'Put quotes around the wildcard argument to prevent the shell '
+ 'from expanding it.'
+ "Example: 'datasets/celeba/celeba_tfrecords/validation/validation-r08-s-*-of-*.tfrecords'" # noqa:E501
+ ))
+ parser.add_argument(
+ '--log_resolution',
+ type=int,
+ default=10,
+ help='Log scale of resolution.')
+ parser.add_argument(
+ '--save_root',
+ type=str,
+ default='datasets/ffhq/',
+ help='Save root path.')
+ parser.add_argument(
+ '--save_type',
+ type=str,
+ default='img',
+ help="Save type. Options: 'img' | 'lmdb'. Default: 'img'.")
+ parser.add_argument(
+ '--compress_level',
+ type=int,
+ default=1,
+ help='Compress level when encoding images. Default: 1.')
+ args = parser.parse_args()
+
+ try:
+ import tensorflow as tf
+ except Exception:
+ raise ImportError('You need to install tensorflow to read tfrecords.')
+
+ if args.dataset == 'ffhq':
+ convert_ffhq_tfrecords(
+ args.tf_file,
+ args.log_resolution,
+ args.save_root,
+ save_type=args.save_type,
+ compress_level=args.compress_level)
+ else:
+ convert_celeba_tfrecords(
+ args.tf_file,
+ args.log_resolution,
+ args.save_root,
+ save_type=args.save_type,
+ compress_level=args.compress_level)
diff --git a/scripts/extract_subimages.py b/scripts/data_preparation/extract_subimages.py
similarity index 96%
rename from scripts/extract_subimages.py
rename to scripts/data_preparation/extract_subimages.py
index 6cf06b1..6424e8d 100644
--- a/scripts/extract_subimages.py
+++ b/scripts/data_preparation/extract_subimages.py
@@ -1,12 +1,12 @@
import cv2
-import mmcv
import numpy as np
import os
import sys
from multiprocessing import Pool
from os import path as osp
+from tqdm import tqdm
-from basicsr.utils.util import ProgressBar
+from basicsr.utils import scandir
def main():
@@ -94,16 +94,16 @@ def extract_subimages(opt):
print(f'Folder {save_folder} already exists. Exit.')
sys.exit(1)
- img_list = list(mmcv.scandir(input_folder))
- img_list = [osp.join(input_folder, v) for v in img_list]
+ img_list = list(scandir(input_folder, full_path=True))
- pbar = ProgressBar(len(img_list))
+ pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
pool = Pool(opt['n_thread'])
for path in img_list:
pool.apply_async(
- worker, args=(path, opt), callback=lambda arg: pbar.update(arg))
+ worker, args=(path, opt), callback=lambda arg: pbar.update(1))
pool.close()
pool.join()
+ pbar.close()
print('All processes done.')
diff --git a/scripts/generate_meta_info.py b/scripts/data_preparation/generate_meta_info.py
similarity index 91%
rename from scripts/generate_meta_info.py
rename to scripts/data_preparation/generate_meta_info.py
index 22d851e..7bb1aed 100644
--- a/scripts/generate_meta_info.py
+++ b/scripts/data_preparation/generate_meta_info.py
@@ -1,7 +1,8 @@
-import mmcv
from os import path as osp
from PIL import Image
+from basicsr.utils import scandir
+
def generate_meta_info_div2k():
"""Generate meta info for DIV2K dataset.
@@ -10,7 +11,7 @@ def generate_meta_info_div2k():
gt_folder = 'datasets/DIV2K/DIV2K_train_HR_sub/'
meta_info_txt = 'basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt'
- img_list = sorted(list(mmcv.scandir(gt_folder)))
+ img_list = sorted(list(scandir(gt_folder)))
with open(meta_info_txt, 'w') as f:
for idx, img_path in enumerate(img_list):
diff --git a/scripts/regroup_reds_dataset.py b/scripts/data_preparation/regroup_reds_dataset.py
similarity index 86%
rename from scripts/regroup_reds_dataset.py
rename to scripts/data_preparation/regroup_reds_dataset.py
index 3ce71fa..7d3ddbf 100644
--- a/scripts/regroup_reds_dataset.py
+++ b/scripts/data_preparation/regroup_reds_dataset.py
@@ -18,8 +18,9 @@ def regroup_reds_dataset(train_path, val_path):
# move the validation data to the train folder
val_folders = glob.glob(os.path.join(val_path, '*'))
for folder in val_folders:
- new_folder_idx = int(folder.split(' / ')[-1]) + 240
- os.system(f'cp -r {folder} {os.path.join(train_path, new_folder_idx)}')
+ new_folder_idx = int(folder.split('/')[-1]) + 240
+ os.system(
+ f'cp -r {folder} {os.path.join(train_path, str(new_folder_idx))}')
if __name__ == '__main__':
diff --git a/scripts/download_gdrive.py b/scripts/download_gdrive.py
new file mode 100644
index 0000000..c3e34c7
--- /dev/null
+++ b/scripts/download_gdrive.py
@@ -0,0 +1,12 @@
+import argparse
+
+from basicsr.utils.download_util import download_file_from_google_drive
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('--id', type=str, help='File id')
+ parser.add_argument('--output', type=str, help='Save path')
+ args = parser.parse_args()
+
+ download_file_from_google_drive(args.id, args.save_path)
diff --git a/scripts/download_pretrained_models.py b/scripts/download_pretrained_models.py
index cc26218..3eb6911 100644
--- a/scripts/download_pretrained_models.py
+++ b/scripts/download_pretrained_models.py
@@ -1,19 +1,19 @@
import argparse
-import mmcv
+import os
from os import path as osp
-from basicsr.utils.download import download_file_from_google_drive
+from basicsr.utils.download_util import download_file_from_google_drive
def download_pretrained_models(method, file_ids):
save_path_root = f'./experiments/pretrained_models/{method}'
- mmcv.mkdir_or_exist(save_path_root)
+ os.makedirs(save_path_root, exist_ok=True)
for file_name, file_id in file_ids.items():
save_path = osp.abspath(osp.join(save_path_root, file_name))
if osp.exists(save_path):
user_response = input(
- f'{file_name} already exist. Do you want to cover it? Y/N')
+ f'{file_name} already exist. Do you want to cover it? Y/N\n')
if user_response.lower() == 'y':
print(f'Covering {file_name} to {save_path}')
download_file_from_google_drive(file_id, save_path)
@@ -112,9 +112,7 @@ def download_pretrained_models(method, file_ids):
'DFDNet_dict_512-f79685f0.pth':
'1iH00oMsoN_1OJaEQw3zP7_wqiAYMnY79',
'DFDNet_official-d1fa5650.pth':
- '1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe',
- 'FFHQ_5_landmarks_template_1024-90a00515.npy':
- '1IQdQcq9QnpW6YzRwDaNbpV-rJ1Cq7RUq'
+ '1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe'
},
'dlib': {
'mmod_human_face_detector-4cb19393.dat':
diff --git a/scripts/extract_images_from_tfrecords.py b/scripts/extract_images_from_tfrecords.py
deleted file mode 100644
index 3ee902a..0000000
--- a/scripts/extract_images_from_tfrecords.py
+++ /dev/null
@@ -1,123 +0,0 @@
-"""Read tfrecords w/o define a graph.
-
-Ref:
-http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/
-"""
-
-import cv2
-import glob
-import numpy as np
-import os
-
-from basicsr.utils.lmdb import LmdbMaker
-
-
-def celeba_tfrecords():
- # Configurations
- file_pattern = '/home/xtwang/datasets/CelebA_tfrecords/celeba-full-tfr/train/train-r08-s-*-of-*.tfrecords' # noqa:E501
- # r08: resolution 2^8 = 256
- resolution = 128
- save_path = f'/home/xtwang/datasets/CelebA_tfrecords/tmptrain_{resolution}'
-
- save_all_path = os.path.join(save_path, f'all_{resolution}')
- os.makedirs(save_all_path)
-
- idx = 0
- print(glob.glob(file_pattern))
- for record in glob.glob(file_pattern):
- record_iterator = tf.python_io.tf_record_iterator(record)
- for string_record in record_iterator:
- example = tf.train.Example()
- example.ParseFromString(string_record)
- # label = example.features.feature['label'].int64_list.value[0]
-
- # attr = example.features.feature['attr'].int64_list.value
- # male = attr[20]
- # young = attr[39]
-
- shape = example.features.feature['shape'].int64_list.value
- h, w, c = shape
- img_str = example.features.feature['data'].bytes_list.value[0]
- img = np.fromstring(img_str, dtype=np.uint8).reshape((h, w, c))
-
- # save image
- img = img[:, :, [2, 1, 0]]
- cv2.imwrite(os.path.join(save_all_path, f'{idx:08d}.png'), img)
-
- idx += 1
- print(idx)
-
-
-def ffhq_tfrecords():
- # Configurations
- file_pattern = '/home/xtwang/datasets/ffhq/ffhq-r10.tfrecords'
- resolution = 1024
- save_path = f'/home/xtwang/datasets/ffhq/ffhq_imgs/ffhq_{resolution}'
-
- os.makedirs(save_path, exist_ok=True)
- idx = 0
- print(glob.glob(file_pattern))
- for record in glob.glob(file_pattern):
- record_iterator = tf.python_io.tf_record_iterator(record)
- for string_record in record_iterator:
- example = tf.train.Example()
- example.ParseFromString(string_record)
-
- shape = example.features.feature['shape'].int64_list.value
- c, h, w = shape
- img_str = example.features.feature['data'].bytes_list.value[0]
- img = np.fromstring(img_str, dtype=np.uint8).reshape((c, h, w))
-
- # save image
- img = img.transpose(1, 2, 0)
- img = img[:, :, [2, 1, 0]]
- cv2.imwrite(os.path.join(save_path, f'{idx:08d}.png'), img)
-
- idx += 1
- print(idx)
-
-
-def ffhq_tfrecords_to_lmdb():
- # Configurations
- file_pattern = '/home/xtwang/datasets/ffhq/ffhq-r10.tfrecords'
- log_resolution = 10
- compress_level = 1
- lmdb_path = f'/home/xtwang/datasets/ffhq/ffhq_{2**log_resolution}.lmdb'
-
- idx = 0
- print(glob.glob(file_pattern))
-
- lmdb_maker = LmdbMaker(lmdb_path)
- for record in glob.glob(file_pattern):
- record_iterator = tf.python_io.tf_record_iterator(record)
- for string_record in record_iterator:
- example = tf.train.Example()
- example.ParseFromString(string_record)
-
- shape = example.features.feature['shape'].int64_list.value
- c, h, w = shape
- img_str = example.features.feature['data'].bytes_list.value[0]
- img = np.fromstring(img_str, dtype=np.uint8).reshape((c, h, w))
-
- # write image to lmdb
- img = img.transpose(1, 2, 0)
- img = img[:, :, [2, 1, 0]]
- _, img_byte = cv2.imencode(
- '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
- key = f'{idx:08d}/r{log_resolution:02d}'
- lmdb_maker.put(img_byte, key, (h, w, c))
-
- idx += 1
- print(key)
- lmdb_maker.close()
-
-
-if __name__ == '__main__':
- # we have test on TensorFlow 1.15
- try:
- import tensorflow as tf
- except Exception:
- raise ImportError('You need to install tensorflow to read tfrecords.')
- # celeba_tfrecords()
- # ffhq_tfrecords()
- ffhq_tfrecords_to_lmdb()
diff --git a/scripts/metrics/calculate_fid_folder.py b/scripts/metrics/calculate_fid_folder.py
new file mode 100644
index 0000000..b903160
--- /dev/null
+++ b/scripts/metrics/calculate_fid_folder.py
@@ -0,0 +1,83 @@
+import argparse
+import math
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+from basicsr.data import create_dataset
+from basicsr.metrics.fid import (calculate_fid, extract_inception_features,
+ load_patched_inception_v3)
+
+
+def calculate_fid_folder():
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('folder', type=str, help='Path to the folder.')
+ parser.add_argument(
+ '--fid_stats', type=str, help='Path to the dataset fid statistics.')
+ parser.add_argument('--batch_size', type=int, default=64)
+ parser.add_argument('--num_sample', type=int, default=50000)
+ parser.add_argument('--num_workers', type=int, default=4)
+ parser.add_argument(
+ '--backend',
+ type=str,
+ default='disk',
+ help='io backend for dataset. Option: disk, lmdb')
+ args = parser.parse_args()
+
+ # inception model
+ inception = load_patched_inception_v3(device)
+
+ # create dataset
+ opt = {}
+ opt['name'] = 'SingleImageDataset'
+ opt['type'] = 'SingleImageDataset'
+ opt['dataroot_lq'] = args.folder
+ opt['io_backend'] = dict(type=args.backend)
+ opt['mean'] = [0.5, 0.5, 0.5]
+ opt['std'] = [0.5, 0.5, 0.5]
+ dataset = create_dataset(opt)
+
+ # create dataloader
+ data_loader = DataLoader(
+ dataset=dataset,
+ batch_size=args.batch_size,
+ shuffle=False,
+ num_workers=args.num_workers,
+ sampler=None,
+ drop_last=False)
+ args.num_sample = min(args.num_sample, len(dataset))
+ total_batch = math.ceil(args.num_sample / args.batch_size)
+
+ def data_generator(data_loader, total_batch):
+ for idx, data in enumerate(data_loader):
+ if idx >= total_batch:
+ break
+ else:
+ yield data['lq']
+
+ features = extract_inception_features(
+ data_generator(data_loader, total_batch), inception, total_batch,
+ device)
+ features = features.numpy()
+ total_len = features.shape[0]
+ features = features[:args.num_sample]
+ print(f'Extracted {total_len} features, '
+ f'use the first {features.shape[0]} features to calculate stats.')
+
+ sample_mean = np.mean(features, 0)
+ sample_cov = np.cov(features, rowvar=False)
+
+ # load the dataset stats
+ stats = torch.load(args.fid_stats)
+ real_mean = stats['mean']
+ real_cov = stats['cov']
+
+ # calculate FID metric
+ fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov)
+ print('fid:', fid)
+
+
+if __name__ == '__main__':
+ calculate_fid_folder()
diff --git a/scripts/metrics/calculate_fid_stats_from_datasets.py b/scripts/metrics/calculate_fid_stats_from_datasets.py
new file mode 100644
index 0000000..8b61f5c
--- /dev/null
+++ b/scripts/metrics/calculate_fid_stats_from_datasets.py
@@ -0,0 +1,72 @@
+import argparse
+import math
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+from basicsr.data import create_dataset
+from basicsr.metrics.fid import (extract_inception_features,
+ load_patched_inception_v3)
+
+
+def calculate_stats_from_dataset():
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_sample', type=int, default=50000)
+ parser.add_argument('--batch_size', type=int, default=64)
+ parser.add_argument('--size', type=int, default=512)
+ parser.add_argument('--dataroot', type=str, default='datasets/ffhq')
+ args = parser.parse_args()
+
+ # inception model
+ inception = load_patched_inception_v3(device)
+
+ # create dataset
+ opt = {}
+ opt['name'] = 'FFHQ'
+ opt['type'] = 'FFHQDataset'
+ opt['dataroot_gt'] = f'datasets/ffhq/ffhq_{args.size}.lmdb'
+ opt['io_backend'] = dict(type='lmdb')
+ opt['use_hflip'] = False
+ opt['mean'] = [0.5, 0.5, 0.5]
+ opt['std'] = [0.5, 0.5, 0.5]
+ dataset = create_dataset(opt)
+
+ # create dataloader
+ data_loader = DataLoader(
+ dataset=dataset,
+ batch_size=args.batch_size,
+ shuffle=False,
+ num_workers=4,
+ sampler=None,
+ drop_last=False)
+ total_batch = math.ceil(args.num_sample / args.batch_size)
+
+ def data_generator(data_loader, total_batch):
+ for idx, data in enumerate(data_loader):
+ if idx >= total_batch:
+ break
+ else:
+ yield data['gt']
+
+ features = extract_inception_features(
+ data_generator(data_loader, total_batch), inception, total_batch,
+ device)
+ features = features.numpy()
+ total_len = features.shape[0]
+ features = features[:args.num_sample]
+ print(f'Extracted {total_len} features, '
+ f'use the first {features.shape[0]} features to calculate stats.')
+ mean = np.mean(features, 0)
+ cov = np.cov(features, rowvar=False)
+
+ save_path = f'inception_{opt["name"]}_{args.size}.pth'
+ torch.save(
+ dict(name=opt['name'], size=args.size, mean=mean, cov=cov),
+ save_path,
+ _use_new_zipfile_serialization=False)
+
+
+if __name__ == '__main__':
+ calculate_stats_from_dataset()
diff --git a/scripts/metrics/calculate_lpips.py b/scripts/metrics/calculate_lpips.py
new file mode 100644
index 0000000..d9fbd3c
--- /dev/null
+++ b/scripts/metrics/calculate_lpips.py
@@ -0,0 +1,56 @@
+import cv2
+import glob
+import numpy as np
+import os.path as osp
+from torchvision.transforms.functional import normalize
+
+from basicsr.utils import img2tensor
+
+try:
+ import lpips
+except ImportError:
+ print('Please install lpips: pip install lpips')
+
+
+def main():
+ # Configurations
+ # -------------------------------------------------------------------------
+ folder_gt = 'datasets/celeba/celeba_512_validation'
+ folder_restored = 'datasets/celeba/celeba_512_validation_lq'
+ # crop_border = 4
+ suffix = ''
+ # -------------------------------------------------------------------------
+ loss_fn_vgg = lpips.LPIPS(net='vgg').cuda() # RGB, normalized to [-1,1]
+ lpips_all = []
+ img_list = sorted(glob.glob(osp.join(folder_gt, '*')))
+
+ mean = [0.5, 0.5, 0.5]
+ std = [0.5, 0.5, 0.5]
+ for i, img_path in enumerate(img_list):
+ basename, ext = osp.splitext(osp.basename(img_path))
+ img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(
+ np.float32) / 255.
+ img_restored = cv2.imread(
+ osp.join(folder_restored, basename + suffix + ext),
+ cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
+
+ img_gt, img_restored = img2tensor([img_gt, img_restored],
+ bgr2rgb=True,
+ float32=True)
+ # norm to [-1, 1]
+ normalize(img_gt, mean, std, inplace=True)
+ normalize(img_restored, mean, std, inplace=True)
+
+ # calculate lpips
+ lpips_val = loss_fn_vgg(
+ img_restored.unsqueeze(0).cuda(),
+ img_gt.unsqueeze(0).cuda())
+
+ print(f'{i+1:3d}: {basename:25}. \tLPIPS: {lpips_val:.6f}.')
+ lpips_all.append(lpips_val)
+
+ print(f'Average: LPIPS: {sum(lpips_all) / len(lpips_all):.6f}')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/calculate_psnr_ssim.py b/scripts/metrics/calculate_psnr_ssim.py
similarity index 57%
rename from scripts/calculate_psnr_ssim.py
rename to scripts/metrics/calculate_psnr_ssim.py
index 7e802d1..1a14af5 100644
--- a/scripts/calculate_psnr_ssim.py
+++ b/scripts/metrics/calculate_psnr_ssim.py
@@ -1,8 +1,10 @@
-import mmcv
+import cv2
import numpy as np
from os import path as osp
from basicsr.metrics import calculate_psnr, calculate_ssim
+from basicsr.utils import scandir
+from basicsr.utils.matlab_functions import bgr2ycbcr
def main():
@@ -23,11 +25,12 @@ def main():
crop_border = 4
suffix = '_expname'
test_y_channel = False
+ correct_mean_var = False
# -------------------------------------------------------------------------
psnr_all = []
ssim_all = []
- img_list = sorted(mmcv.scandir(folder_gt, recursive=True))
+ img_list = sorted(scandir(folder_gt, recursive=True, full_path=True))
if test_y_channel:
print('Testing Y channel.')
@@ -36,16 +39,35 @@ def main():
for i, img_path in enumerate(img_list):
basename, ext = osp.splitext(osp.basename(img_path))
- img_gt = mmcv.imread(
- osp.join(folder_gt, img_path), flag='unchanged').astype(
- np.float32) / 255.
- img_restored = mmcv.imread(
+ img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(
+ np.float32) / 255.
+ img_restored = cv2.imread(
osp.join(folder_restored, basename + suffix + ext),
- flag='unchanged').astype(np.float32) / 255.
+ cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
+
+ if correct_mean_var:
+ mean_l = []
+ std_l = []
+ for j in range(3):
+ mean_l.append(np.mean(img_gt[:, :, j]))
+ std_l.append(np.std(img_gt[:, :, j]))
+ for j in range(3):
+ # correct twice
+ mean = np.mean(img_restored[:, :, j])
+ img_restored[:, :,
+ j] = img_restored[:, :, j] - mean + mean_l[j]
+ std = np.std(img_restored[:, :, j])
+ img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j]
+
+ mean = np.mean(img_restored[:, :, j])
+ img_restored[:, :,
+ j] = img_restored[:, :, j] - mean + mean_l[j]
+ std = np.std(img_restored[:, :, j])
+ img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j]
if test_y_channel and img_gt.ndim == 3 and img_gt.shape[2] == 3:
- img_gt = mmcv.bgr2ycbcr(img_gt, y_only=True)
- img_restored = mmcv.bgr2ycbcr(img_restored, y_only=True)
+ img_gt = bgr2ycbcr(img_gt, y_only=True)
+ img_restored = bgr2ycbcr(img_restored, y_only=True)
# calculate PSNR and SSIM
psnr = calculate_psnr(
@@ -62,6 +84,8 @@ def main():
f'\tSSIM: {ssim:.6f}')
psnr_all.append(psnr)
ssim_all.append(ssim)
+ print(folder_gt)
+ print(folder_restored)
print(f'Average: PSNR: {sum(psnr_all) / len(psnr_all):.6f} dB, '
f'SSIM: {sum(ssim_all) / len(ssim_all):.6f}')
diff --git a/scripts/metrics/calculate_stylegan2_fid.py b/scripts/metrics/calculate_stylegan2_fid.py
new file mode 100644
index 0000000..bd3acb1
--- /dev/null
+++ b/scripts/metrics/calculate_stylegan2_fid.py
@@ -0,0 +1,79 @@
+import argparse
+import math
+import numpy as np
+import torch
+from torch import nn
+
+from basicsr.metrics.fid import (calculate_fid, extract_inception_features,
+ load_patched_inception_v3)
+from basicsr.models.archs.stylegan2_arch import StyleGAN2Generator
+
+
+def calculate_stylegan2_fid():
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ 'ckpt', type=str, help='Path to the stylegan2 checkpoint.')
+ parser.add_argument(
+ 'fid_stats', type=str, help='Path to the dataset fid statistics.')
+ parser.add_argument('--size', type=int, default=256)
+ parser.add_argument('--channel_multiplier', type=int, default=2)
+ parser.add_argument('--batch_size', type=int, default=64)
+ parser.add_argument('--num_sample', type=int, default=50000)
+ parser.add_argument('--truncation', type=float, default=1)
+ parser.add_argument('--truncation_mean', type=int, default=4096)
+ args = parser.parse_args()
+
+ # create stylegan2 model
+ generator = StyleGAN2Generator(
+ out_size=args.size,
+ num_style_feat=512,
+ num_mlp=8,
+ channel_multiplier=args.channel_multiplier,
+ resample_kernel=(1, 3, 3, 1))
+ generator.load_state_dict(torch.load(args.ckpt)['params_ema'])
+ generator = nn.DataParallel(generator).eval().to(device)
+
+ if args.truncation < 1:
+ with torch.no_grad():
+ truncation_latent = generator.mean_latent(args.truncation_mean)
+ else:
+ truncation_latent = None
+
+ # inception model
+ inception = load_patched_inception_v3(device)
+
+ total_batch = math.ceil(args.num_sample / args.batch_size)
+
+ def sample_generator(total_batch):
+ for i in range(total_batch):
+ with torch.no_grad():
+ latent = torch.randn(args.batch_size, 512, device=device)
+ samples, _ = generator([latent],
+ truncation=args.truncation,
+ truncation_latent=truncation_latent)
+ yield samples
+
+ features = extract_inception_features(
+ sample_generator(total_batch), inception, total_batch, device)
+ features = features.numpy()
+ total_len = features.shape[0]
+ features = features[:args.num_sample]
+ print(f'Extracted {total_len} features, '
+ f'use the first {features.shape[0]} features to calculate stats.')
+ sample_mean = np.mean(features, 0)
+ sample_cov = np.cov(features, rowvar=False)
+
+ # load the dataset stats
+ stats = torch.load(args.fid_stats)
+ real_mean = stats['mean']
+ real_cov = stats['cov']
+
+ # calculate FID metric
+ fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov)
+ print('fid:', fid)
+
+
+if __name__ == '__main__':
+ calculate_stylegan2_fid()
diff --git a/scripts/convert_dfdnet.py b/scripts/model_conversion/convert_dfdnet.py
similarity index 100%
rename from scripts/convert_dfdnet.py
rename to scripts/model_conversion/convert_dfdnet.py
diff --git a/scripts/convert_models.py b/scripts/model_conversion/convert_models.py
similarity index 100%
rename from scripts/convert_models.py
rename to scripts/model_conversion/convert_models.py
diff --git a/scripts/convert_stylegan.py b/scripts/model_conversion/convert_stylegan.py
similarity index 100%
rename from scripts/convert_stylegan.py
rename to scripts/model_conversion/convert_stylegan.py
diff --git a/scripts/publish_models.py b/scripts/publish_models.py
index ea2b5f4..ea4ae79 100644
--- a/scripts/publish_models.py
+++ b/scripts/publish_models.py
@@ -53,6 +53,7 @@ def convert_to_backward_compatible_models(paths):
if __name__ == '__main__':
- paths = glob.glob('experiments/pretrained_models/*.pth')
+ paths = glob.glob('experiments/pretrained_models/*.pth') + glob.glob(
+ 'experiments/pretrained_models/**/*.pth')
convert_to_backward_compatible_models(paths)
update_sha(paths)
diff --git a/setup.cfg b/setup.cfg
index dccb00b..ae5a6eb 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -16,6 +16,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = basicsr
-known_third_party = PIL,cv2,lmdb,matplotlib,mmcv,numpy,requests,scipy,skimage,torch,torchvision,yaml
+known_third_party = PIL,cv2,lmdb,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
diff --git a/setup.py b/setup.py
index 0a339ff..621007f 100644
--- a/setup.py
+++ b/setup.py
@@ -4,6 +4,7 @@
import os
import subprocess
+import sys
import time
import torch
from torch.utils.cpp_extension import (BuildExtension, CppExtension,
@@ -85,8 +86,9 @@ def get_version():
return locals()['__version__']
-def make_cuda_ext(name, module, sources, sources_cuda=[]):
-
+def make_cuda_ext(name, module, sources, sources_cuda=None):
+ if sources_cuda is None:
+ sources_cuda = []
define_macros = []
extra_compile_args = {'cxx': []}
@@ -118,6 +120,31 @@ def get_requirements(filename='requirements.txt'):
if __name__ == '__main__':
+ if '--no_cuda_ext' in sys.argv:
+ ext_modules = []
+ sys.argv.remove('--no_cuda_ext')
+ else:
+ ext_modules = [
+ make_cuda_ext(
+ name='deform_conv_ext',
+ module='basicsr.models.ops.dcn',
+ sources=['src/deform_conv_ext.cpp'],
+ sources_cuda=[
+ 'src/deform_conv_cuda.cpp',
+ 'src/deform_conv_cuda_kernel.cu'
+ ]),
+ make_cuda_ext(
+ name='fused_act_ext',
+ module='basicsr.models.ops.fused_act',
+ sources=['src/fused_bias_act.cpp'],
+ sources_cuda=['src/fused_bias_act_kernel.cu']),
+ make_cuda_ext(
+ name='upfirdn2d_ext',
+ module='basicsr.models.ops.upfirdn2d',
+ sources=['src/upfirdn2d.cpp'],
+ sources_cuda=['src/upfirdn2d_kernel.cu']),
+ ]
+
write_version_py()
setup(
name='basicsr',
@@ -142,25 +169,6 @@ def get_requirements(filename='requirements.txt'):
license='Apache License 2.0',
setup_requires=['cython', 'numpy'],
install_requires=get_requirements(),
- ext_modules=[
- make_cuda_ext(
- name='deform_conv_ext',
- module='basicsr.models.ops.dcn',
- sources=['src/deform_conv_ext.cpp'],
- sources_cuda=[
- 'src/deform_conv_cuda.cpp',
- 'src/deform_conv_cuda_kernel.cu'
- ]),
- make_cuda_ext(
- name='fused_act_ext',
- module='basicsr.models.ops.fused_act',
- sources=['src/fused_bias_act.cpp'],
- sources_cuda=['src/fused_bias_act_kernel.cu']),
- make_cuda_ext(
- name='upfirdn2d_ext',
- module='basicsr.models.ops.upfirdn2d',
- sources=['src/upfirdn2d.cpp'],
- sources_cuda=['src/upfirdn2d_kernel.cu']),
- ],
+ ext_modules=ext_modules,
cmdclass={'build_ext': BuildExtension},
zip_safe=False)
diff --git a/tests/test_face_dfdnet.py b/tests/test_face_dfdnet.py
deleted file mode 100644
index ab2d6db..0000000
--- a/tests/test_face_dfdnet.py
+++ /dev/null
@@ -1,353 +0,0 @@
-import argparse
-import cv2
-import glob
-import mmcv
-import numpy as np
-import os
-import torch
-import torchvision.transforms as transforms
-from skimage import io
-from skimage import transform as trans
-
-from basicsr.models.archs.dfdnet_arch import DFDNet
-from basicsr.utils import tensor2img
-
-try:
- import dlib
-except ImportError:
- print('Please install dlib before testing face restoration.'
- 'Reference: https://github.com/davisking/dlib')
-
-
-class FaceRestorationHelper(object):
- """Helper for the face restoration pipeline."""
-
- def __init__(self, upscale_factor, face_template_path, out_size=512):
- self.upscale_factor = upscale_factor
- self.out_size = (out_size, out_size)
-
- # standard 5 landmarks for FFHQ faces with 1024 x 1024
- self.face_template = np.load(face_template_path) / (1024 // out_size)
- # for estimation the 2D similarity transformation
- self.similarity_trans = trans.SimilarityTransform()
-
- self.all_landmarks_5 = []
- self.all_landmarks_68 = []
- self.affine_matrices = []
- self.inverse_affine_matrices = []
- self.cropped_faces = []
- self.restored_faces = []
-
- def init_dlib(self, detection_path, landmark5_path, landmark68_path):
- """Initialize the dlib detectors and predictors."""
- self.face_detector = dlib.cnn_face_detection_model_v1(detection_path)
- self.shape_predictor_5 = dlib.shape_predictor(landmark5_path)
- self.shape_predictor_68 = dlib.shape_predictor(landmark68_path)
-
- def free_dlib_gpu_memory(self):
- del self.face_detector
- del self.shape_predictor_5
- del self.shape_predictor_68
-
- def read_input_image(self, img_path):
- # self.input_img is Numpy array, (h, w, c) with RGB order
- self.input_img = dlib.load_rgb_image(img_path)
-
- def detect_faces(self, img_path, upsample_num_times=1):
- """
- Args:
- img_path (str): Image path.
- upsample_num_times (int): Upsamples the image before running the
- face detector
-
- Returns:
- int: Number of detected faces.
- """
- self.read_input_image(img_path)
- self.det_faces = self.face_detector(self.input_img, upsample_num_times)
- if len(self.det_faces) == 0:
- print('No face detected. Try to increase upsample_num_times.')
- return len(self.det_faces)
-
- def get_face_landmarks_5(self):
- for face in self.det_faces:
- shape = self.shape_predictor_5(self.input_img, face.rect)
- landmark = np.array([[part.x, part.y] for part in shape.parts()])
- self.all_landmarks_5.append(landmark)
- return len(self.all_landmarks_5)
-
- def get_face_landmarks_68(self):
- """Get 68 densemarks for cropped images.
-
- Should only have one face at most in the cropped image.
- """
- num_detected_face = 0
- for idx, face in enumerate(self.cropped_faces):
- # face detection
- det_face = self.face_detector(face, 1) # TODO: can we remove it
- if len(det_face) == 0:
- print(f'Cannot find faces in cropped image with index {idx}.')
- self.all_landmarks_68.append(None)
- elif len(det_face) == 1:
- shape = self.shape_predictor_68(face, det_face[0].rect)
- landmark = np.array([[part.x, part.y]
- for part in shape.parts()])
- self.all_landmarks_68.append(landmark)
- num_detected_face += 1
- else:
- print('Should only have one face at most.')
- return num_detected_face
-
- def warp_crop_faces(self, save_cropped_path=None):
- """Get affine matrix, warp and cropped faces.
-
- Also get inverse affine matrix for post-processing.
- """
- for idx, landmark in enumerate(self.all_landmarks_5):
- # use 5 landmarks to get affine matrix
- self.similarity_trans.estimate(landmark, self.face_template)
- affine_matrix = self.similarity_trans.params[0:2, :]
- self.affine_matrices.append(affine_matrix)
- # warp and crop faces
- cropped_face = cv2.warpAffine(self.input_img, affine_matrix,
- self.out_size)
- self.cropped_faces.append(cropped_face)
- # save the cropped face
- if save_cropped_path is not None:
- path, ext = os.path.splitext(save_cropped_path)
- save_path = f'{path}_{idx:02d}{ext}'
- mmcv.imwrite(mmcv.rgb2bgr(cropped_face), save_path)
-
- # get inverse affine matrix
- self.similarity_trans.estimate(self.face_template,
- landmark * self.upscale_factor)
- inverse_affine = self.similarity_trans.params[0:2, :]
- self.inverse_affine_matrices.append(inverse_affine)
-
- def add_restored_face(self, face):
- self.restored_faces.append(face)
-
- def paste_faces_to_input_image(self, save_path):
- # operate in the BGR order
- input_img = mmcv.rgb2bgr(self.input_img)
- h, w, _ = input_img.shape
- h_up, w_up = h * self.upscale_factor, w * self.upscale_factor
- # simply resize the background
- upsample_img = cv2.resize(input_img, (w_up, h_up))
- for restored_face, inverse_affine in zip(self.restored_faces,
- self.inverse_affine_matrices):
- inv_restored = cv2.warpAffine(restored_face, inverse_affine,
- (w_up, h_up))
- mask = np.ones((*self.out_size, 3), dtype=np.float32)
- inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
- # remove the black borders
- inv_mask_erosion = cv2.erode(
- inv_mask,
- np.ones((2 * self.upscale_factor, 2 * self.upscale_factor),
- np.uint8))
- inv_restored_remove_border = inv_mask_erosion * inv_restored
- total_face_area = np.sum(inv_mask_erosion) // 3
- # compute the fusion edge based on the area of face
- w_edge = int(total_face_area**0.5) // 20
- erosion_radius = w_edge * 2
- inv_mask_center = cv2.erode(
- inv_mask_erosion,
- np.ones((erosion_radius, erosion_radius), np.uint8))
- blur_size = w_edge * 2
- inv_soft_mask = cv2.GaussianBlur(inv_mask_center,
- (blur_size + 1, blur_size + 1), 0)
- upsample_img = inv_soft_mask * inv_restored_remove_border + (
- 1 - inv_soft_mask) * upsample_img
- mmcv.imwrite(upsample_img.astype(np.uint8), save_path)
-
- def clean_all(self):
- self.all_landmarks_5 = []
- self.all_landmarks_68 = []
- self.restored_faces = []
- self.affine_matrices = []
- self.cropped_faces = []
- self.inverse_affine_matrices = []
-
-
-def get_part_location(landmarks):
- """Get part locations from landmarks."""
- map_left_eye = list(np.hstack((range(17, 22), range(36, 42))))
- map_right_eye = list(np.hstack((range(22, 27), range(42, 48))))
- map_nose = list(range(29, 36))
- map_mouth = list(range(48, 68))
-
- # left eye
- mean_left_eye = np.mean(landmarks[map_left_eye], 0) # (x, y)
- half_len_left_eye = np.max((np.max(
- np.max(landmarks[map_left_eye], 0) -
- np.min(landmarks[map_left_eye], 0)) / 2, 16)) # A number
- loc_left_eye = np.hstack((mean_left_eye - half_len_left_eye + 1,
- mean_left_eye + half_len_left_eye)).astype(int)
- loc_left_eye = torch.from_numpy(loc_left_eye).unsqueeze(0)
- # (1, 4), the four numbers forms two coordinates in the diagonal
-
- # right eye
- mean_right_eye = np.mean(landmarks[map_right_eye], 0)
- half_len_right_eye = np.max((np.max(
- np.max(landmarks[map_right_eye], 0) -
- np.min(landmarks[map_right_eye], 0)) / 2, 16))
- loc_right_eye = np.hstack(
- (mean_right_eye - half_len_right_eye + 1,
- mean_right_eye + half_len_right_eye)).astype(int)
- loc_right_eye = torch.from_numpy(loc_right_eye).unsqueeze(0)
- # nose
- mean_nose = np.mean(landmarks[map_nose], 0)
- half_len_nose = np.max((np.max(
- np.max(landmarks[map_nose], 0) - np.min(landmarks[map_nose], 0)) / 2,
- 16)) # noqa: E126
- loc_nose = np.hstack(
- (mean_nose - half_len_nose + 1, mean_nose + half_len_nose)).astype(int)
- loc_nose = torch.from_numpy(loc_nose).unsqueeze(0)
- # mouth
- mean_mouth = np.mean(landmarks[map_mouth], 0)
- half_len_mouth = np.max((np.max(
- np.max(landmarks[map_mouth], 0) - np.min(landmarks[map_mouth], 0)) / 2,
- 16)) # noqa: E126
- loc_mouth = np.hstack((mean_mouth - half_len_mouth + 1,
- mean_mouth + half_len_mouth)).astype(int)
- loc_mouth = torch.from_numpy(loc_mouth).unsqueeze(0)
-
- return loc_left_eye, loc_right_eye, loc_nose, loc_mouth
-
-
-if __name__ == '__main__':
- """We try to align to the official codes. But there are still slight
- differences: 1) we use dlib for 68 landmark detection; 2) the used image
- package are different (especially for reading and writing.)
- """
- device = 'cuda'
- parser = argparse.ArgumentParser()
-
- parser.add_argument('--upscale_factor', type=int, default=2)
- parser.add_argument(
- '--model_path',
- type=str,
- default= # noqa: E251
- 'experiments/pretrained_models/DFDNet/DFDNet_official-d1fa5650.pth')
- parser.add_argument(
- '--dict_path',
- type=str,
- default= # noqa: E251
- 'experiments/pretrained_models/DFDNet/DFDNet_dict_512-f79685f0.pth')
- parser.add_argument('--test_path', type=str, default='datasets/TestWhole')
- parser.add_argument('--upsample_num_times', type=int, default=1)
- # The official codes use skimage.io to read the cropped images from disk
- # instead of directly using the intermediate results in the memory (as we
- # do). Such a different operation brings slight differences due to
- # skimage.io. For aligning with the official results, we could set the
- # official_adaption to True.
- parser.add_argument('--official_adaption', type=bool, default=True)
-
- # The following are the paths for face template and dlib models
- parser.add_argument(
- '--face_template_path',
- type=str,
- default= # noqa: E251
- 'experiments/pretrained_models/DFDNet/FFHQ_5_landmarks_template_1024-90a00515.npy' # noqa: E501
- )
- parser.add_argument(
- '--detection_path',
- type=str,
- default= # noqa: E251
- 'experiments/pretrained_models/dlib/mmod_human_face_detector-4cb19393.dat' # noqa: E501
- )
- parser.add_argument(
- '--landmark5_path',
- type=str,
- default= # noqa: E251
- 'experiments/pretrained_models/dlib/shape_predictor_5_face_landmarks-c4b1e980.dat' # noqa: E501
- )
- parser.add_argument(
- '--landmark68_path',
- type=str,
- default= # noqa: E251
- 'experiments/pretrained_models/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat' # noqa: E501
- )
-
- args = parser.parse_args()
- result_root = f'results/DFDNet/{args.test_path.split("/")[-1]}'
-
- # set up the DFDNet
- net = DFDNet(64, dict_path=args.dict_path).to(device)
- checkpoint = torch.load(
- args.model_path, map_location=lambda storage, loc: storage)
- net.load_state_dict(checkpoint['params'])
- net.eval()
-
- save_crop_root = os.path.join(result_root, 'cropped_faces')
- save_restore_root = os.path.join(result_root, 'restored_faces')
- save_final_root = os.path.join(result_root, 'final_results')
-
- face_helper = FaceRestorationHelper(
- args.upscale_factor, args.face_template_path, out_size=512)
-
- # scan all the jpg and png images
- for img_path in glob.glob(os.path.join(args.test_path, '*.[jp][pn]g')):
- img_name = os.path.basename(img_path)
- print(f'Processing {img_name} image ...')
- save_crop_path = os.path.join(save_crop_root, img_name)
-
- face_helper.init_dlib(args.detection_path, args.landmark5_path,
- args.landmark68_path)
- # detect faces
- num_det_faces = face_helper.detect_faces(
- img_path, upsample_num_times=args.upsample_num_times)
- # get 5 face landmarks for each face
- num_landmarks = face_helper.get_face_landmarks_5()
- print(f'\tDetect {num_det_faces} faces, {num_landmarks} landmarks.')
- # warp and crop each face
- face_helper.warp_crop_faces(save_crop_path)
-
- if args.official_adaption:
- path, ext = os.path.splitext(save_crop_path)
- pathes = sorted(glob.glob(f'{path}_[0-9]*{ext}'))
- cropped_faces = [io.imread(path) for path in pathes]
- else:
- cropped_faces = face_helper.cropped_faces
-
- # get 68 landmarks for each cropped face
- num_landmarks = face_helper.get_face_landmarks_68()
- print(f'\tDetect {num_landmarks} faces for 68 landmarks.')
-
- face_helper.free_dlib_gpu_memory()
-
- print('\tFace restoration ...')
- # face restoration for each cropped face
- for idx, (cropped_face, landmarks) in enumerate(
- zip(cropped_faces, face_helper.all_landmarks_68)):
- if landmarks is None:
- print(f'Landmarks is None, skip cropped faces with idx {idx}.')
- else:
- # prepare data
- part_locations = get_part_location(landmarks)
- cropped_face = transforms.ToTensor()(cropped_face)
- cropped_face = transforms.Normalize((0.5, 0.5, 0.5),
- (0.5, 0.5, 0.5))(
- cropped_face)
- cropped_face = cropped_face.unsqueeze(0).to(device)
-
- with torch.no_grad():
- output = net(cropped_face, part_locations)
- im = tensor2img(output, min_max=(-1, 1))
- del output
- torch.cuda.empty_cache()
- path, ext = os.path.splitext(
- os.path.join(save_restore_root, img_name))
- save_path = f'{path}_{idx:02d}{ext}'
- mmcv.imwrite(im, save_path)
- face_helper.add_restored_face(im)
-
- print('\tGenerate the final result ...')
- # paste each restored face to the input image
- face_helper.paste_faces_to_input_image(
- os.path.join(save_final_root, img_name))
-
- # clean all the intermediate results to process the next image
- face_helper.clean_all()
-
- print(f'\nAll results are saved in {result_root}')
diff --git a/tests/test_ffhq_dataset.py b/tests/test_ffhq_dataset.py
index 5486385..655e402 100644
--- a/tests/test_ffhq_dataset.py
+++ b/tests/test_ffhq_dataset.py
@@ -1,5 +1,5 @@
import math
-import mmcv
+import os
import torch
import torchvision.utils
@@ -29,7 +29,7 @@ def main():
opt['dataset_enlarge_ratio'] = 1
- mmcv.mkdir_or_exist('tmp')
+ os.makedirs('tmp', exist_ok=True)
dataset = create_dataset(opt)
data_loader = create_dataloader(
diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py
index 9562ffd..b9642d1 100644
--- a/tests/test_lr_scheduler.py
+++ b/tests/test_lr_scheduler.py
@@ -1,10 +1,14 @@
-import matplotlib as mpl
import torch
-from matplotlib import pyplot as plt
-from matplotlib import ticker as mtick
from basicsr.models.lr_scheduler import CosineAnnealingRestartLR
+try:
+ import matplotlib as mpl
+ from matplotlib import pyplot as plt
+ from matplotlib import ticker as mtick
+except ImportError:
+ print('Please install matplotlib.')
+
mpl.use('Agg')
diff --git a/tests/test_paired_image_dataset.py b/tests/test_paired_image_dataset.py
index 3c415a3..a133a36 100644
--- a/tests/test_paired_image_dataset.py
+++ b/tests/test_paired_image_dataset.py
@@ -1,5 +1,5 @@
import math
-import mmcv
+import os
import torchvision.utils
from basicsr.data import create_dataloader, create_dataset
@@ -44,7 +44,7 @@ def main(mode='folder'):
opt['dataset_enlarge_ratio'] = 1
- mmcv.mkdir_or_exist('tmp')
+ os.makedirs('tmp', exist_ok=True)
dataset = create_dataset(opt)
data_loader = create_dataloader(
diff --git a/tests/test_reds_dataset.py b/tests/test_reds_dataset.py
index 7863fe0..cbf23a6 100644
--- a/tests/test_reds_dataset.py
+++ b/tests/test_reds_dataset.py
@@ -1,5 +1,5 @@
import math
-import mmcv
+import os
import torchvision.utils
from basicsr.data import create_dataloader, create_dataset
@@ -45,7 +45,7 @@ def main(mode='folder'):
opt['dataset_enlarge_ratio'] = 1
- mmcv.mkdir_or_exist('tmp')
+ os.makedirs('tmp', exist_ok=True)
dataset = create_dataset(opt)
data_loader = create_dataloader(
diff --git a/tests/test_vimeo90k_dataset.py b/tests/test_vimeo90k_dataset.py
index 8a9661a..80bb45a 100644
--- a/tests/test_vimeo90k_dataset.py
+++ b/tests/test_vimeo90k_dataset.py
@@ -1,5 +1,5 @@
import math
-import mmcv
+import os
import torchvision.utils
from basicsr.data import create_dataloader, create_dataset
@@ -41,7 +41,7 @@ def main(mode='folder'):
opt['dataset_enlarge_ratio'] = 1
- mmcv.mkdir_or_exist('tmp')
+ os.makedirs('tmp', exist_ok=True)
dataset = create_dataset(opt)
data_loader = create_dataloader(