-
Notifications
You must be signed in to change notification settings - Fork 315
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' of github.com:xinntao/BasicSR
- Loading branch information
Showing
33 changed files
with
1,269 additions
and
102 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,32 @@ | |
BasicSR 是一个基于 PyTorch 的**开源**图像视频超分辨率 (Super-Resolution) 工具箱 (之后会支持更多的 Restoration 任务).<br> | ||
<sub>([ESRGAN](https://github.com/xinntao/ESRGAN), [EDVR](https://github.com/xinntao/EDVR), [DNI](https://github.com/xinntao/DNI), [SFTGAN](https://github.com/xinntao/SFTGAN))</sub> | ||
|
||
## :sparkles: 新的特性 | ||
|
||
- Sep 8, 2020. 添加 **盲人脸复原推理代码: [DFDNet](https://github.com/csxmli2016/DFDNet)**. 注意和官方代码有些微差异. | ||
> Blind Face Restoration via Deep Multi-scale Component Dictionaries <br> | ||
> Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang <br> | ||
> European Conference on Computer Vision (ECCV), 2020 | ||
- Aug 27, 2020. 添加 **StyleGAN2 训练和测试** 代码: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch). | ||
> Analyzing and Improving the Image Quality of StyleGAN <br> | ||
> Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila <br> | ||
> Computer Vision and Pattern Recognition (CVPR), 2020 | ||
<details> | ||
<summary>更多</summary> | ||
<ul> | ||
<li>Aug 19, 2020. 全新的 BasicSR v1.0.0 上线.</li> | ||
</ul> | ||
</details> | ||
|
||
## :zap:HOWTOs | ||
|
||
我们提供了简单的流程来快速上手 训练/测试/推理 模型. 这些命令并不能涵盖所有用法, 更多的细节参见下面的部分. | ||
|
||
- :zap: [如何训练 StyleGAN2](docs/HOWTOs_CN.md#如何训练-StyleGAN2) | ||
- :zap: [如何测试 StyleGAN2](docs/HOWTOs_CN.md#如何测试-StyleGAN2) | ||
- :zap: [如何测试 DFDNet](docs/HOWTOs_CN.md#如何测试-DFDNet) | ||
|
||
## 依赖和安装 | ||
|
||
- Python >= 3.7 (推荐使用 [Anaconda](https://www.anaconda.com/download/#linux) 或 [Miniconda](https://docs.conda.io/en/latest/miniconda.html)) | ||
|
@@ -28,13 +54,6 @@ python setup.py develop | |
|
||
注意: BasicSR 仅在 Ubuntu 下进行测试,或许不支持Windows. 可以在Windows下尝试[支持CUDA的Windows WSL](https://docs.microsoft.com/en-us/windows/win32/direct3d12/gpu-cuda-in-wsl) :-) (目前只有Fast ring的预览版系统可以安装). | ||
|
||
## HOWTOs | ||
|
||
我们提供了简单的流程来快速上手 训练/测试/推理 模型. 这些命令并不能涵盖所有用法, 更多的细节参见下面的部分. | ||
|
||
- [如何训练 StyleGAN2](docs/HOWTOs_CN.md#如何训练-StyleGAN2) | ||
- [如何测试 StyleGAN2](docs/HOWTOs_CN.md#如何测试-StyleGAN2) | ||
|
||
## TODO 清单 | ||
|
||
参见 [project boards](https://github.com/xinntao/BasicSR/projects). | ||
|
@@ -52,6 +71,9 @@ python setup.py develop | |
|
||
## 模型库和基准 | ||
|
||
**[下载官方提供的预训练模型](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing)** <br> | ||
**[下载复现的模型和log](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing)** | ||
|
||
- 目前支持的模型描述, 参见 [Models_CN.md](docs/Models_CN.md). | ||
- **预训练模型和log样例**, 参见 **[ModelZoo_CN.md](docs/ModelZoo_CN.md)**. | ||
- 我们也在 [wandb](https://app.wandb.ai/xintao/basicsr) 上提供了**训练曲线**等: | ||
|
@@ -77,5 +99,3 @@ python setup.py develop | |
#### 联系 | ||
|
||
若有任何问题, 请电邮 `[email protected]`. | ||
|
||
<sub><sup>[BasicSR-private](https://github.com/xinntao/BasicSR-private)</sup></sub> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
1.0.1 | ||
1.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.nn.utils.spectral_norm as SpectralNorm | ||
|
||
from basicsr.models.archs.dfdnet_util import (AttentionBlock, Blur, | ||
MSDilationBlock, UpResBlock, | ||
adaptive_instance_normalization) | ||
from basicsr.models.archs.vgg_arch import VGGFeatureExtractor | ||
|
||
|
||
class SFTUpBlock(nn.Module): | ||
"""Spatial feature transform (SFT) with upsampling block.""" | ||
|
||
def __init__(self, in_channel, out_channel, kernel_size=3, padding=1): | ||
super(SFTUpBlock, self).__init__() | ||
self.conv1 = nn.Sequential( | ||
Blur(in_channel), | ||
SpectralNorm( | ||
nn.Conv2d( | ||
in_channel, out_channel, kernel_size, padding=padding)), | ||
nn.LeakyReLU(0.04, True), | ||
# The official codes use two LeakyReLU here, so 0.04 for equivalent | ||
) | ||
self.convup = nn.Sequential( | ||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), | ||
SpectralNorm( | ||
nn.Conv2d( | ||
out_channel, out_channel, kernel_size, padding=padding)), | ||
nn.LeakyReLU(0.2, True), | ||
) | ||
|
||
# for SFT scale and shift | ||
self.scale_block = nn.Sequential( | ||
SpectralNorm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), | ||
nn.LeakyReLU(0.2, True), | ||
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))) | ||
self.shift_block = nn.Sequential( | ||
SpectralNorm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), | ||
nn.LeakyReLU(0.2, True), | ||
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), | ||
nn.Sigmoid()) | ||
# The official codes use sigmoid for shift block, do not know why | ||
|
||
def forward(self, x, updated_feat): | ||
out = self.conv1(x) | ||
# SFT | ||
scale = self.scale_block(updated_feat) | ||
shift = self.shift_block(updated_feat) | ||
out = out * scale + shift | ||
# upsample | ||
out = self.convup(out) | ||
return out | ||
|
||
|
||
class VGGFaceFeatureExtractor(VGGFeatureExtractor): | ||
|
||
def preprocess(self, x): | ||
# norm to [0, 1] | ||
x = (x + 1) / 2 | ||
if self.use_input_norm: | ||
x = (x - self.mean) / self.std | ||
if x.shape[3] < 224: | ||
x = torch.nn.functional.interpolate( | ||
x, size=(224, 224), mode='bilinear', align_corners=False) | ||
return x | ||
|
||
def forward(self, x): | ||
x = self.preprocess(x) | ||
features = [] | ||
for key, layer in self.vgg_net._modules.items(): | ||
x = layer(x) | ||
if key in self.layer_name_list: | ||
features.append(x) | ||
return features | ||
|
||
|
||
class DFDNet(nn.Module): | ||
"""DFDNet: Deep Face Dictionary Network. | ||
It only processes faces with 512x512 size. | ||
""" | ||
|
||
def __init__(self, num_feat, dict_path): | ||
super().__init__() | ||
self.parts = ['left_eye', 'right_eye', 'nose', 'mouth'] | ||
# part_sizes: [80, 80, 50, 110] | ||
channel_sizes = [128, 256, 512, 512] | ||
self.feature_sizes = np.array([256, 128, 64, 32]) | ||
self.flag_dict_device = False | ||
|
||
# dict | ||
self.dict = torch.load(dict_path) | ||
|
||
# vgg face extractor | ||
self.vgg_extractor = VGGFaceFeatureExtractor( | ||
layer_name_list=['conv2_2', 'conv3_4', 'conv4_4', 'conv5_4'], | ||
vgg_type='vgg19', | ||
use_input_norm=True, | ||
requires_grad=False) | ||
|
||
# attention block for fusing dictionary features and input features | ||
self.attn_blocks = nn.ModuleDict() | ||
for idx, feat_size in enumerate(self.feature_sizes): | ||
for name in self.parts: | ||
self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock( | ||
channel_sizes[idx]) | ||
|
||
# multi scale dilation block | ||
self.multi_scale_dilation = MSDilationBlock( | ||
num_feat * 8, dilation=[4, 3, 2, 1]) | ||
|
||
# upsampling and reconstruction | ||
self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8) | ||
self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4) | ||
self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2) | ||
self.upsample3 = SFTUpBlock(num_feat * 2, num_feat) | ||
self.upsample4 = nn.Sequential( | ||
SpectralNorm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), | ||
nn.LeakyReLU(0.2, True), UpResBlock(num_feat), | ||
UpResBlock(num_feat), | ||
nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), | ||
nn.Tanh()) | ||
|
||
def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, | ||
f_size): | ||
"""swap the features from the dictionary.""" | ||
# get the original vgg features | ||
part_feat = vgg_feat[:, :, location[1]:location[3], | ||
location[0]:location[2]].clone() | ||
# resize original vgg features | ||
part_resize_feat = F.interpolate( | ||
part_feat, | ||
dict_feat.size()[2:4], | ||
mode='bilinear', | ||
align_corners=False) | ||
# use adaptive instance normalization to adjust color and illuminations | ||
dict_feat = adaptive_instance_normalization(dict_feat, | ||
part_resize_feat) | ||
# get similarity scores | ||
similarity_score = F.conv2d(part_resize_feat, dict_feat) | ||
similarity_score = F.softmax(similarity_score.view(-1), dim=0) | ||
# select the most similar features in the dict (after norm) | ||
select_idx = torch.argmax(similarity_score) | ||
swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], | ||
part_feat.size()[2:4]) | ||
# attention | ||
attn = self.attn_blocks[f'{part_name}_' + str(f_size)]( | ||
swap_feat - part_feat) | ||
attn_feat = attn * swap_feat | ||
# update features | ||
updated_feat[:, :, location[1]:location[3], | ||
location[0]:location[2]] = attn_feat + part_feat | ||
return updated_feat | ||
|
||
def put_dict_to_device(self, x): | ||
if self.flag_dict_device is False: | ||
for k, v in self.dict.items(): | ||
for kk, vv in v.items(): | ||
self.dict[k][kk] = vv.to(x) | ||
self.flag_dict_device = True | ||
|
||
def forward(self, x, part_locations): | ||
""" | ||
Now only support testing with batch size = 0. | ||
Args: | ||
x (Tensor): Input faces with shape (b, c, 512, 512). | ||
part_locations (list[Tensor]): Part locations. | ||
""" | ||
self.put_dict_to_device(x) | ||
# extract vggface features | ||
vgg_features = self.vgg_extractor(x) | ||
# update vggface features using the dictionary for each part | ||
updated_vgg_features = [] | ||
batch = 0 # only supports testing with batch size = 0 | ||
for i, f_size in enumerate(self.feature_sizes): | ||
dict_features = self.dict[f'{f_size}'] | ||
vgg_feat = vgg_features[i] | ||
updated_feat = vgg_feat.clone() | ||
|
||
# swap features from dictionary | ||
for part_idx, part_name in enumerate(self.parts): | ||
location = (part_locations[part_idx][batch] // | ||
(512 / f_size)).int() | ||
updated_feat = self.swap_feat(vgg_feat, updated_feat, | ||
dict_features[part_name], | ||
location, part_name, f_size) | ||
|
||
updated_vgg_features.append(updated_feat) | ||
|
||
vgg_feat_dilation = self.multi_scale_dilation(vgg_features[3]) | ||
# use updated vgg features to modulate the upsampled features with | ||
# SFT (Spatial Feature Transform) scaling and shifting manner. | ||
upsampled_feat = self.upsample0(vgg_feat_dilation, | ||
updated_vgg_features[3]) | ||
upsampled_feat = self.upsample1(upsampled_feat, | ||
updated_vgg_features[2]) | ||
upsampled_feat = self.upsample2(upsampled_feat, | ||
updated_vgg_features[1]) | ||
upsampled_feat = self.upsample3(upsampled_feat, | ||
updated_vgg_features[0]) | ||
out = self.upsample4(upsampled_feat) | ||
|
||
return out |
Oops, something went wrong.