diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index dc4f3cf..0b61a71 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -1,6 +1,6 @@ name: Python Lint -on: [push] +on: [push, pull_request] jobs: build: @@ -25,5 +25,5 @@ jobs: - name: Lint run: | flake8 . - isort --check-only --diff basicsr/ options/ scripts/ tests/ setup.py - yapf -r -d basicsr/ options/ scripts/ tests/ setup.py + isort --check-only --diff basicsr/ options/ scripts/ tests/ inference/ setup.py + yapf -r -d basicsr/ options/ scripts/ tests/ inference/ setup.py diff --git a/.gitignore b/.gitignore index 5d63054..4e4055d 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ version.py # ignored files with suffix *.html *.png +*.jpeg *.jpg *.gif *.pth diff --git a/LICENSE/README.md b/LICENSE/README.md index 159c492..a5d8b6f 100644 --- a/LICENSE/README.md +++ b/LICENSE/README.md @@ -13,3 +13,5 @@ This BasicSR project is released under the Apache 2.0 license. 1. NIQE metric: the codes are translated from the [official MATLAB codes](http://live.ece.utexas.edu/research/quality/niqe_release.zip) > A. Mittal, R. Soundararajan and A. C. Bovik, "Making a Completely Blind Image Quality Analyzer", IEEE Signal Processing Letters, 2012. + +1. FID metric: the codes are modified from [pytorch-fid](https://github.com/mseitzer/pytorch-fid) and [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch). diff --git a/README.md b/README.md index 9678432..e8fb360 100644 --- a/README.md +++ b/README.md @@ -7,32 +7,32 @@ Note that this version is not compatible with previous versions. If you want to [English](README.md) **|** [简体中文](README_CN.md)   [GitHub](https://github.com/xinntao/BasicSR) **|** [Gitee码云](https://gitee.com/xinntao/BasicSR) -:arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing) +google colab logo Google Colab: [GitHub Link](colab) **|** [Google Drive Link](https://drive.google.com/drive/folders/1G_qcpvkT5ixmw5XoN6MupkOzcK1km625?usp=sharing)
+:m: [Model Zoo](docs/ModelZoo.md) :arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing) :arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ)
+:file_folder: [Datasets](docs/DatasetPreparation.md) :arrow_double_down: [Google Drive](https://drive.google.com/drive/folders/1gt5eT293esqY0yr1Anbm36EdnxWW_5oH?usp=sharing) :arrow_double_down: [百度网盘](https://pan.baidu.com/s/1AZDcEAFwwc1OC3KCd7EDnQ) (提取码:basr)
:chart_with_upwards_trend: [Training curves in wandb](https://app.wandb.ai/xintao/basicsr)
:computer: [Commands for training and testing](docs/TrainTest.md)
:zap: [HOWTOs](#zap-howtos) --- -BasicSR is an **open source** image and video super-resolution toolbox based on PyTorch (will extend to more restoration tasks in the future).
+BasicSR (**Basic** **S**uper **R**estoration) is an open source **image and video restoration** toolbox based on PyTorch, such as super-resolution, denoise, deblurring, JPEG artifacts removal, *etc*.
([ESRGAN](https://github.com/xinntao/ESRGAN), [EDVR](https://github.com/xinntao/EDVR), [DNI](https://github.com/xinntao/DNI), [SFTGAN](https://github.com/xinntao/SFTGAN)) +([HandyView](https://github.com/xinntao/HandyView), [HandyFigure](https://github.com/xinntao/HandyFigure), [HandyCrawler](https://github.com/xinntao/HandyCrawler), [HandyWriting](https://github.com/xinntao/HandyWriting)) -## :sparkles: New Feature +## :sparkles: New Features -- Sep 8, 2020. Add **blind face restoration inference codes: [DFDNet](https://github.com/csxmli2016/DFDNet)**. Note that it is slightly different from the official testing codes. - > Blind Face Restoration via Deep Multi-scale Component Dictionaries
- > Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
- > European Conference on Computer Vision (ECCV), 2020 +- Nov 29, 2020. Add **ESRGAN** and **DFDNet** [colab demo](colab). +- Sep 8, 2020. Add **blind face restoration** inference codes: [DFDNet](https://github.com/csxmli2016/DFDNet). - Aug 27, 2020. Add **StyleGAN2 training and testing** codes: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch). - > Analyzing and Improving the Image Quality of StyleGAN
- > Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
- > Computer Vision and Pattern Recognition (CVPR), 2020
More
@@ -41,9 +41,21 @@ BasicSR is an **open source** image and video super-resolution toolbox based on We provides simple pipelines to train/test/inference models for quick start. These pipelines/commands cannot cover all the cases and more details are in the following sections. -- [How to train StyleGAN2](docs/HOWTOs.md#How-to-train-StyleGAN2) -- [How to test StyleGAN2](docs/HOWTOs.md#How-to-test-StyleGAN2) -- [How to test DFDNet](docs/HOWTOs.md#How-to-test-DFDNet) +| GAN | | | | | | +| :--- | :---: | :---: | :--- | :---: | :---: | +| StyleGAN2 | [Train](docs/HOWTOs.md#How-to-train-StyleGAN2) | [Inference](docs/HOWTOs.md#How-to-inference-StyleGAN2) | | | | +| **Face Restoration** | | | | | | +| DFDNet | - | [Inference](docs/HOWTOs.md#How-to-inference-DFDNet) | | | | +| **Super Resolution** | | | | | | +| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO*| +| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO*| +| RCAN | *TODO* | *TODO* | | | | +| EDVR | *TODO* | *TODO* | DUF | - | *TODO* | +| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* | +| **Deblurring** | | | | | | +| DeblurGANv2 | - | *TODO* | | | | +| **Denoise** | | | | | | +| RIDNet | - | *TODO* | CBDNet | - | *TODO*| ## :wrench: Dependencies and Installation @@ -51,13 +63,46 @@ These pipelines/commands cannot cover all the cases and more details are in the - [PyTorch >= 1.3](https://pytorch.org/) - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) -Please run the following commands in the **BasicSR root path** to install BasicSR:
-(Make sure that your GCC version: gcc >= 5) +1. Clone repo -```bash -pip install -r requirements.txt -python setup.py develop -``` + ```bash + git clone https://github.com/xinntao/BasicSR.git + ``` + +1. Install dependent packages + + ```bash + cd BasicSR + pip install -r requirements.txt + ``` + +1. Install BasicSR + + Please run the following commands in the **BasicSR root path** to install BasicSR:
+ (Make sure that your GCC version: gcc >= 5)
+ If you do not need the cuda extensions:
+  [*dcn* for EDVR](basicsr/models/ops)
+  [*upfirdn2d* and *fused_act* for StyleGAN2](basicsr/models/ops)
+ please add `--no_cuda_ext` when installing + + ```bash + python setup.py develop --no_cuda_ext + ``` + + If you use the EDVR and StyleGAN2 model, the above cuda extensions are necessary. + + ```bash + python setup.py develop + ``` + + You may also want to specify the CUDA paths: + + ```bash + CUDA_HOME=/usr/local/cuda \ + CUDNN_INCLUDE_DIR=/usr/local/cuda \ + CUDNN_LIB_DIR=/usr/local/cuda \ + python setup.py develop + ``` Note that BasicSR is only tested in Ubuntu, and may be not suitable for Windows. You may try [Windows WSL with CUDA supports](https://docs.microsoft.com/en-us/windows/win32/direct3d12/gpu-cuda-in-wsl) :-) (It is now only available for insider build with Fast ring). @@ -76,7 +121,7 @@ Please see [project boards](https://github.com/xinntao/BasicSR/projects). - **Options/Configs**: Please refer to [Config.md](docs/Config.md). - **Logging**: Please refer to [Logging.md](docs/Logging.md). -## :card_file_box: Model Zoo and Baselines +## :european_castle: Model Zoo and Baselines - The descriptions of currently supported models are in [Models.md](docs/Models.md). - **Pre-trained models and log examples** are available in **[ModelZoo.md](docs/ModelZoo.md)**. @@ -97,8 +142,25 @@ The figure below shows the overall framework. More descriptions for each compone ## :scroll: License and Acknowledgement -This project is released under the Apache 2.0 license. -More details about license and acknowledgement are in [LICENSE](LICENSE/README.md). +This project is released under the Apache 2.0 license.
+More details about **license** and **acknowledgement** are in [LICENSE](LICENSE/README.md). + +## :earth_asia: Citations + +If BasicSR helps your research or work, please consider citing BasicSR.
+The following is a BibTeX reference. The BibTeX entry requires the `url` LaTeX package. + +``` latex +@misc{wang2020basicsr, + author = {Xintao Wang and Ke Yu and Kelvin C.K. Chan and + Chao Dong and Chen Change Loy}, + title = {BasicSR}, + howpublished = {\url{https://github.com/xinntao/BasicSR}}, + year = {2020} +} +``` + +> Xintao Wang, Ke Yu, Kelvin C.K. Chan, Chao Dong and Chen Change Loy. BasicSR. https://github.com/xinntao/BasicSR, 2020. ## :e-mail: Contact diff --git a/README_CN.md b/README_CN.md index c963f68..452c804 100644 --- a/README_CN.md +++ b/README_CN.md @@ -8,32 +8,32 @@ [English](README.md) **|** [简体中文](README_CN.md)   [GitHub](https://github.com/xinntao/BasicSR) **|** [Gitee码云](https://gitee.com/xinntao/BasicSR) -:arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ) +google colab logo Google Colab: [GitHub Link](colab) **|** [Google Drive Link](https://drive.google.com/drive/folders/1G_qcpvkT5ixmw5XoN6MupkOzcK1km625?usp=sharing)
+:m: [模型库](docs/ModelZoo_CN.md) :arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ) :arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing)
+:file_folder: [数据](docs/DatasetPreparation_CN.md) :arrow_double_down: [百度网盘](https://pan.baidu.com/s/1AZDcEAFwwc1OC3KCd7EDnQ) (提取码:basr) :arrow_double_down: [Google Drive](https://drive.google.com/drive/folders/1gt5eT293esqY0yr1Anbm36EdnxWW_5oH?usp=sharing)
:chart_with_upwards_trend: [wandb的训练曲线](https://app.wandb.ai/xintao/basicsr)
:computer: [训练和测试的命令](docs/TrainTest_CN.md)
:zap: [HOWTOs](#zap-howtos) --- -BasicSR 是一个基于 PyTorch 的**开源**图像视频超分辨率 (Super-Resolution) 工具箱 (之后会支持更多的 Restoration 任务).
+BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源图像视频复原工具箱, 比如 超分辨率, 去噪, 去模糊, 去 JPEG 压缩噪声等.
([ESRGAN](https://github.com/xinntao/ESRGAN), [EDVR](https://github.com/xinntao/EDVR), [DNI](https://github.com/xinntao/DNI), [SFTGAN](https://github.com/xinntao/SFTGAN)) +([HandyView](https://gitee.com/xinntao/HandyView), [HandyFigure](https://gitee.com/xinntao/HandyFigure), [HandyCrawler](https://gitee.com/xinntao/HandyCrawler), [HandyWriting](https://gitee.com/xinntao/HandyWriting)) ## :sparkles: 新的特性 -- Sep 8, 2020. 添加 **盲人脸复原推理代码: [DFDNet](https://github.com/csxmli2016/DFDNet)**. 注意和官方代码有些微差异. - > Blind Face Restoration via Deep Multi-scale Component Dictionaries
- > Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang
- > European Conference on Computer Vision (ECCV), 2020 -- Aug 27, 2020. 添加 **StyleGAN2 训练和测试** 代码: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch). - > Analyzing and Improving the Image Quality of StyleGAN
- > Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila
- > Computer Vision and Pattern Recognition (CVPR), 2020 +- Nov 29, 2020. 添加 **ESRGAN** and **DFDNet** [colab demo](colab). +- Sep 8, 2020. 添加 **盲人脸复原**测试代码: [DFDNet](https://github.com/csxmli2016/DFDNet). +- Aug 27, 2020. 添加 **StyleGAN2 训练和测试** 代码: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch).
更多
@@ -41,9 +41,21 @@ BasicSR 是一个基于 PyTorch 的**开源**图像视频超分辨率 (Super-Res 我们提供了简单的流程来快速上手 训练/测试/推理 模型. 这些命令并不能涵盖所有用法, 更多的细节参见下面的部分. -- [如何训练 StyleGAN2](docs/HOWTOs_CN.md#如何训练-StyleGAN2) -- [如何测试 StyleGAN2](docs/HOWTOs_CN.md#如何测试-StyleGAN2) -- [如何测试 DFDNet](docs/HOWTOs_CN.md#如何测试-DFDNet) +| GAN | | | | | | +| :--- | :---: | :---: | :--- | :---: | :---: | +| StyleGAN2 | [训练](docs/HOWTOs_CN.md#如何训练-StyleGAN2) | [测试](docs/HOWTOs_CN.md#如何测试-StyleGAN2) | | | | +| **Face Restoration** | | | | | | +| DFDNet | - | [测试](docs/HOWTOs_CN.md#如何测试-DFDNet) | | | | +| **Super Resolution** | | | | | | +| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO*| +| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO*| +| RCAN | *TODO* | *TODO* | | | | +| EDVR | *TODO* | *TODO* | DUF | - | *TODO* | +| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* | +| **Deblurring** | | | | | | +| DeblurGANv2 | - | *TODO* | | | | +| **Denoise** | | | | | | +| RIDNet | - | *TODO* | CBDNet | - | *TODO*| ## :wrench: 依赖和安装 @@ -51,13 +63,46 @@ BasicSR 是一个基于 PyTorch 的**开源**图像视频超分辨率 (Super-Res - [PyTorch >= 1.3](https://pytorch.org/) - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) -在BasicSR的**根目录**下运行以下命令:
-(确保 GCC 版本: gcc >= 5) +1. Clone repo -```bash -pip install -r requirements.txt -python setup.py develop -``` + ```bash + git clone https://github.com/xinntao/BasicSR.git + ``` + +1. 安装依赖包 + + ```bash + cd BasicSR + pip install -r requirements.txt + ``` + +1. 安装 BasicSR + + 在BasicSR的**根目录**下运行以下命令:
+ (确保 GCC 版本: gcc >= 5)
+ 如果你不需要以下 cuda 扩展算子:
+  [*dcn* for EDVR](basicsr/models/ops)
+  [*upfirdn2d* and *fused_act* for StyleGAN2](basicsr/models/ops)
+ 在安装命令后添加 `--no_cuda_ext` + + ```bash + python setup.py develop --no_cuda_ext + ``` + + 如果使用 EDVR 和 StyleGAN2 模型, 则需要使用上面的 cuda 扩展算子. + + ```bash + python setup.py develop + ``` + + 你或许需要指定 CUDA 路径: + + ```bash + CUDA_HOME=/usr/local/cuda \ + CUDNN_INCLUDE_DIR=/usr/local/cuda \ + CUDNN_LIB_DIR=/usr/local/cuda \ + python setup.py develop + ``` 注意: BasicSR 仅在 Ubuntu 下进行测试,或许不支持Windows. 可以在Windows下尝试[支持CUDA的Windows WSL](https://docs.microsoft.com/en-us/windows/win32/direct3d12/gpu-cuda-in-wsl) :-) (目前只有Fast ring的预览版系统可以安装). @@ -76,7 +121,7 @@ python setup.py develop - **Options/Configs**配置文件的说明, 参见 [Config_CN.md](docs/Config_CN.md). - **Logging**日志系统的说明, 参见 [Logging_CN.md](docs/Logging_CN.md). -## :card_file_box: 模型库和基准 +## :european_castle: 模型库和基准 - 目前支持的模型描述, 参见 [Models_CN.md](docs/Models_CN.md). - **预训练模型和log样例**, 参见 **[ModelZoo_CN.md](docs/ModelZoo_CN.md)**. @@ -97,8 +142,25 @@ python setup.py develop ## :scroll: 许可 -本项目使用 Apache 2.0 license. -更多细节参见 [LICENSE](LICENSE/README.md). +本项目使用 Apache 2.0 license.
+更多关于**许可**和**致谢**, 请参见 [LICENSE](LICENSE/README.md). + +## :earth_asia: 引用 + +如果 BasicSR 对你有所帮助, 可以考虑引用BasicSR.
+下面是一个 BibTex 引用条目, 它需要 `url` LaTeX package. + +``` latex +@misc{wang2020basicsr, + author = {Xintao Wang and Ke Yu and Kelvin C.K. Chan and + Chao Dong and Chen Change Loy}, + title = {BasicSR}, + howpublished = {\url{https://github.com/xinntao/BasicSR}}, + year = {2020} +} +``` + +> Xintao Wang, Ke Yu, Kelvin C.K. Chan, Chao Dong and Chen Change Loy. BasicSR. https://github.com/xinntao/BasicSR, 2020. ## :e-mail: 联系 diff --git a/VERSION b/VERSION index 524cb55..26aaba0 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.1.1 +1.2.0 diff --git a/assets/basicsr.png b/assets/basicsr.png new file mode 100644 index 0000000..cb35770 Binary files /dev/null and b/assets/basicsr.png differ diff --git a/basicsr/data/__init__.py b/basicsr/data/__init__.py index c7b09bc..22b0c8b 100644 --- a/basicsr/data/__init__.py +++ b/basicsr/data/__init__.py @@ -1,15 +1,14 @@ import importlib -import mmcv import numpy as np import random import torch import torch.utils.data from functools import partial -from mmcv.runner import get_dist_info from os import path as osp from basicsr.data.prefetch_dataloader import PrefetchDataLoader -from basicsr.utils import get_root_logger +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.dist_util import get_dist_info __all__ = ['create_dataset', 'create_dataloader'] @@ -17,7 +16,7 @@ # scan all the files under the data folder with '_dataset' in file names data_folder = osp.dirname(osp.abspath(__file__)) dataset_filenames = [ - osp.splitext(osp.basename(v))[0] for v in mmcv.scandir(data_folder) + osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py') ] # import all the dataset modules @@ -99,7 +98,7 @@ def create_dataloader(dataset, seed=seed) if seed is not None else None elif phase in ['val', 'test']: # validation dataloader_args = dict( - dataset=dataset, batch_size=1, shuffle=False, num_workers=1) + dataset=dataset, batch_size=1, shuffle=False, num_workers=0) else: raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.") diff --git a/basicsr/data/util.py b/basicsr/data/data_util.py similarity index 93% rename from basicsr/data/util.py rename to basicsr/data/data_util.py index 50245ac..975c0c0 100644 --- a/basicsr/data/util.py +++ b/basicsr/data/data_util.py @@ -1,10 +1,11 @@ -import mmcv +import cv2 import numpy as np import torch from os import path as osp from torch.nn import functional as F -from basicsr.data.transforms import mod_crop, totensor +from basicsr.data.transforms import mod_crop +from basicsr.utils import img2tensor, scandir def read_img_seq(path, require_mod_crop=False, scale=1): @@ -22,11 +23,11 @@ def read_img_seq(path, require_mod_crop=False, scale=1): if isinstance(path, list): img_paths = path else: - img_paths = sorted([osp.join(path, v) for v in mmcv.scandir(path)]) - imgs = [mmcv.imread(v).astype(np.float32) / 255. for v in img_paths] + img_paths = sorted(list(scandir(path, full_path=True))) + imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] if require_mod_crop: imgs = [mod_crop(img, scale) for img in imgs] - imgs = totensor(imgs, bgr2rgb=True, float32=True) + imgs = img2tensor(imgs, bgr2rgb=True, float32=True) imgs = torch.stack(imgs, dim=0) return imgs @@ -227,8 +228,8 @@ def paired_paths_from_folder(folders, keys, filename_tmpl): input_folder, gt_folder = folders input_key, gt_key = keys - input_paths = list(mmcv.scandir(input_folder)) - gt_paths = list(mmcv.scandir(gt_folder)) + input_paths = list(scandir(input_folder)) + gt_paths = list(scandir(gt_folder)) assert len(input_paths) == len(gt_paths), ( f'{input_key} and {gt_key} datasets have different number of images: ' f'{len(input_paths)}, {len(gt_paths)}.') @@ -256,11 +257,27 @@ def paths_from_folder(folder): list[str]: Returned path list. """ - paths = list(mmcv.scandir(folder)) + paths = list(scandir(folder)) paths = [osp.join(folder, path) for path in paths] return paths +def paths_from_lmdb(folder): + """Generate paths from lmdb. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + if not folder.endswith('.lmdb'): + raise ValueError(f'Folder {folder}folder should in lmdb format.') + with open(osp.join(folder, 'meta_info.txt')) as fin: + paths = [line.split('.')[0] for line in fin] + return paths + + def generate_gaussian_kernel(kernel_size=13, sigma=1.6): """Generate Gaussian kernel used in `duf_downsample`. diff --git a/basicsr/data/ffhq_dataset.py b/basicsr/data/ffhq_dataset.py index b4cabab..ef93ed6 100644 --- a/basicsr/data/ffhq_dataset.py +++ b/basicsr/data/ffhq_dataset.py @@ -1,11 +1,9 @@ -import mmcv -import numpy as np from os import path as osp from torch.utils import data as data from torchvision.transforms.functional import normalize -from basicsr.data.transforms import augment, totensor -from basicsr.utils import FileClient +from basicsr.data.transforms import augment +from basicsr.utils import FileClient, imfrombytes, img2tensor class FFHQDataset(data.Dataset): @@ -53,12 +51,12 @@ def __getitem__(self, index): # load gt image gt_path = self.paths[index] img_bytes = self.file_client.get(gt_path) - img_gt = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_gt = imfrombytes(img_bytes, float32=True) # random horizontal flip img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) # BGR to RGB, HWC to CHW, numpy to tensor - img_gt = totensor(img_gt, bgr2rgb=True, float32=True) + img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) # normalize normalize(img_gt, self.mean, self.std, inplace=True) return {'gt': img_gt, 'gt_path': gt_path} diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py index c5b01a8..66c042f 100644 --- a/basicsr/data/paired_image_dataset.py +++ b/basicsr/data/paired_image_dataset.py @@ -1,12 +1,11 @@ -import mmcv -import numpy as np from torch.utils import data as data +from torchvision.transforms.functional import normalize -from basicsr.data.transforms import augment, paired_random_crop, totensor -from basicsr.data.util import (paired_paths_from_folder, - paired_paths_from_lmdb, - paired_paths_from_meta_info_file) -from basicsr.utils import FileClient +from basicsr.data.data_util import (paired_paths_from_folder, + paired_paths_from_lmdb, + paired_paths_from_meta_info_file) +from basicsr.data.transforms import augment, paired_random_crop +from basicsr.utils import FileClient, imfrombytes, img2tensor class PairedImageDataset(data.Dataset): @@ -46,6 +45,8 @@ def __init__(self, opt): # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] if 'filename_tmpl' in opt: @@ -79,10 +80,10 @@ def __getitem__(self, index): # image range: [0, 1], float32. gt_path = self.paths[index]['gt_path'] img_bytes = self.file_client.get(gt_path, 'gt') - img_gt = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_gt = imfrombytes(img_bytes, float32=True) lq_path = self.paths[index]['lq_path'] img_bytes = self.file_client.get(lq_path, 'lq') - img_lq = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_lq = imfrombytes(img_bytes, float32=True) # augmentation for training if self.opt['phase'] == 'train': @@ -96,7 +97,13 @@ def __getitem__(self, index): # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor - img_gt, img_lq = totensor([img_gt, img_lq], bgr2rgb=True, float32=True) + img_gt, img_lq = img2tensor([img_gt, img_lq], + bgr2rgb=True, + float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) return { 'lq': img_lq, diff --git a/basicsr/data/reds_dataset.py b/basicsr/data/reds_dataset.py index 8f5f1de..7f7db26 100644 --- a/basicsr/data/reds_dataset.py +++ b/basicsr/data/reds_dataset.py @@ -1,12 +1,12 @@ -import mmcv import numpy as np import random import torch from pathlib import Path from torch.utils import data as data -from basicsr.data.transforms import augment, paired_random_crop, totensor -from basicsr.utils import FileClient, get_root_logger +from basicsr.data.transforms import augment, paired_random_crop +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.flow_util import dequantize_flow class REDSDataset(data.Dataset): @@ -144,7 +144,7 @@ def __getitem__(self, index): else: img_gt_path = self.gt_root / clip_name / f'{frame_name}.png' img_bytes = self.file_client.get(img_gt_path, 'gt') - img_gt = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_gt = imfrombytes(img_bytes, float32=True) # get the neighboring LQ frames img_lqs = [] @@ -154,7 +154,7 @@ def __getitem__(self, index): else: img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png' img_bytes = self.file_client.get(img_lq_path, 'lq') - img_lq = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_lq = imfrombytes(img_bytes, float32=True) img_lqs.append(img_lq) # get flows @@ -168,10 +168,11 @@ def __getitem__(self, index): flow_path = ( self.flow_root / clip_name / f'{frame_name}_p{i}.png') img_bytes = self.file_client.get(flow_path, 'flow') - cat_flow = mmcv.imfrombytes( - img_bytes, flag='grayscale') # uint8, [0, 255] + cat_flow = imfrombytes( + img_bytes, flag='grayscale', + float32=False) # uint8, [0, 255] dx, dy = np.split(cat_flow, 2, axis=0) - flow = mmcv.video.dequantize_flow( + flow = dequantize_flow( dx, dy, max_val=20, denorm=False) # we use max_val 20 here. img_flows.append(flow) @@ -183,9 +184,11 @@ def __getitem__(self, index): flow_path = ( self.flow_root / clip_name / f'{frame_name}_n{i}.png') img_bytes = self.file_client.get(flow_path, 'flow') - cat_flow = mmcv.imfrombytes(img_bytes, flag='grayscale') + cat_flow = imfrombytes( + img_bytes, flag='grayscale', + float32=False) # uint8, [0, 255] dx, dy = np.split(cat_flow, 2, axis=0) - flow = mmcv.video.dequantize_flow( + flow = dequantize_flow( dx, dy, max_val=20, denorm=False) # we use max_val 20 here. img_flows.append(flow) @@ -210,12 +213,12 @@ def __getitem__(self, index): img_results = augment(img_lqs, self.opt['use_flip'], self.opt['use_rot']) - img_results = totensor(img_results) + img_results = img2tensor(img_results) img_lqs = torch.stack(img_results[0:-1], dim=0) img_gt = img_results[-1] if self.flow_root is not None: - img_flows = totensor(img_flows) + img_flows = img2tensor(img_flows) # add the zero center flow img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0])) diff --git a/basicsr/data/single_image_dataset.py b/basicsr/data/single_image_dataset.py index 3a15934..b752b00 100644 --- a/basicsr/data/single_image_dataset.py +++ b/basicsr/data/single_image_dataset.py @@ -1,11 +1,9 @@ -import mmcv -import numpy as np from os import path as osp from torch.utils import data as data from torchvision.transforms.functional import normalize -from basicsr.data.transforms import totensor -from basicsr.utils import FileClient +from basicsr.data.data_util import paths_from_lmdb +from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir class SingleImageDataset(data.Dataset): @@ -33,17 +31,19 @@ def __init__(self, opt): self.mean = opt['mean'] if 'mean' in opt else None self.std = opt['std'] if 'std' in opt else None self.lq_folder = opt['dataroot_lq'] - if 'meta_info_file' in self.opt: + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder] + self.io_backend_opt['client_keys'] = ['lq'] + self.paths = paths_from_lmdb(self.lq_folder) + elif 'meta_info_file' in self.opt: with open(self.opt['meta_info_file'], 'r') as fin: self.paths = [ osp.join(self.lq_folder, line.split(' ')[0]) for line in fin ] else: - self.paths = [ - osp.join(self.lq_folder, v) - for v in mmcv.scandir(self.lq_folder) - ] + self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) def __getitem__(self, index): if self.file_client is None: @@ -52,12 +52,12 @@ def __getitem__(self, index): # load lq image lq_path = self.paths[index] - img_bytes = self.file_client.get(lq_path) - img_lq = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) # TODO: color space transform # BGR to RGB, HWC to CHW, numpy to tensor - img_lq = totensor(img_lq, bgr2rgb=True, float32=True) + img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) # normalize if self.mean is not None or self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True) diff --git a/basicsr/data/transforms.py b/basicsr/data/transforms.py index 6c7eb80..b6d04ff 100644 --- a/basicsr/data/transforms.py +++ b/basicsr/data/transforms.py @@ -1,6 +1,5 @@ -import mmcv +import cv2 import random -import torch def mod_crop(img, scale): @@ -85,7 +84,7 @@ def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): return img_gts, img_lqs -def augment(imgs, hflip=True, rotation=True, flows=None): +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). We use vertical flip and transpose for rotation implementation. @@ -99,6 +98,8 @@ def augment(imgs, hflip=True, rotation=True, flows=None): flows (list[ndarray]: Flows to be augmented. If the input is an ndarray, it will be transformed to a list. Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. Returns: list[ndarray] | ndarray: Augmented images and flows. If returned @@ -110,20 +111,20 @@ def augment(imgs, hflip=True, rotation=True, flows=None): rot90 = rotation and random.random() < 0.5 def _augment(img): - if hflip: - mmcv.imflip_(img, 'horizontal') - if vflip: - mmcv.imflip_(img, 'vertical') + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) if rot90: img = img.transpose(1, 0, 2) return img def _augment_flow(flow): - if hflip: - mmcv.imflip_(flow, 'horizontal') + if hflip: # horizontal + cv2.flip(flow, 1, flow) flow[:, :, 0] *= -1 - if vflip: - mmcv.imflip_(flow, 'vertical') + if vflip: # vertical + cv2.flip(flow, 0, flow) flow[:, :, 1] *= -1 if rot90: flow = flow.transpose(1, 0, 2) @@ -144,31 +145,28 @@ def _augment_flow(flow): flows = flows[0] return imgs, flows else: - return imgs + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs -def totensor(imgs, bgr2rgb=True, float32=True): - """Numpy array to tensor. +def img_rotate(img, angle, center=None, scale=1.0): + """Rotate image. Args: - imgs (list[ndarray] | ndarray): Input images. - bgr2rgb (bool): Whether to change bgr to rgb. - float32 (bool): Whether to change to float32. - - Returns: - list[tensor] | tensor: Tensor images. If returned results only have - one element, just return tensor. + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees. Positive values mean + counter-clockwise rotation. + center (tuple[int]): Rotation center. If the center is None, + initialize it as the center of the image. Default: None. + scale (float): Isotropic scale factor. Default: 1.0. """ + (h, w) = img.shape[:2] - def _totensor(img, bgr2rgb, float32): - if img.shape[2] == 3 and bgr2rgb: - img = mmcv.bgr2rgb(img) - img = torch.from_numpy(img.transpose(2, 0, 1)) - if float32: - img = img.float() - return img + if center is None: + center = (w // 2, h // 2) - if isinstance(imgs, list): - return [_totensor(img, bgr2rgb, float32) for img in imgs] - else: - return _totensor(imgs, bgr2rgb, float32) + matrix = cv2.getRotationMatrix2D(center, angle, scale) + rotated_img = cv2.warpAffine(img, matrix, (w, h)) + return rotated_img diff --git a/basicsr/data/video_test_dataset.py b/basicsr/data/video_test_dataset.py index 0ab7d99..01b876a 100644 --- a/basicsr/data/video_test_dataset.py +++ b/basicsr/data/video_test_dataset.py @@ -1,12 +1,11 @@ import glob -import mmcv import torch from os import path as osp from torch.utils import data as data -from basicsr.data import util as util -from basicsr.data.util import duf_downsample -from basicsr.utils import get_root_logger +from basicsr.data.data_util import (duf_downsample, generate_frame_indices, + read_img_seq) +from basicsr.utils import get_root_logger, scandir class VideoTestDataset(data.Dataset): @@ -81,14 +80,10 @@ def __init__(self, opt): subfolders_gt): # get frame list for lq and gt subfolder_name = osp.basename(subfolder_lq) - img_paths_lq = sorted([ - osp.join(subfolder_lq, v) - for v in mmcv.scandir(subfolder_lq) - ]) - img_paths_gt = sorted([ - osp.join(subfolder_gt, v) - for v in mmcv.scandir(subfolder_gt) - ]) + img_paths_lq = sorted( + list(scandir(subfolder_lq, full_path=True))) + img_paths_gt = sorted( + list(scandir(subfolder_gt, full_path=True))) max_idx = len(img_paths_lq) assert max_idx == len(img_paths_gt), ( @@ -110,10 +105,8 @@ def __init__(self, opt): if self.cache_data: logger.info( f'Cache {subfolder_name} for VideoTestDataset...') - self.imgs_lq[subfolder_name] = util.read_img_seq( - img_paths_lq) - self.imgs_gt[subfolder_name] = util.read_img_seq( - img_paths_gt) + self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq) + self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt) else: self.imgs_lq[subfolder_name] = img_paths_lq self.imgs_gt[subfolder_name] = img_paths_gt @@ -128,7 +121,7 @@ def __getitem__(self, index): border = self.data_info['border'][index] lq_path = self.data_info['lq_path'][index] - select_idx = util.generate_frame_indices( + select_idx = generate_frame_indices( idx, max_idx, self.opt['num_frame'], padding=self.opt['padding']) if self.cache_data: @@ -137,8 +130,8 @@ def __getitem__(self, index): img_gt = self.imgs_gt[folder][idx] else: img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx] - imgs_lq = util.read_img_seq(img_paths_lq) - img_gt = util.read_img_seq([self.imgs_gt[folder][idx]]) + imgs_lq = read_img_seq(img_paths_lq) + img_gt = read_img_seq([self.imgs_gt[folder][idx]]) img_gt.squeeze_(0) return { @@ -218,8 +211,8 @@ def __init__(self, opt): def __getitem__(self, index): lq_path = self.data_info['lq_path'][index] gt_path = self.data_info['gt_path'][index] - imgs_lq = util.read_img_seq(lq_path) - img_gt = util.read_img_seq([gt_path]) + imgs_lq = read_img_seq(lq_path) + img_gt = read_img_seq([gt_path]) img_gt.squeeze_(0) return { @@ -255,7 +248,7 @@ def __getitem__(self, index): border = self.data_info['border'][index] lq_path = self.data_info['lq_path'][index] - select_idx = util.generate_frame_indices( + select_idx = generate_frame_indices( idx, max_idx, self.opt['num_frame'], padding=self.opt['padding']) if self.cache_data: @@ -273,7 +266,7 @@ def __getitem__(self, index): if self.opt['use_duf_downsampling']: img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx] # read imgs_gt to generate low-resolution frames - imgs_lq = util.read_img_seq( + imgs_lq = read_img_seq( img_paths_lq, require_mod_crop=True, scale=self.opt['scale']) @@ -281,10 +274,10 @@ def __getitem__(self, index): imgs_lq, kernel_size=13, scale=self.opt['scale']) else: img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx] - imgs_lq = util.read_img_seq(img_paths_lq) - img_gt = util.read_img_seq([self.imgs_gt[folder][idx]], - require_mod_crop=True, - scale=self.opt['scale']) + imgs_lq = read_img_seq(img_paths_lq) + img_gt = read_img_seq([self.imgs_gt[folder][idx]], + require_mod_crop=True, + scale=self.opt['scale']) img_gt.squeeze_(0) return { diff --git a/basicsr/data/vimeo90k_dataset.py b/basicsr/data/vimeo90k_dataset.py index a88216e..71d5d11 100644 --- a/basicsr/data/vimeo90k_dataset.py +++ b/basicsr/data/vimeo90k_dataset.py @@ -1,12 +1,10 @@ -import mmcv -import numpy as np import random import torch from pathlib import Path from torch.utils import data as data -from basicsr.data.transforms import augment, paired_random_crop, totensor -from basicsr.utils import FileClient, get_root_logger +from basicsr.data.transforms import augment, paired_random_crop +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor class Vimeo90KDataset(data.Dataset): @@ -97,7 +95,7 @@ def __getitem__(self, index): else: img_gt_path = self.gt_root / clip / seq / 'im4.png' img_bytes = self.file_client.get(img_gt_path, 'gt') - img_gt = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_gt = imfrombytes(img_bytes, float32=True) # get the neighboring LQ frames img_lqs = [] @@ -107,7 +105,7 @@ def __getitem__(self, index): else: img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' img_bytes = self.file_client.get(img_lq_path, 'lq') - img_lq = mmcv.imfrombytes(img_bytes).astype(np.float32) / 255. + img_lq = imfrombytes(img_bytes, float32=True) img_lqs.append(img_lq) # randomly crop @@ -119,7 +117,7 @@ def __getitem__(self, index): img_results = augment(img_lqs, self.opt['use_flip'], self.opt['use_rot']) - img_results = totensor(img_results) + img_results = img2tensor(img_results) img_lqs = torch.stack(img_results[0:-1], dim=0) img_gt = img_results[-1] diff --git a/basicsr/metrics/fid.py b/basicsr/metrics/fid.py new file mode 100644 index 0000000..35fc23d --- /dev/null +++ b/basicsr/metrics/fid.py @@ -0,0 +1,102 @@ +import numpy as np +import torch +import torch.nn as nn +from scipy import linalg +from tqdm import tqdm + +from basicsr.models.archs.inception import InceptionV3 + + +def load_patched_inception_v3(device='cuda', + resize_input=True, + normalize_input=False): + # we may not resize the input, but in [rosinality/stylegan2-pytorch] it + # does resize the input. + inception = InceptionV3([3], + resize_input=resize_input, + normalize_input=normalize_input) + inception = nn.DataParallel(inception).eval().to(device) + return inception + + +@torch.no_grad() +def extract_inception_features(data_generator, + inception, + len_generator=None, + device='cuda'): + """Extract inception features. + + Args: + data_generator (generator): A data generator. + inception (nn.Module): Inception model. + len_generator (int): Length of the data_generator to show the + progressbar. Default: None. + device (str): Device. Default: cuda. + + Returns: + Tensor: Extracted features. + """ + if len_generator is not None: + pbar = tqdm(total=len_generator, unit='batch', desc='Extract') + else: + pbar = None + features = [] + + for data in data_generator: + if pbar: + pbar.update(1) + data = data.to(device) + feature = inception(data)[0].view(data.shape[0], -1) + features.append(feature.to('cpu')) + if pbar: + pbar.close() + features = torch.cat(features, 0) + return features + + +def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + + Args: + mu1 (np.array): The sample mean over activations. + sigma1 (np.array): The covariance matrix over activations for + generated samples. + mu2 (np.array): The sample mean over activations, precalculated on an + representative data set. + sigma2 (np.array): The covariance matrix over activations, + precalculated on an representative data set. + + Returns: + float: The Frechet Distance. + """ + assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, ( + 'Two covariances have different dimensions') + + cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) + + # Product might be almost singular + if not np.isfinite(cov_sqrt).all(): + print('Product of cov matrices is singular. Adding {eps} to diagonal ' + 'of cov estimates') + offset = np.eye(sigma1.shape[0]) * eps + cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(cov_sqrt): + if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): + m = np.max(np.abs(cov_sqrt.imag)) + raise ValueError(f'Imaginary component {m}') + cov_sqrt = cov_sqrt.real + + mean_diff = mu1 - mu2 + mean_norm = mean_diff @ mean_diff + trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) + fid = mean_norm + trace + + return fid diff --git a/basicsr/metrics/metric_util.py b/basicsr/metrics/metric_util.py index cac7026..fb38e1b 100644 --- a/basicsr/metrics/metric_util.py +++ b/basicsr/metrics/metric_util.py @@ -1,6 +1,7 @@ -import mmcv import numpy as np +from basicsr.utils.matlab_functions import bgr2ycbcr + def reorder_image(img, input_order='HWC'): """Reorder images to 'HWC' order. @@ -25,7 +26,6 @@ def reorder_image(img, input_order='HWC'): "'HWC' and 'CHW'") if len(img.shape) == 2: img = img[..., None] - return img if input_order == 'CHW': img = img.transpose(1, 2, 0) return img @@ -42,6 +42,6 @@ def to_y_channel(img): """ img = img.astype(np.float32) / 255. if img.ndim == 3 and img.shape[2] == 3: - img = mmcv.bgr2ycbcr(img, y_only=True) + img = bgr2ycbcr(img, y_only=True) img = img[..., None] return img * 255. diff --git a/basicsr/metrics/niqe.py b/basicsr/metrics/niqe.py index 7447bb4..a16c0d6 100644 --- a/basicsr/metrics/niqe.py +++ b/basicsr/metrics/niqe.py @@ -141,7 +141,9 @@ def niqe(img, # fit a MVG (multivariate Gaussian) model to distorted patch features mu_distparam = np.nanmean(distparam, axis=0) - cov_distparam = np.cov(distparam, rowvar=False) # TODO: use nancov + # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html + distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)] + cov_distparam = np.cov(distparam_no_nan, rowvar=False) # compute niqe quality, Eq. 10 in the paper invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2) diff --git a/basicsr/metrics/psnr_ssim.py b/basicsr/metrics/psnr_ssim.py index 22dc0fd..faef700 100644 --- a/basicsr/metrics/psnr_ssim.py +++ b/basicsr/metrics/psnr_ssim.py @@ -34,6 +34,8 @@ def calculate_psnr(img1, '"HWC" and "CHW"') img1 = reorder_image(img1, input_order=input_order) img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) if crop_border != 0: img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] @@ -122,6 +124,8 @@ def calculate_ssim(img1, '"HWC" and "CHW"') img1 = reorder_image(img1, input_order=input_order) img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) if crop_border != 0: img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] diff --git a/basicsr/models/__init__.py b/basicsr/models/__init__.py index 4f173de..10f3b9f 100644 --- a/basicsr/models/__init__.py +++ b/basicsr/models/__init__.py @@ -1,15 +1,14 @@ import importlib -import mmcv from os import path as osp -from basicsr.utils import get_root_logger +from basicsr.utils import get_root_logger, scandir # automatically scan and import model modules # scan all the files under the 'models' folder and collect files ending with # '_model.py' model_folder = osp.dirname(osp.abspath(__file__)) model_filenames = [ - osp.splitext(osp.basename(v))[0] for v in mmcv.scandir(model_folder) + osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py') ] # import all the model modules diff --git a/basicsr/models/archs/__init__.py b/basicsr/models/archs/__init__.py index a00982a..40410be 100644 --- a/basicsr/models/archs/__init__.py +++ b/basicsr/models/archs/__init__.py @@ -1,13 +1,14 @@ import importlib -import mmcv from os import path as osp +from basicsr.utils import scandir + # automatically scan and import arch modules # scan all the files under the 'archs' folder and collect files ending with # '_arch.py' arch_folder = osp.dirname(osp.abspath(__file__)) arch_filenames = [ - osp.splitext(osp.basename(v))[0] for v in mmcv.scandir(arch_folder) + osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py') ] # import all the arch modules diff --git a/basicsr/models/archs/arch_util.py b/basicsr/models/archs/arch_util.py index 9aebf99..19b4ed8 100644 --- a/basicsr/models/archs/arch_util.py +++ b/basicsr/models/archs/arch_util.py @@ -5,10 +5,17 @@ from torch.nn import init as init from torch.nn.modules.batchnorm import _BatchNorm -from basicsr.models.ops.dcn import (ModulatedDeformConvPack, - modulated_deform_conv) from basicsr.utils import get_root_logger +try: + from basicsr.models.ops.dcn import (ModulatedDeformConvPack, + modulated_deform_conv) +except ImportError: + print('Cannot import dcn. Ignore this warning if dcn is not used. ' + 'Otherwise install BasicSR with compiling dcn.') + ModulatedDeformConvPack = object + modulated_deform_conv = None + @torch.no_grad() def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): @@ -150,7 +157,7 @@ def flow_warp(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, - align_corners=True) + align_corners=align_corners) # TODO, what if align_corners=False return output diff --git a/basicsr/models/archs/dfdnet_arch.py b/basicsr/models/archs/dfdnet_arch.py index e03dfc0..c887d90 100644 --- a/basicsr/models/archs/dfdnet_arch.py +++ b/basicsr/models/archs/dfdnet_arch.py @@ -54,28 +54,6 @@ def forward(self, x, updated_feat): return out -class VGGFaceFeatureExtractor(VGGFeatureExtractor): - - def preprocess(self, x): - # norm to [0, 1] - x = (x + 1) / 2 - if self.use_input_norm: - x = (x - self.mean) / self.std - if x.shape[3] < 224: - x = torch.nn.functional.interpolate( - x, size=(224, 224), mode='bilinear', align_corners=False) - return x - - def forward(self, x): - x = self.preprocess(x) - features = [] - for key, layer in self.vgg_net._modules.items(): - x = layer(x) - if key in self.layer_name_list: - features.append(x) - return features - - class DFDNet(nn.Module): """DFDNet: Deep Face Dictionary Network. @@ -88,16 +66,18 @@ def __init__(self, num_feat, dict_path): # part_sizes: [80, 80, 50, 110] channel_sizes = [128, 256, 512, 512] self.feature_sizes = np.array([256, 128, 64, 32]) + self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4'] self.flag_dict_device = False # dict self.dict = torch.load(dict_path) # vgg face extractor - self.vgg_extractor = VGGFaceFeatureExtractor( - layer_name_list=['conv2_2', 'conv3_4', 'conv4_4', 'conv5_4'], + self.vgg_extractor = VGGFeatureExtractor( + layer_name_list=self.vgg_layers, vgg_type='vgg19', use_input_norm=True, + range_norm=True, requires_grad=False) # attention block for fusing dictionary features and input features @@ -175,9 +155,9 @@ def forward(self, x, part_locations): # update vggface features using the dictionary for each part updated_vgg_features = [] batch = 0 # only supports testing with batch size = 0 - for i, f_size in enumerate(self.feature_sizes): + for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes): dict_features = self.dict[f'{f_size}'] - vgg_feat = vgg_features[i] + vgg_feat = vgg_features[vgg_layer] updated_feat = vgg_feat.clone() # swap features from dictionary @@ -190,7 +170,7 @@ def forward(self, x, part_locations): updated_vgg_features.append(updated_feat) - vgg_feat_dilation = self.multi_scale_dilation(vgg_features[3]) + vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4']) # use updated vgg features to modulate the upsampled features with # SFT (Spatial Feature Transform) scaling and shifting manner. upsampled_feat = self.upsample0(vgg_feat_dilation, diff --git a/basicsr/models/archs/inception.py b/basicsr/models/archs/inception.py new file mode 100644 index 0000000..3efdf62 --- /dev/null +++ b/basicsr/models/archs/inception.py @@ -0,0 +1,323 @@ +# Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501 +# For FID metric + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.model_zoo import load_url +from torchvision import models + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 +LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling features + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=[DEFAULT_BLOCK_INDEX], + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3. + + Args: + output_blocks (list[int]): Indices of blocks to return features of. + Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input (bool): If true, bilinearly resizes input to width and + height 299 before feeding input to model. As the network + without fully connected layers is fully convolutional, it + should be able to handle inputs of arbitrary size, so resizing + might not be strictly needed. Default: True. + normalize_input (bool): If true, scales the input from range (0, 1) + to the range the pretrained Inception network expects, + namely (-1, 1). Default: True. + requires_grad (bool): If true, parameters of the model require + gradients. Possibly useful for finetuning the network. + Default: False. + use_fid_inception (bool): If true, uses the pretrained Inception + model used in Tensorflow's FID implementation. + If false, uses the pretrained Inception model available in + torchvision. The FID Inception model has different weights + and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get + comparable results. Default: True. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, ( + 'Last possible output block index is 3') + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + try: + inception = models.inception_v3( + pretrained=True, init_weights=False) + except TypeError: + # pytorch < 1.5 does not have init_weights for inception_v3 + inception = models.inception_v3(pretrained=True) + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, x): + """Get Inception feature maps. + + Args: + x (Tensor): Input tensor of shape (b, 3, h, w). + Values are expected to be in range (-1, 1). You can also input + (0, 1) with setting normalize_input = True. + + Returns: + list[Tensor]: Corresponding to the selected output block, sorted + ascending by index. + """ + output = [] + + if self.resize_input: + x = F.interpolate( + x, size=(299, 299), mode='bilinear', align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + output.append(x) + + if idx == self.last_needed_block: + break + + return output + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation. + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + try: + inception = models.inception_v3( + num_classes=1008, + aux_logits=False, + pretrained=False, + init_weights=False) + except TypeError: + # pytorch < 1.5 does not have init_weights for inception_v3 + inception = models.inception_v3( + num_classes=1008, aux_logits=False, pretrained=False) + + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + if os.path.exists(LOCAL_FID_WEIGHTS): + state_dict = torch.load( + LOCAL_FID_WEIGHTS, map_location=lambda storage, loc: storage) + else: + state_dict = load_url(FID_WEIGHTS_URL, progress=True) + + inception.load_state_dict(state_dict) + return inception + + +class FIDInceptionA(models.inception.InceptionA): + """InceptionA block patched for FID computation""" + + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(models.inception.InceptionC): + """InceptionC block patched for FID computation""" + + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d( + x, kernel_size=3, stride=1, padding=1, count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) diff --git a/basicsr/models/archs/stylegan2_arch.py b/basicsr/models/archs/stylegan2_arch.py index f0d3453..3d53cff 100644 --- a/basicsr/models/archs/stylegan2_arch.py +++ b/basicsr/models/archs/stylegan2_arch.py @@ -4,8 +4,13 @@ from torch import nn from torch.nn import functional as F -from basicsr.models.ops.fused_act import FusedLeakyReLU, fused_leaky_relu -from basicsr.models.ops.upfirdn2d import upfirdn2d +try: + from basicsr.models.ops.fused_act import FusedLeakyReLU, fused_leaky_relu + from basicsr.models.ops.upfirdn2d import upfirdn2d +except ImportError: + print('Cannot import fused_act and upfirdn2d. Ignore this warning if ' + 'they are not used. Otherwise install BasicSR with compiling them.') + FusedLeakyReLU, fused_leaky_relu, upfirdn2d = None, None, None class NormStyleCode(nn.Module): @@ -211,7 +216,7 @@ class ModulatedConv2d(nn.Module): sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. resample_kernel (list[int]): A list indicating the 1D resample kernel - magnitude. Default: [1, 3, 3, 1]. + magnitude. Default: (1, 3, 3, 1). eps (float): A value added to the denominator for numerical stability. Default: 1e-8. """ @@ -223,7 +228,7 @@ def __init__(self, num_style_feat, demodulate=True, sample_mode=None, - resample_kernel=[1, 3, 3, 1], + resample_kernel=(1, 3, 3, 1), eps=1e-8): super(ModulatedConv2d, self).__init__() self.in_channels = in_channels @@ -333,7 +338,7 @@ class StyleConv(nn.Module): sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. resample_kernel (list[int]): A list indicating the 1D resample kernel - magnitude. Default: [1, 3, 3, 1]. + magnitude. Default: (1, 3, 3, 1). """ def __init__(self, @@ -343,7 +348,7 @@ def __init__(self, num_style_feat, demodulate=True, sample_mode=None, - resample_kernel=[1, 3, 3, 1]): + resample_kernel=(1, 3, 3, 1)): super(StyleConv, self).__init__() self.modulated_conv = ModulatedConv2d( in_channels, @@ -377,14 +382,14 @@ class ToRGB(nn.Module): num_style_feat (int): Channel number of style features. upsample (bool): Whether to upsample. Default: True. resample_kernel (list[int]): A list indicating the 1D resample kernel - magnitude. Default: [1, 3, 3, 1]. + magnitude. Default: (1, 3, 3, 1). """ def __init__(self, in_channels, num_style_feat, upsample=True, - resample_kernel=[1, 3, 3, 1]): + resample_kernel=(1, 3, 3, 1)): super(ToRGB, self).__init__() if upsample: self.upsample = UpFirDnUpsample(resample_kernel, factor=2) @@ -447,8 +452,9 @@ class StyleGAN2Generator(nn.Module): StyleGAN2. Default: 2. resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be applied to extent 1D resample - kenrel to 2D resample kernel. Default: [1, 3, 3, 1]. + kenrel to 2D resample kernel. Default: (1, 3, 3, 1). lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): Narrow ratio for channels. Default: 1.0. """ def __init__(self, @@ -456,8 +462,9 @@ def __init__(self, num_style_feat=512, num_mlp=8, channel_multiplier=2, - resample_kernel=[1, 3, 3, 1], - lr_mlp=0.01): + resample_kernel=(1, 3, 3, 1), + lr_mlp=0.01, + narrow=1): super(StyleGAN2Generator, self).__init__() # Style MLP layers self.num_style_feat = num_style_feat @@ -474,16 +481,17 @@ def __init__(self, self.style_mlp = nn.Sequential(*style_mlp_layers) channels = { - '4': 512, - '8': 512, - '16': 512, - '32': 512, - '64': 256 * channel_multiplier, - '128': 128 * channel_multiplier, - '256': 64 * channel_multiplier, - '512': 32 * channel_multiplier, - '1024': 16 * channel_multiplier + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) } + self.channels = channels self.constant_input = ConstantInput(channels['4'], size=4) self.style_conv1 = StyleConv( @@ -736,7 +744,7 @@ class ConvLayer(nn.Sequential): resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be applied to extent 1D resample kenrel to 2D resample kernel. - Default: [1, 3, 3, 1]. + Default: (1, 3, 3, 1). bias (bool): Whether with bias. Default: True. activate (bool): Whether use activateion. Default: True. """ @@ -746,7 +754,7 @@ def __init__(self, out_channels, kernel_size, downsample=False, - resample_kernel=[1, 3, 3, 1], + resample_kernel=(1, 3, 3, 1), bias=True, activate=True): layers = [] @@ -791,13 +799,11 @@ class ResBlock(nn.Module): resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be applied to extent 1D resample kenrel to 2D resample kernel. - Default: [1, 3, 3, 1]. + Default: (1, 3, 3, 1). """ - def __init__(self, - in_channels, - out_channels, - resample_kernel=[1, 3, 3, 1]): + def __init__(self, in_channels, out_channels, + resample_kernel=(1, 3, 3, 1)): super(ResBlock, self).__init__() self.conv1 = ConvLayer( @@ -836,26 +842,31 @@ class StyleGAN2Discriminator(nn.Module): StyleGAN2. Default: 2. resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be applied to extent 1D resample - kenrel to 2D resample kernel. Default: [1, 3, 3, 1]. + kenrel to 2D resample kernel. Default: (1, 3, 3, 1). + stddev_group (int): For group stddev statistics. Default: 4. + narrow (float): Narrow ratio for channels. Default: 1.0. """ def __init__(self, out_size, channel_multiplier=2, - resample_kernel=[1, 3, 3, 1]): + resample_kernel=(1, 3, 3, 1), + stddev_group=4, + narrow=1): super(StyleGAN2Discriminator, self).__init__() channels = { - '4': 512, - '8': 512, - '16': 512, - '32': 512, - '64': 256 * channel_multiplier, - '128': 128 * channel_multiplier, - '256': 64 * channel_multiplier, - '512': 32 * channel_multiplier, - '1024': 16 * channel_multiplier + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) } + log_size = int(math.log(out_size, 2)) conv_body = [ @@ -888,7 +899,7 @@ def __init__(self, lr_mul=1, activation=None), ) - self.stddev_group = 4 + self.stddev_group = stddev_group self.stddev_feat = 1 def forward(self, x): diff --git a/basicsr/models/archs/vgg_arch.py b/basicsr/models/archs/vgg_arch.py index 89c8772..5b1574a 100644 --- a/basicsr/models/archs/vgg_arch.py +++ b/basicsr/models/archs/vgg_arch.py @@ -1,8 +1,10 @@ +import os import torch from collections import OrderedDict from torch import nn as nn from torchvision.models import vgg as vgg +VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' NAMES = { 'vgg11': [ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', @@ -68,6 +70,8 @@ class VGGFeatureExtractor(nn.Module): vgg_type (str): Set the type of vgg network. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. requires_grad (bool): If true, the parameters of VGG network will be optimized. Default: False. remove_pooling (bool): If true, the max pooling operations in VGG net @@ -79,6 +83,7 @@ def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True, + range_norm=False, requires_grad=False, remove_pooling=False, pooling_stride=2): @@ -86,6 +91,7 @@ def __init__(self, self.layer_name_list = layer_name_list self.use_input_norm = use_input_norm + self.range_norm = range_norm self.names = NAMES[vgg_type.replace('_bn', '')] if 'bn' in vgg_type: @@ -97,8 +103,16 @@ def __init__(self, idx = self.names.index(v) if idx > max_idx: max_idx = idx - features = getattr(vgg, - vgg_type)(pretrained=True).features[:max_idx + 1] + + if os.path.exists(VGG_PRETRAIN_PATH): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + state_dict = torch.load( + VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) + vgg_net.load_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + + features = vgg_net.features[:max_idx + 1] modified_net = OrderedDict() for k, v in zip(self.names, features): @@ -143,7 +157,8 @@ def forward(self, x): Returns: Tensor: Forward results. """ - + if self.range_norm: + x = (x + 1) / 2 if self.use_input_norm: x = (x - self.mean) / self.std diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py index 5baa524..f8987bf 100644 --- a/basicsr/models/base_model.py +++ b/basicsr/models/base_model.py @@ -3,10 +3,10 @@ import torch from collections import OrderedDict from copy import deepcopy -from mmcv.runner import master_only from torch.nn.parallel import DataParallel, DistributedDataParallel from basicsr.models import lr_scheduler as lr_scheduler +from basicsr.utils.dist_util import master_only logger = logging.getLogger('basicsr') @@ -242,14 +242,17 @@ def load_network(self, net, load_path, strict=True, param_key='params'): load_path (str): The path of networks to be loaded. net (nn.Module): Network. strict (bool): Whether strictly loaded. - param_key (str): The parameter key of loaded network. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. Default: 'params'. """ net = self.get_bare_model(net) logger.info( f'Loading {net.__class__.__name__} model from {load_path}.') load_net = torch.load( - load_path, map_location=lambda storage, loc: storage)[param_key] + load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + load_net = load_net[param_key] # remove unnecessary 'module.' for k, v in deepcopy(load_net).items(): if k.startswith('module.'): diff --git a/basicsr/models/losses/loss_utils.py b/basicsr/models/losses/loss_util.py similarity index 100% rename from basicsr/models/losses/loss_utils.py rename to basicsr/models/losses/loss_util.py diff --git a/basicsr/models/losses/losses.py b/basicsr/models/losses/losses.py index 4cbc5e8..2df8d75 100644 --- a/basicsr/models/losses/losses.py +++ b/basicsr/models/losses/losses.py @@ -5,7 +5,7 @@ from torch.nn import functional as F from basicsr.models.archs.vgg_arch import VGGFeatureExtractor -from basicsr.models.losses.loss_utils import weighted_loss +from basicsr.models.losses.loss_util import weighted_loss _reduction_modes = ['none', 'mean', 'sum'] @@ -155,17 +155,14 @@ class PerceptualLoss(nn.Module): Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image in vgg. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. perceptual_weight (float): If `perceptual_weight > 0`, the perceptual loss will be calculated and the loss will multiplied by the weight. Default: 1.0. style_weight (float): If `style_weight > 0`, the style loss will be calculated and the loss will multiplied by the weight. Default: 0. - norm_img (bool): If True, the image will be normed to [0, 1]. Note that - this is different from the `use_input_norm` which norm the input in - in forward function of vgg according to the statistics of dataset. - Importantly, the input image must be in range [-1, 1]. - Default: False. criterion (str): Criterion used for perceptual loss. Default: 'l1'. """ @@ -173,19 +170,19 @@ def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, + range_norm=False, perceptual_weight=1.0, style_weight=0., - norm_img=False, criterion='l1'): super(PerceptualLoss, self).__init__() - self.norm_img = norm_img self.perceptual_weight = perceptual_weight self.style_weight = style_weight self.layer_weights = layer_weights self.vgg = VGGFeatureExtractor( layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, - use_input_norm=use_input_norm) + use_input_norm=use_input_norm, + range_norm=range_norm) self.criterion_type = criterion if self.criterion_type == 'l1': @@ -208,11 +205,6 @@ def forward(self, x, gt): Returns: Tensor: Forward results. """ - - if self.norm_img: - x = (x + 1.) * 0.5 - gt = (gt + 1.) * 0.5 - # extract vgg features x_features = self.vgg(x) gt_features = self.vgg(gt.detach()) diff --git a/basicsr/models/lr_scheduler.py b/basicsr/models/lr_scheduler.py index eaa0b53..a2b4d35 100644 --- a/basicsr/models/lr_scheduler.py +++ b/basicsr/models/lr_scheduler.py @@ -20,8 +20,8 @@ def __init__(self, optimizer, milestones, gamma=0.1, - restarts=[0], - restart_weights=[1], + restarts=(0, ), + restart_weights=(1, ), last_epoch=-1): self.milestones = Counter(milestones) self.gamma = gamma @@ -90,7 +90,7 @@ class CosineAnnealingRestartLR(_LRScheduler): def __init__(self, optimizer, periods, - restart_weights=[1], + restart_weights=(1, ), eta_min=0, last_epoch=-1): self.periods = periods diff --git a/basicsr/models/ops/fused_act/src/fused_bias_act.cpp b/basicsr/models/ops/fused_act/src/fused_bias_act.cpp index cc9b8f7..85ed0a7 100755 --- a/basicsr/models/ops/fused_act/src/fused_bias_act.cpp +++ b/basicsr/models/ops/fused_act/src/fused_bias_act.cpp @@ -2,15 +2,19 @@ #include -torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, - int act, int grad, float alpha, float scale); +torch::Tensor fused_bias_act_op(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale); #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) -torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, - int act, int grad, float alpha, float scale) { +torch::Tensor fused_bias_act(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { CHECK_CUDA(input); CHECK_CUDA(bias); diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py index 66a98b6..92bbf4b 100644 --- a/basicsr/models/sr_model.py +++ b/basicsr/models/sr_model.py @@ -1,13 +1,13 @@ import importlib -import mmcv import torch from collections import OrderedDict from copy import deepcopy from os import path as osp +from tqdm import tqdm from basicsr.models.archs import define_network from basicsr.models.base_model import BaseModel -from basicsr.utils import ProgressBar, get_root_logger, tensor2img +from basicsr.utils import get_root_logger, imwrite, tensor2img loss_module = importlib.import_module('basicsr.models.losses') metric_module = importlib.import_module('basicsr.metrics') @@ -25,10 +25,10 @@ def __init__(self, opt): self.print_network(self.net_g) # load pretrained models - load_path = self.opt['path'].get('pretrain_model_g', None) + load_path = self.opt['path'].get('pretrain_network_g', None) if load_path is not None: self.load_network(self.net_g, load_path, - self.opt['path']['strict_load']) + self.opt['path'].get('strict_load_g', True)) if self.is_train: self.init_training_settings() @@ -131,7 +131,7 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, metric: 0 for metric in self.opt['val']['metrics'].keys() } - pbar = ProgressBar(len(dataloader)) + pbar = tqdm(total=len(dataloader), unit='image') for idx, val_data in enumerate(dataloader): img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] @@ -163,7 +163,7 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img_path = osp.join( self.opt['path']['visualization'], dataset_name, f'{img_name}_{self.opt["name"]}.png') - mmcv.imwrite(sr_img, save_img_path) + imwrite(sr_img, save_img_path) if with_metrics: # calculate metrics @@ -172,7 +172,9 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, metric_type = opt_.pop('type') self.metric_results[name] += getattr( metric_module, metric_type)(sr_img, gt_img, **opt_) - pbar.update(f'Test {img_name}') + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() if with_metrics: for metric in self.metric_results.keys(): diff --git a/basicsr/models/srgan_model.py b/basicsr/models/srgan_model.py index d927773..7d08d7b 100644 --- a/basicsr/models/srgan_model.py +++ b/basicsr/models/srgan_model.py @@ -21,10 +21,10 @@ def init_training_settings(self): self.print_network(self.net_d) # load pretrained models - load_path = self.opt['path'].get('pretrain_model_d', None) + load_path = self.opt['path'].get('pretrain_network_d', None) if load_path is not None: self.load_network(self.net_d, load_path, - self.opt['path']['strict_load']) + self.opt['path'].get('strict_load_d', True)) self.net_g.train() self.net_d.train() diff --git a/basicsr/models/stylegan2_model.py b/basicsr/models/stylegan2_model.py index 7cf7aec..c1ac6cf 100644 --- a/basicsr/models/stylegan2_model.py +++ b/basicsr/models/stylegan2_model.py @@ -1,6 +1,6 @@ +import cv2 import importlib import math -import mmcv import numpy as np import random import torch @@ -11,7 +11,7 @@ from basicsr.models.archs import define_network from basicsr.models.base_model import BaseModel from basicsr.models.losses.losses import g_path_regularize, r1_penalty -from basicsr.utils import tensor2img +from basicsr.utils import imwrite, tensor2img loss_module = importlib.import_module('basicsr.models.losses') @@ -27,11 +27,12 @@ def __init__(self, opt): self.net_g = self.model_to_device(self.net_g) self.print_network(self.net_g) # load pretrained model - load_path = self.opt['path'].get('pretrain_model_g', None) + load_path = self.opt['path'].get('pretrain_network_g', None) if load_path is not None: param_key = self.opt['path'].get('param_key_g', 'params') self.load_network(self.net_g, load_path, - self.opt['path']['strict_load'], param_key) + self.opt['path'].get('strict_load_g', + True), param_key) # latent dimension: self.num_style_feat self.num_style_feat = opt['network_g']['num_style_feat'] @@ -51,10 +52,10 @@ def init_training_settings(self): self.print_network(self.net_d) # load pretrained model - load_path = self.opt['path'].get('pretrain_model_d', None) + load_path = self.opt['path'].get('pretrain_network_d', None) if load_path is not None: self.load_network(self.net_d, load_path, - self.opt['path']['strict_load']) + self.opt['path'].get('strict_load_d', True)) # define network net_g with Exponential Moving Average (EMA) # net_g_ema only used for testing on one GPU and saving, do not need to @@ -62,10 +63,11 @@ def init_training_settings(self): self.net_g_ema = define_network(deepcopy(self.opt['network_g'])).to( self.device) # load pretrained model - load_path = self.opt['path'].get('pretrain_model_g', None) + load_path = self.opt['path'].get('pretrain_network_g', None) if load_path is not None: self.load_network(self.net_g_ema, load_path, - self.opt['path']['strict_load'], 'params_ema') + self.opt['path'].get('strict_load_g', + True), 'params_ema') else: self.model_ema(0) # copy net_g weight @@ -311,10 +313,10 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, else: save_img_path = osp.join(self.opt['path']['visualization'], 'test', f'test_{self.opt["name"]}.png') - mmcv.imwrite(result, save_img_path) + imwrite(result, save_img_path) # add sample images to tb_logger result = (result / 255.).astype(np.float32) - result = mmcv.bgr2rgb(result) + result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) if tb_logger is not None: tb_logger.add_image( 'samples', result, global_step=current_iter, dataformats='HWC') diff --git a/basicsr/models/video_base_model.py b/basicsr/models/video_base_model.py index 6e70eed..c8e8d26 100644 --- a/basicsr/models/video_base_model.py +++ b/basicsr/models/video_base_model.py @@ -1,14 +1,14 @@ import importlib -import mmcv import torch from collections import Counter from copy import deepcopy -from mmcv.runner import get_dist_info from os import path as osp from torch import distributed as dist +from tqdm import tqdm from basicsr.models.sr_model import SRModel -from basicsr.utils import ProgressBar, get_root_logger, tensor2img +from basicsr.utils import get_root_logger, imwrite, tensor2img +from basicsr.utils.dist_util import get_dist_info metric_module = importlib.import_module('basicsr.metrics') @@ -34,13 +34,13 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img): len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda') - rank, world_size = get_dist_info() - for _, tensor in self.metric_results.items(): - tensor.zero_() + if with_metrics: + for _, tensor in self.metric_results.items(): + tensor.zero_() # record all frames (border and center frames) if rank == 0: - pbar = ProgressBar(len(dataset)) + pbar = tqdm(total=len(dataset), unit='frame') for idx in range(rank, len(dataset), world_size): val_data = dataset[idx] val_data['lq'].unsqueeze_(0) @@ -83,7 +83,7 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img): save_img_path = osp.join( self.opt['path']['visualization'], dataset_name, folder, f'{img_name}_{self.opt["name"]}.png') - mmcv.imwrite(result_img, save_img_path) + imwrite(result_img, save_img_path) if with_metrics: # calculate metrics @@ -98,8 +98,12 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img): # progress bar if rank == 0: for _ in range(world_size): - pbar.update(f'Test {folder} - ' - f'{int(frame_idx) + world_size}/{max_idx}') + pbar.update(1) + pbar.set_description( + f'Test {folder}:' + f'{int(frame_idx) + world_size}/{max_idx}') + if rank == 0: + pbar.close() if with_metrics: if self.opt['dist']: diff --git a/basicsr/models/video_gan_model.py b/basicsr/models/video_gan_model.py index 94ccf4b..290434b 100644 --- a/basicsr/models/video_gan_model.py +++ b/basicsr/models/video_gan_model.py @@ -1,142 +1,15 @@ -import importlib -import torch -from collections import OrderedDict -from copy import deepcopy - -from basicsr.models.archs import define_network +from basicsr.models.srgan_model import SRGANModel from basicsr.models.video_base_model import VideoBaseModel -loss_module = importlib.import_module('basicsr.models.losses') - - -class VideoGANModel(VideoBaseModel): - """Video GAN model.""" - - def init_training_settings(self): - train_opt = self.opt['train'] - - # define network net_d - self.net_d = define_network(deepcopy(self.opt['network_d'])) - self.net_d = self.model_to_device(self.net_d) - self.print_network(self.net_d) - - # load pretrained models - load_path = self.opt['path'].get('pretrain_model_d', None) - if load_path is not None: - self.load_network(self.net_d, load_path, - self.opt['path']['strict_load']) - - self.net_g.train() - self.net_d.train() - - # define losses - if train_opt.get('pixel_opt'): - pixel_type = train_opt['pixel_opt'].pop('type') - cri_pix_cls = getattr(loss_module, pixel_type) - self.cri_pix = cri_pix_cls(**train_opt['pixel_opt']).to( - self.device) - else: - self.cri_pix = None - - if train_opt.get('perceptual_opt'): - percep_type = train_opt['perceptual_opt'].pop('type') - cri_perceptual_cls = getattr(loss_module, percep_type) - self.cri_perceptual = cri_perceptual_cls( - **train_opt['perceptual_opt']).to(self.device) - else: - self.cri_perceptual = None - - if train_opt.get('gan_opt'): - gan_type = train_opt['gan_opt'].pop('type') - cri_gan_cls = getattr(loss_module, gan_type) - self.cri_gan = cri_gan_cls(**train_opt['gan_opt']).to(self.device) - - self.net_d_iters = train_opt.get('net_d_iters', 1) - self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) - - # set up optimizers and schedulers - self.setup_optimizers() - self.setup_schedulers() - - def setup_optimizers(self): - train_opt = self.opt['train'] - # optimizer g - optim_type = train_opt['optim_g'].pop('type') - if optim_type == 'Adam': - self.optimizer_g = torch.optim.Adam(self.net_g.parameters(), - **train_opt['optim_g']) - else: - raise NotImplementedError( - f'optimizer {optim_type} is not supperted yet.') - self.optimizers.append(self.optimizer_g) - # optimizer d - optim_type = train_opt['optim_d'].pop('type') - if optim_type == 'Adam': - self.optimizer_d = torch.optim.Adam(self.net_d.parameters(), - **train_opt['optim_d']) - else: - raise NotImplementedError( - f'optimizer {optim_type} is not supperted yet.') - self.optimizers.append(self.optimizer_d) - - def optimize_parameters(self, current_iter): - # optimize net_g - for p in self.net_d.parameters(): - p.requires_grad = False - - self.optimizer_g.zero_grad() - self.output = self.net_g(self.lq) - - l_g_total = 0 - loss_dict = OrderedDict() - if (current_iter % self.net_d_iters == 0 - and current_iter > self.net_d_init_iters): - # pixel loss - if self.cri_pix: - l_g_pix = self.cri_pix(self.output, self.gt) - l_g_total += l_g_pix - loss_dict['l_g_pix'] = l_g_pix - # perceptual loss - if self.cri_perceptual: - l_g_percep, l_g_style = self.cri_perceptual( - self.output, self.gt) - if l_g_percep is not None: - l_g_total += l_g_percep - loss_dict['l_g_percep'] = l_g_percep - if l_g_style is not None: - l_g_total += l_g_style - loss_dict['l_g_style'] = l_g_style - # gan loss - fake_g_pred = self.net_d(self.output) - l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) - l_g_total += l_g_gan - loss_dict['l_g_gan'] = l_g_gan - - l_g_total.backward() - self.optimizer_g.step() - - # optimize net_d - for p in self.net_d.parameters(): - p.requires_grad = True - - self.optimizer_d.zero_grad() - # real - real_d_pred = self.net_d(self.gt) - l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) - loss_dict['l_d_real'] = l_d_real - loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) - l_d_real.backward() - # fake - fake_d_pred = self.net_d(self.output.detach()) - l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) - loss_dict['l_d_fake'] = l_d_fake - loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) - l_d_fake.backward() - self.optimizer_d.step() - self.log_dict = self.reduce_loss_dict(loss_dict) +class VideoGANModel(SRGANModel, VideoBaseModel): + """Video GAN model. - def save(self, epoch, current_iter): - self.save_network(self.net_g, 'net_g', current_iter) - self.save_network(self.net_d, 'net_d', current_iter) - self.save_training_state(epoch, current_iter) + Use multiple inheritance. + It will first use the functions of SRGANModel: + init_training_settings + setup_optimizers + optimize_parameters + save + Then find functions in VideoBaseModel. + """ diff --git a/basicsr/test.py b/basicsr/test.py index 7bdae15..622df4e 100644 --- a/basicsr/test.py +++ b/basicsr/test.py @@ -1,46 +1,23 @@ -import argparse import logging -import random import torch -from mmcv.runner import get_dist_info, get_time_str, init_dist from os import path as osp from basicsr.data import create_dataloader, create_dataset from basicsr.models import create_model -from basicsr.utils import (get_env_info, get_root_logger, make_exp_dirs, - set_random_seed) -from basicsr.utils.options import dict2str, parse +from basicsr.train import parse_options +from basicsr.utils import (get_env_info, get_root_logger, get_time_str, + make_exp_dirs) +from basicsr.utils.options import dict2str def main(): - # options - parser = argparse.ArgumentParser() - parser.add_argument( - '-opt', type=str, required=True, help='Path to option YAML file.') - parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm'], - default='none', - help='job launcher') - parser.add_argument('--local_rank', type=int, default=0) - args = parser.parse_args() - opt = parse(args.opt, is_train=False) + # parse options, set distributed setting, set ramdom seed + opt = parse_options(is_train=False) - # distributed testing settings - if args.launcher == 'none': # non-distributed testing - opt['dist'] = False - print('Disable distributed testing.', flush=True) - else: - opt['dist'] = True - if args.launcher == 'slurm' and 'dist_params' in opt: - init_dist(args.launcher, **opt['dist_params']) - else: - init_dist(args.launcher) - - rank, world_size = get_dist_info() - opt['rank'] = rank - opt['world_size'] = world_size + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + # mkdir and initialize loggers make_exp_dirs(opt) log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") @@ -49,17 +26,6 @@ def main(): logger.info(get_env_info()) logger.info(dict2str(opt)) - # random seed - seed = opt['manual_seed'] - if seed is None: - seed = random.randint(1, 10000) - opt['manual_seed'] = seed - logger.info(f'Random seed: {seed}') - set_random_seed(seed + rank) - - torch.backends.cudnn.benchmark = True - # torch.backends.cudnn.deterministic = True - # create test dataset and dataloader test_loaders = [] for phase, dataset_opt in sorted(opt['datasets'].items()): @@ -70,7 +36,7 @@ def main(): num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, - seed=seed) + seed=opt['manual_seed']) logger.info( f"Number of test images in {dataset_opt['name']}: {len(test_set)}") test_loaders.append(test_loader) diff --git a/basicsr/train.py b/basicsr/train.py index 0d769c8..02a460f 100644 --- a/basicsr/train.py +++ b/basicsr/train.py @@ -5,7 +5,6 @@ import random import time import torch -from mmcv.runner import get_dist_info, get_time_str, init_dist from os import path as osp from basicsr.data import create_dataloader, create_dataset @@ -13,13 +12,14 @@ from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher from basicsr.models import create_model from basicsr.utils import (MessageLogger, check_resume, get_env_info, - get_root_logger, init_tb_logger, init_wandb_logger, - make_exp_dirs, mkdir_and_rename, set_random_seed) + get_root_logger, get_time_str, init_tb_logger, + init_wandb_logger, make_exp_dirs, mkdir_and_rename, + set_random_seed) +from basicsr.utils.dist_util import get_dist_info, init_dist from basicsr.utils.options import dict2str, parse -def main(): - # options +def parse_options(is_train=True): parser = argparse.ArgumentParser() parser.add_argument( '-opt', type=str, required=True, help='Path to option YAML file.') @@ -30,12 +30,12 @@ def main(): help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() - opt = parse(args.opt, is_train=True) + opt = parse(args.opt, is_train=is_train) - # distributed training settings - if args.launcher == 'none': # non-distributed training + # distributed settings + if args.launcher == 'none': opt['dist'] = False - print('Disable distributed training.', flush=True) + print('Disable distributed.', flush=True) else: opt['dist'] = True if args.launcher == 'slurm' and 'dist_params' in opt: @@ -43,68 +43,55 @@ def main(): else: init_dist(args.launcher) - rank, world_size = get_dist_info() - opt['rank'] = rank - opt['world_size'] = world_size + opt['rank'], opt['world_size'] = get_dist_info() - # load resume states if exists - if opt['path'].get('resume_state'): - device_id = torch.cuda.current_device() - resume_state = torch.load( - opt['path']['resume_state'], - map_location=lambda storage, loc: storage.cuda(device_id)) - else: - resume_state = None + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) - # mkdir and loggers - if resume_state is None: - make_exp_dirs(opt) + return opt + + +def init_loggers(opt): log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log") logger = get_root_logger( logger_name='basicsr', log_level=logging.INFO, log_file=log_file) logger.info(get_env_info()) logger.info(dict2str(opt)) + # initialize tensorboard logger and wandb logger tb_logger = None if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']: - log_dir = './tb_logger/' + opt['name'] - if resume_state is None and opt['rank'] == 0: - mkdir_and_rename(log_dir) - tb_logger = init_tb_logger(log_dir=log_dir) + tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name'])) if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None) and ('debug' not in opt['name']): assert opt['logger'].get('use_tb_logger') is True, ( 'should turn on tensorboard when using wandb') init_wandb_logger(opt) + return logger, tb_logger - # random seed - seed = opt['manual_seed'] - if seed is None: - seed = random.randint(1, 10000) - opt['manual_seed'] = seed - logger.info(f'Random seed: {seed}') - set_random_seed(seed + rank) - - torch.backends.cudnn.benchmark = True - # torch.backends.cudnn.deterministic = True +def create_train_val_dataloader(opt, logger): # create train and val dataloaders train_loader, val_loader = None, None for phase, dataset_opt in opt['datasets'].items(): if phase == 'train': dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) train_set = create_dataset(dataset_opt) - train_sampler = EnlargedSampler(train_set, world_size, rank, - dataset_enlarge_ratio) + train_sampler = EnlargedSampler(train_set, opt['world_size'], + opt['rank'], dataset_enlarge_ratio) train_loader = create_dataloader( train_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=train_sampler, - seed=seed) + seed=opt['manual_seed']) num_iter_per_epoch = math.ceil( len(train_set) * dataset_enlarge_ratio / @@ -119,6 +106,7 @@ def main(): f'\n\tWorld size (gpu number): {opt["world_size"]}' f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') + elif phase == 'val': val_set = create_dataset(dataset_opt) val_loader = create_dataloader( @@ -127,27 +115,57 @@ def main(): num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, - seed=seed) + seed=opt['manual_seed']) logger.info( f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}') else: raise ValueError(f'Dataset phase {phase} is not recognized.') - assert train_loader is not None - # create model - if resume_state: - check_resume(opt, resume_state['iter']) # modify pretrain_model paths - model = create_model(opt) + return train_loader, train_sampler, val_loader, total_epochs, total_iters + + +def main(): + # parse options, set distributed setting, set ramdom seed + opt = parse_options(is_train=True) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # load resume states if necessary + if opt['path'].get('resume_state'): + device_id = torch.cuda.current_device() + resume_state = torch.load( + opt['path']['resume_state'], + map_location=lambda storage, loc: storage.cuda(device_id)) + else: + resume_state = None + + # mkdir for experiments and logger + if resume_state is None: + make_exp_dirs(opt) + if opt['logger'].get('use_tb_logger') and 'debug' not in opt[ + 'name'] and opt['rank'] == 0: + mkdir_and_rename(osp.join('tb_logger', opt['name'])) + + # initialize loggers + logger, tb_logger = init_loggers(opt) + + # create train and validation dataloaders + result = create_train_val_dataloader(opt, logger) + train_loader, train_sampler, val_loader, total_epochs, total_iters = result - # resume training - if resume_state: + # create model + if resume_state: # resume training + check_resume(opt, resume_state['iter']) + model = create_model(opt) + model.resume_training(resume_state) # handle optimizers and schedulers logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.") start_epoch = resume_state['epoch'] current_iter = resume_state['iter'] - model.resume_training(resume_state) # handle optimizers and schedulers else: + model = create_model(opt) start_epoch = 0 current_iter = 0 diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py index 95f7a50..2b91571 100644 --- a/basicsr/utils/__init__.py +++ b/basicsr/utils/__init__.py @@ -1,12 +1,31 @@ from .file_client import FileClient +from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img from .logger import (MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger) -from .util import (ProgressBar, check_resume, crop_border, make_exp_dirs, - mkdir_and_rename, set_random_seed, tensor2img) +from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, + scandir, set_random_seed, sizeof_fmt) __all__ = [ - 'FileClient', 'MessageLogger', 'get_root_logger', 'make_exp_dirs', - 'init_tb_logger', 'init_wandb_logger', 'set_random_seed', 'ProgressBar', - 'tensor2img', 'crop_border', 'check_resume', 'mkdir_and_rename', - 'get_env_info' + # file_client.py + 'FileClient', + # img_util.py + 'img2tensor', + 'tensor2img', + 'imfrombytes', + 'imwrite', + 'crop_border', + # logger.py + 'MessageLogger', + 'init_tb_logger', + 'init_wandb_logger', + 'get_root_logger', + 'get_env_info', + # misc.py + 'set_random_seed', + 'get_time_str', + 'mkdir_and_rename', + 'make_exp_dirs', + 'scandir', + 'check_resume', + 'sizeof_fmt' ] diff --git a/basicsr/utils/dist_util.py b/basicsr/utils/dist_util.py new file mode 100644 index 0000000..43cf4cd --- /dev/null +++ b/basicsr/utils/dist_util.py @@ -0,0 +1,83 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/basicsr/utils/download.py b/basicsr/utils/download_util.py similarity index 73% rename from basicsr/utils/download.py rename to basicsr/utils/download_util.py index e03516c..64a0016 100644 --- a/basicsr/utils/download.py +++ b/basicsr/utils/download_util.py @@ -1,7 +1,8 @@ import math import requests +from tqdm import tqdm -from basicsr.utils import ProgressBar +from .misc import sizeof_fmt def download_file_from_google_drive(file_id, save_path): @@ -49,7 +50,8 @@ def save_response_content(response, file_size=None, chunk_size=32768): if file_size is not None: - pbar = ProgressBar(math.ceil(file_size / chunk_size)) + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + readable_file_size = sizeof_fmt(file_size) else: pbar = None @@ -59,24 +61,10 @@ def save_response_content(response, for chunk in response.iter_content(chunk_size): downloaded_size += chunk_size if pbar is not None: - pbar.update(f'Downloading {sizeof_fmt(downloaded_size)} ' - f'/ {readable_file_size}') + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' + f'/ {readable_file_size}') if chunk: # filter out keep-alive new chunks f.write(chunk) - - -def sizeof_fmt(size, suffix='B'): - """Get human readable file size. - - Args: - size (int): File size. - suffix (str): Suffix. Default: 'B'. - - Return: - str: Formated file siz. - """ - for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: - if abs(size) < 1024.0: - return f'{size:3.1f} {unit}{suffix}' - size /= 1024.0 - return f'{size:3.1f} Y{suffix}' + if pbar is not None: + pbar.close() diff --git a/basicsr/utils/face_util.py b/basicsr/utils/face_util.py new file mode 100644 index 0000000..33fe178 --- /dev/null +++ b/basicsr/utils/face_util.py @@ -0,0 +1,217 @@ +import cv2 +import numpy as np +import os +import torch +from skimage import transform as trans + +from basicsr.utils import imwrite + +try: + import dlib +except ImportError: + print('Please install dlib before testing face restoration.' + 'Reference: https://github.com/davisking/dlib') + + +class FaceRestorationHelper(object): + """Helper for the face restoration pipeline.""" + + def __init__(self, upscale_factor, face_size=512): + self.upscale_factor = upscale_factor + self.face_size = (face_size, face_size) + + # standard 5 landmarks for FFHQ faces with 1024 x 1024 + self.face_template = np.array([[686.77227723, 488.62376238], + [586.77227723, 493.59405941], + [337.91089109, 488.38613861], + [437.95049505, 493.51485149], + [513.58415842, 678.5049505]]) + self.face_template = self.face_template / (1024 // face_size) + # for estimation the 2D similarity transformation + self.similarity_trans = trans.SimilarityTransform() + + self.all_landmarks_5 = [] + self.all_landmarks_68 = [] + self.affine_matrices = [] + self.inverse_affine_matrices = [] + self.cropped_faces = [] + self.restored_faces = [] + self.save_png = True + + def init_dlib(self, detection_path, landmark5_path, landmark68_path): + """Initialize the dlib detectors and predictors.""" + self.face_detector = dlib.cnn_face_detection_model_v1(detection_path) + self.shape_predictor_5 = dlib.shape_predictor(landmark5_path) + self.shape_predictor_68 = dlib.shape_predictor(landmark68_path) + + def free_dlib_gpu_memory(self): + del self.face_detector + del self.shape_predictor_5 + del self.shape_predictor_68 + + def read_input_image(self, img_path): + # self.input_img is Numpy array, (h, w, c) with RGB order + self.input_img = dlib.load_rgb_image(img_path) + + def detect_faces(self, + img_path, + upsample_num_times=1, + only_keep_largest=False): + """ + Args: + img_path (str): Image path. + upsample_num_times (int): Upsamples the image before running the + face detector + + Returns: + int: Number of detected faces. + """ + self.read_input_image(img_path) + det_faces = self.face_detector(self.input_img, upsample_num_times) + if len(det_faces) == 0: + print('No face detected. Try to increase upsample_num_times.') + else: + if only_keep_largest: + print('Detect several faces and only keep the largest.') + face_areas = [] + for i in range(len(det_faces)): + face_area = (det_faces[i].rect.right() - + det_faces[i].rect.left()) * ( + det_faces[i].rect.bottom() - + det_faces[i].rect.top()) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + self.det_faces = [det_faces[largest_idx]] + else: + self.det_faces = det_faces + return len(self.det_faces) + + def get_face_landmarks_5(self): + for face in self.det_faces: + shape = self.shape_predictor_5(self.input_img, face.rect) + landmark = np.array([[part.x, part.y] for part in shape.parts()]) + self.all_landmarks_5.append(landmark) + return len(self.all_landmarks_5) + + def get_face_landmarks_68(self): + """Get 68 densemarks for cropped images. + + Should only have one face at most in the cropped image. + """ + num_detected_face = 0 + for idx, face in enumerate(self.cropped_faces): + # face detection + det_face = self.face_detector(face, 1) # TODO: can we remove it? + if len(det_face) == 0: + print(f'Cannot find faces in cropped image with index {idx}.') + self.all_landmarks_68.append(None) + else: + if len(det_face) > 1: + print('Detect several faces in the cropped face. Use the ' + ' largest one. Note that it will also cause overlap ' + 'during paste_faces_to_input_image.') + face_areas = [] + for i in range(len(det_face)): + face_area = (det_face[i].rect.right() - + det_face[i].rect.left()) * ( + det_face[i].rect.bottom() - + det_face[i].rect.top()) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + face_rect = det_face[largest_idx].rect + else: + face_rect = det_face[0].rect + shape = self.shape_predictor_68(face, face_rect) + landmark = np.array([[part.x, part.y] + for part in shape.parts()]) + self.all_landmarks_68.append(landmark) + num_detected_face += 1 + + return num_detected_face + + def warp_crop_faces(self, + save_cropped_path=None, + save_inverse_affine_path=None): + """Get affine matrix, warp and cropped faces. + + Also get inverse affine matrix for post-processing. + """ + for idx, landmark in enumerate(self.all_landmarks_5): + # use 5 landmarks to get affine matrix + self.similarity_trans.estimate(landmark, self.face_template) + affine_matrix = self.similarity_trans.params[0:2, :] + self.affine_matrices.append(affine_matrix) + # warp and crop faces + cropped_face = cv2.warpAffine(self.input_img, affine_matrix, + self.face_size) + self.cropped_faces.append(cropped_face) + # save the cropped face + if save_cropped_path is not None: + path, ext = os.path.splitext(save_cropped_path) + if self.save_png: + save_path = f'{path}_{idx:02d}.png' + else: + save_path = f'{path}_{idx:02d}{ext}' + + imwrite( + cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path) + + # get inverse affine matrix + self.similarity_trans.estimate(self.face_template, + landmark * self.upscale_factor) + inverse_affine = self.similarity_trans.params[0:2, :] + self.inverse_affine_matrices.append(inverse_affine) + # save inverse affine matrices + if save_inverse_affine_path is not None: + path, _ = os.path.splitext(save_inverse_affine_path) + save_path = f'{path}_{idx:02d}.pth' + torch.save(inverse_affine, save_path) + + def add_restored_face(self, face): + self.restored_faces.append(face) + + def paste_faces_to_input_image(self, save_path): + # operate in the BGR order + input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR) + h, w, _ = input_img.shape + h_up, w_up = h * self.upscale_factor, w * self.upscale_factor + # simply resize the background + upsample_img = cv2.resize(input_img, (w_up, h_up)) + assert len(self.restored_faces) == len(self.inverse_affine_matrices), ( + 'length of restored_faces and affine_matrices are different.') + for restored_face, inverse_affine in zip(self.restored_faces, + self.inverse_affine_matrices): + inv_restored = cv2.warpAffine(restored_face, inverse_affine, + (w_up, h_up)) + mask = np.ones((*self.face_size, 3), dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # remove the black borders + inv_mask_erosion = cv2.erode( + inv_mask, + np.ones((2 * self.upscale_factor, 2 * self.upscale_factor), + np.uint8)) + inv_restored_remove_border = inv_mask_erosion * inv_restored + total_face_area = np.sum(inv_mask_erosion) // 3 + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode( + inv_mask_erosion, + np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, + (blur_size + 1, blur_size + 1), 0) + upsample_img = inv_soft_mask * inv_restored_remove_border + ( + 1 - inv_soft_mask) * upsample_img + if self.save_png: + save_path = save_path.replace('.jpg', + '.png').replace('.jpeg', '.png') + imwrite(upsample_img.astype(np.uint8), save_path) + + def clean_all(self): + self.all_landmarks_5 = [] + self.all_landmarks_68 = [] + self.restored_faces = [] + self.affine_matrices = [] + self.cropped_faces = [] + self.inverse_affine_matrices = [] diff --git a/basicsr/utils/file_client.py b/basicsr/utils/file_client.py index 1d8e5cf..066b22f 100644 --- a/basicsr/utils/file_client.py +++ b/basicsr/utils/file_client.py @@ -1,113 +1,183 @@ -from mmcv.fileio.file_client import (BaseStorageBackend, CephBackend, - HardDiskBackend, MemcachedBackend) - - -class LmdbBackend(BaseStorageBackend): - """Lmdb storage backend. - - Args: - db_paths (str | list[str]): Lmdb database paths. - client_keys (str | list[str]): Lmdb client keys. Default: 'default'. - readonly (bool, optional): Lmdb environment parameter. If True, - disallow any write operations. Default: True. - lock (bool, optional): Lmdb environment parameter. If False, when - concurrent access occurs, do not lock the database. Default: False. - readahead (bool, optional): Lmdb environment parameter. If False, - disable the OS filesystem readahead mechanism, which may improve - random read performance when a database is larger than RAM. - Default: False. - - Attributes: - db_paths (list): Lmdb database path. - _client (list): A list of several lmdb envs. - """ - - def __init__(self, - db_paths, - client_keys='default', - readonly=True, - lock=False, - readahead=False, - **kwargs): - try: - import lmdb - except ImportError: - raise ImportError('Please install lmdb to enable LmdbBackend.') - - if isinstance(client_keys, str): - client_keys = [client_keys] - - if isinstance(db_paths, list): - self.db_paths = [str(v) for v in db_paths] - elif isinstance(db_paths, str): - self.db_paths = [str(db_paths)] - assert len(client_keys) == len(self.db_paths), ( - 'client_keys and db_paths should have the same length, ' - f'but received {len(client_keys)} and {len(self.db_paths)}.') - - self._client = {} - for client, path in zip(client_keys, self.db_paths): - self._client[client] = lmdb.open( - path, - readonly=readonly, - lock=lock, - readahead=readahead, - **kwargs) - - def get(self, filepath, client_key): - """Get values according to the filepath from one lmdb named client_key. - - Args: - filepath (str | obj:`Path`): Here, filepath is the lmdb key. - client_key (str): Used for distinguishing differnet lmdb envs. - """ - filepath = str(filepath) - assert client_key in self._client, (f'client_key {client_key} is not ' - 'in lmdb clients.') - client = self._client[client_key] - with client.begin(write=False) as txn: - value_buf = txn.get(filepath.encode('ascii')) - return value_buf - - def get_text(self, filepath): - raise NotImplementedError - - -class FileClient(object): - """A general file client to access files in different backend. - - The client loads a file or text in a specified backend from its path - and return it as a binary file. it can also register other backend - accessor with a given name and backend class. - - Attributes: - backend (str): The storage backend type. Options are "disk", "ceph", - "memcached" and "lmdb". - client (:obj:`BaseStorageBackend`): The backend object. - """ - - _backends = { - 'disk': HardDiskBackend, - 'ceph': CephBackend, - 'memcached': MemcachedBackend, - 'lmdb': LmdbBackend, - } - - def __init__(self, backend='disk', **kwargs): - if backend not in self._backends: - raise ValueError( - f'Backend {backend} is not supported. Currently supported ones' - f' are {list(self._backends.keys())}') - self.backend = backend - self.client = self._backends[backend](**kwargs) - - def get(self, filepath, client_key='default'): - # client_key is used only for lmdb, where different fileclients have - # different lmdb environments. - if self.backend == 'lmdb': - return self.client.get(filepath, client_key) - else: - return self.client.get(filepath) - - def get_text(self, filepath): - return self.client.get_text(filepath) +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError( + 'Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, + self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, + db_paths, + client_keys='default', + readonly=True, + lock=False, + readahead=False, + **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ( + 'client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open( + path, + readonly=readonly, + lock=lock, + readahead=readahead, + **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing differnet lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not ' + 'in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError( + f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/basicsr/utils/flow_util.py b/basicsr/utils/flow_util.py new file mode 100644 index 0000000..2b052cc --- /dev/null +++ b/basicsr/utils/flow_util.py @@ -0,0 +1,180 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 +import cv2 +import numpy as np +import os + + +def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): + """Read an optical flow map. + + Args: + flow_path (ndarray or str): Flow path. + quantize (bool): whether to read quantized pair, if set to True, + remaining args will be passed to :func:`dequantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + + Returns: + ndarray: Optical flow represented as a (h, w, 2) numpy array + """ + if quantize: + assert concat_axis in [0, 1] + cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) + if cat_flow.ndim != 2: + raise IOError(f'{flow_path} is not a valid quantized flow file, ' + f'its dimension is {cat_flow.ndim}.') + assert cat_flow.shape[concat_axis] % 2 == 0 + dx, dy = np.split(cat_flow, 2, axis=concat_axis) + flow = dequantize_flow(dx, dy, *args, **kwargs) + else: + with open(flow_path, 'rb') as f: + try: + header = f.read(4).decode('utf-8') + except Exception: + raise IOError(f'Invalid flow file: {flow_path}') + else: + if header != 'PIEH': + raise IOError(f'Invalid flow file: {flow_path}, ' + 'header does not contain PIEH') + + w = np.fromfile(f, np.int32, 1).squeeze() + h = np.fromfile(f, np.int32, 1).squeeze() + flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) + + return flow.astype(np.float32) + + +def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): + """Write optical flow to file. + + If the flow is not quantized, it will be saved as a .flo file losslessly, + otherwise a jpeg image which is lossy but of much smaller size. (dx and dy + will be concatenated horizontally into a single image if quantize is True.) + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + filename (str): Output filepath. + quantize (bool): Whether to quantize the flow and save it to 2 jpeg + images. If set to True, remaining args will be passed to + :func:`quantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + """ + if not quantize: + with open(filename, 'wb') as f: + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + flow = flow.astype(np.float32) + flow.tofile(f) + f.flush() + else: + assert concat_axis in [0, 1] + dx, dy = quantize_flow(flow, *args, **kwargs) + dxdy = np.concatenate((dx, dy), axis=concat_axis) + os.makedirs(filename, exist_ok=True) + cv2.imwrite(dxdy, filename) + + +def quantize_flow(flow, max_val=0.02, norm=True): + """Quantize flow to [0, 255]. + + After this step, the size of flow will be much smaller, and can be + dumped as jpeg images. + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + max_val (float): Maximum value of flow, values beyond + [-max_val, max_val] will be truncated. + norm (bool): Whether to divide flow values by image width/height. + + Returns: + tuple[ndarray]: Quantized dx and dy. + """ + h, w, _ = flow.shape + dx = flow[..., 0] + dy = flow[..., 1] + if norm: + dx = dx / w # avoid inplace operations + dy = dy / h + # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. + flow_comps = [ + quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy] + ] + return tuple(flow_comps) + + +def dequantize_flow(dx, dy, max_val=0.02, denorm=True): + """Recover from quantized flow. + + Args: + dx (ndarray): Quantized dx. + dy (ndarray): Quantized dy. + max_val (float): Maximum value used when quantizing. + denorm (bool): Whether to multiply flow values with width/height. + + Returns: + ndarray: Dequantized flow. + """ + assert dx.shape == dy.shape + assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) + + dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] + + if denorm: + dx *= dx.shape[1] + dy *= dx.shape[0] + flow = np.dstack((dx, dy)) + return flow + + +def quantize(arr, min_val, max_val, levels, dtype=np.int64): + """Quantize an array of (-inf, inf) to [0, levels-1]. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the quantized array. + + Returns: + tuple: Quantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError( + f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError( + f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + arr = np.clip(arr, min_val, max_val) - min_val + quantized_arr = np.minimum( + np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) + + return quantized_arr + + +def dequantize(arr, min_val, max_val, levels, dtype=np.float64): + """Dequantize an array. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the dequantized array. + + Returns: + tuple: Dequantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError( + f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError( + f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - + min_val) / levels + min_val + + return dequantized_arr diff --git a/basicsr/utils/img_util.py b/basicsr/utils/img_util.py new file mode 100644 index 0000000..152be01 --- /dev/null +++ b/basicsr/utils/img_util.py @@ -0,0 +1,165 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or + (isinstance(tensor, list) + and all(torch.is_tensor(t) for t in tensor))): + raise TypeError( + f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid( + _tensor, nrow=int(math.sqrt(_tensor.size(0))), + normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError('Only support 4D, 3D or 2D tensor. ' + f'But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = { + 'color': cv2.IMREAD_COLOR, + 'grayscale': cv2.IMREAD_GRAYSCALE, + 'unchanged': cv2.IMREAD_UNCHANGED + } + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [ + v[crop_border:-crop_border, crop_border:-crop_border, ...] + for v in imgs + ] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, + ...] diff --git a/basicsr/utils/lmdb.py b/basicsr/utils/lmdb_util.py similarity index 93% rename from basicsr/utils/lmdb.py rename to basicsr/utils/lmdb_util.py index 8e3e99d..a81278f 100644 --- a/basicsr/utils/lmdb.py +++ b/basicsr/utils/lmdb_util.py @@ -1,11 +1,9 @@ import cv2 import lmdb -import mmcv import sys from multiprocessing import Pool from os import path as osp - -from .util import ProgressBar +from tqdm import tqdm def make_lmdb_from_imgs(data_path, @@ -76,12 +74,13 @@ def make_lmdb_from_imgs(data_path, dataset = {} # use dict to keep the order for multiprocessing shapes = {} print(f'Read images with multiprocessing, #thread: {n_thread} ...') - pbar = ProgressBar(len(img_path_list)) + pbar = tqdm(total=len(img_path_list), unit='image') def callback(arg): """get the image data and update pbar.""" key, dataset[key], shapes[key] = arg - pbar.update('Reading {}'.format(key)) + pbar.update(1) + pbar.set_description(f'Read {key}') pool = Pool(n_thread) for path, key in zip(img_path_list, keys): @@ -91,13 +90,14 @@ def callback(arg): callback=callback) pool.close() pool.join() + pbar.close() print(f'Finish reading {len(img_path_list)} images.') # create lmdb environment if map_size is None: # obtain data size for one image - img = mmcv.imread( - osp.join(data_path, img_path_list[0]), flag='unchanged') + img = cv2.imread( + osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) _, img_byte = cv2.imencode( '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) data_size_per_img = img_byte.nbytes @@ -108,11 +108,12 @@ def callback(arg): env = lmdb.open(lmdb_path, map_size=map_size) # write data to lmdb - pbar = ProgressBar(len(img_path_list)) + pbar = tqdm(total=len(img_path_list), unit='chunk') txn = env.begin(write=True) txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') for idx, (path, key) in enumerate(zip(img_path_list, keys)): - pbar.update(f'Write {key}') + pbar.update(1) + pbar.set_description(f'Write {key}') key_byte = key.encode('ascii') if multiprocessing_read: img_byte = dataset[key] @@ -128,6 +129,7 @@ def callback(arg): if idx % batch == 0: txn.commit() txn = env.begin(write=True) + pbar.close() txn.commit() env.close() txt_file.close() @@ -148,7 +150,7 @@ def read_img_worker(path, key, compress_level): tuple[int]: Image shape. """ - img = mmcv.imread(path, flag='unchanged') + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) if img.ndim == 2: h, w = img.shape c = 1 diff --git a/basicsr/utils/logger.py b/basicsr/utils/logger.py index 6aee50b..48671ed 100644 --- a/basicsr/utils/logger.py +++ b/basicsr/utils/logger.py @@ -1,7 +1,8 @@ import datetime import logging import time -from mmcv.runner import get_dist_info, master_only + +from .dist_util import get_dist_info, master_only class MessageLogger(): @@ -153,7 +154,6 @@ def get_env_info(): Currently, only log the software version. """ - import mmcv import torch import torchvision @@ -173,6 +173,5 @@ def get_env_info(): msg += ('\nVersion Information: ' f'\n\tBasicSR: {__version__}' f'\n\tPyTorch: {torch.__version__}' - f'\n\tTorchVision: {torchvision.__version__}' - f'\n\tMMCV: {mmcv.__version__}') + f'\n\tTorchVision: {torchvision.__version__}') return msg diff --git a/basicsr/utils/matlab_functions.py b/basicsr/utils/matlab_functions.py new file mode 100644 index 0000000..cd96a2c --- /dev/null +++ b/basicsr/utils/matlab_functions.py @@ -0,0 +1,361 @@ +import math +import numpy as np +import torch + + +def cubic(x): + """cubic function used for calculate_weights_indices.""" + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + + 2) * (((absx > 1) * + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, + kernel_width, antialiasing): + """Calculate weights and indices, used for imresize function. + + Args: + in_length (int): Input length. + out_length (int): Output length. + scale (float): Scale factor. + kernel_width (int): Kernel width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + """ + + if (scale < 1) and antialiasing: + # Use a modified kernel (larger kernel width) to simultaneously + # interpolate and antialias + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5 + scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + p = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace( + 0, p - 1, p).view(1, p).expand(out_length, p) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices + + # apply cubic kernel + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, p) + + # If a column in weights is all zero, get rid of it. only consider the + # first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, p - 2) + weights = weights.narrow(1, 1, p - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, p - 2) + weights = weights.narrow(1, 0, p - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +@torch.no_grad() +def imresize(img, scale, antialiasing=True): + """imresize function same as MATLAB. + + It now only supports bicubic. + The same scale applies for both height and width. + + Args: + img (Tensor | Numpy array): + Tensor: Input image with shape (c, h, w), [0, 1] range. + Numpy: Input image with shape (h, w, c), [0, 1] range. + scale (float): Scale factor. The same scale applies for both height + and width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + Default: True. + + Returns: + Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. + """ + if type(img).__module__ == np.__name__: # numpy type + numpy_type = True + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + else: + numpy_type = False + + in_c, in_h, in_w = img.size() + out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) + kernel_width = 4 + kernel = 'cubic' + + # get weights and indices + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices( + in_h, out_h, scale, kernel, kernel_width, antialiasing) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices( + in_w, out_w, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) + img_aug.narrow(1, sym_len_hs, in_h).copy_(img) + + sym_patch = img[:, :sym_len_hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_he:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_c, out_h, in_w) + kernel_width = weights_h.size(1) + for i in range(out_h): + idx = int(indices_h[i][0]) + for j in range(in_c): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose( + 0, 1).mv(weights_h[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) + out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_we:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_c, out_h, out_w) + kernel_width = weights_w.size(1) + for i in range(out_w): + idx = int(indices_w[i][0]) + for j in range(in_c): + out_2[j, :, i] = out_1_aug[j, :, + idx:idx + kernel_width].mv(weights_w[i]) + + if numpy_type: + out_2 = out_2.numpy().transpose(1, 2, 0) + return out_2 + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [ + -222.921, 135.576, -276.836 + ] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [ + -276.836, 135.576, -222.921 + ] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + convertion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError('The img type should be np.float32 or np.uint8, ' + f'but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace convertion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError('The dst_type should be np.float32 or np.uint8, ' + f'but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) diff --git a/basicsr/utils/misc.py b/basicsr/utils/misc.py new file mode 100644 index 0000000..200527c --- /dev/null +++ b/basicsr/utils/misc.py @@ -0,0 +1,139 @@ +import numpy as np +import os +import random +import time +import torch +from os import path as osp + +from .dist_util import master_only +from .logger import get_root_logger + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def mkdir_and_rename(path): + """mkdirs. If path exists, rename it with timestamp and create a new one. + + Args: + path (str): Folder path. + """ + if osp.exists(path): + new_name = path + '_archived_' + get_time_str() + print(f'Path already exists. Rename it to {new_name}', flush=True) + os.rename(path, new_name) + os.makedirs(path, exist_ok=True) + + +@master_only +def make_exp_dirs(opt): + """Make dirs for experiments.""" + path_opt = opt['path'].copy() + if opt['is_train']: + mkdir_and_rename(path_opt.pop('experiments_root')) + else: + mkdir_and_rename(path_opt.pop('results_root')) + for key, path in path_opt.items(): + if ('strict_load' not in key) and ('pretrain_network' + not in key) and ('resume' + not in key): + os.makedirs(path, exist_ok=True) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir( + entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. + """ + logger = get_root_logger() + if opt['path']['resume_state']: + # get all the networks + networks = [key for key in opt.keys() if key.startswith('network_')] + flag_pretrain = False + for network in networks: + if opt['path'].get(f'pretrain_{network}') is not None: + flag_pretrain = True + if flag_pretrain: + logger.warning( + 'pretrain_network path will be ignored during resuming.') + # set pretrained model paths + for network in networks: + name = f'pretrain_{network}' + basename = network.replace('network_', '') + if opt['path'].get('ignore_resume_networks') is None or ( + basename not in opt['path']['ignore_resume_networks']): + opt['path'][name] = osp.join( + opt['path']['models'], f'net_{basename}_{resume_iter}.pth') + logger.info(f"Set {name} to {opt['path'][name]}") + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formated file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py index f7717f0..3670d17 100644 --- a/basicsr/utils/options.py +++ b/basicsr/utils/options.py @@ -57,9 +57,10 @@ def parse(opt_path, is_train=True): dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) # paths - for key, path in opt['path'].items(): - if path and key != 'strict_load': - opt['path'][key] = osp.expanduser(path) + for key, val in opt['path'].items(): + if (val is not None) and ('resume_state' in key + or 'pretrain_network' in key): + opt['path'][key] = osp.expanduser(val) opt['path']['root'] = osp.abspath( osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) if is_train: diff --git a/basicsr/utils/util.py b/basicsr/utils/util.py deleted file mode 100644 index 7419e7b..0000000 --- a/basicsr/utils/util.py +++ /dev/null @@ -1,218 +0,0 @@ -import math -import mmcv -import numpy as np -import os -import random -import sys -import time -import torch -from mmcv.runner import get_time_str, master_only -from os import path as osp -from shutil import get_terminal_size -from torchvision.utils import make_grid - -from basicsr.utils import get_root_logger - - -def check_resume(opt, resume_iter): - """Check resume states and pretrain_model paths. - - Args: - opt (dict): Options. - resume_iter (int): Resume iteration. - """ - logger = get_root_logger() - if opt['path']['resume_state']: - # ignore pretrained model paths - if opt['path'].get('pretrain_model_g') is not None or opt['path'].get( - 'pretrain_model_d') is not None: - logger.warning( - 'pretrain_model path will be ignored during resuming.') - - # set pretrained model paths - opt['path']['pretrain_model_g'] = osp.join(opt['path']['models'], - f'net_g_{resume_iter}.pth') - logger.info( - f"Set pretrain_model_g to {opt['path']['pretrain_model_g']}") - - opt['path']['pretrain_model_d'] = osp.join(opt['path']['models'], - f'net_d_{resume_iter}.pth') - logger.info( - f"Set pretrain_model_d to {opt['path']['pretrain_model_d']}") - - -def mkdir_and_rename(path): - """mkdirs. If path exists, rename it with timestamp and create a new one. - - Args: - path (str): Folder path. - """ - if osp.exists(path): - new_name = path + '_archived_' + get_time_str() - print(f'Path already exists. Rename it to {new_name}', flush=True) - os.rename(path, new_name) - mmcv.mkdir_or_exist(path) - - -@master_only -def make_exp_dirs(opt): - """Make dirs for experiments.""" - path_opt = opt['path'].copy() - if opt['is_train']: - mkdir_and_rename(path_opt.pop('experiments_root')) - else: - mkdir_and_rename(path_opt.pop('results_root')) - path_opt.pop('strict_load') - for key, path in path_opt.items(): - if 'pretrain_model' not in key and 'resume' not in key: - mmcv.mkdir_or_exist(path) - - -def set_random_seed(seed): - """Set random seeds.""" - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def crop_border(imgs, crop_border): - """Crop borders of images. - - Args: - imgs (list[ndarray] | ndarray): Images with shape (h, w, c). - crop_border (int): Crop border for each end of height and weight. - - Returns: - list[ndarray]: Cropped images. - """ - if crop_border == 0: - return imgs - else: - if isinstance(imgs, list): - return [ - v[crop_border:-crop_border, crop_border:-crop_border, ...] - for v in imgs - ] - else: - return imgs[crop_border:-crop_border, crop_border:-crop_border, - ...] - - -def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): - """Convert torch Tensors into image numpy arrays. - - After clamping to [min, max], values will be normalized to [0, 1]. - - Args: - tensor (Tensor or list[Tensor]): Accept shapes: - 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); - 2) 3D Tensor of shape (3/1 x H x W); - 3) 2D Tensor of shape (H x W). - Tensor channel should be in RGB order. - out_type (numpy type): output types. If ``np.uint8``, transform outputs - to uint8 type with range [0, 255]; otherwise, float type with - range [0, 1]. Default: ``np.uint8``. - min_max (tuple[int]): min and max values for clamp. - - Returns: - (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of - shape (H x W). The channel order is BGR. - """ - if not (torch.is_tensor(tensor) or - (isinstance(tensor, list) - and all(torch.is_tensor(t) for t in tensor))): - raise TypeError( - f'tensor or list of tensors expected, got {type(tensor)}') - - if torch.is_tensor(tensor): - tensor = [tensor] - result = [] - for _tensor in tensor: - _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) - _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) - - n_dim = _tensor.dim() - if n_dim == 4: - img_np = make_grid( - _tensor, nrow=int(math.sqrt(_tensor.size(0))), - normalize=False).numpy() - img_np = np.transpose(img_np[[2, 1, 0], :, :], - (1, 2, 0)) # HWC, BGR - elif n_dim == 3: - img_np = _tensor.numpy() - img_np = np.transpose(img_np[[2, 1, 0], :, :], - (1, 2, 0)) # HWC, BGR - elif n_dim == 2: - img_np = _tensor.numpy() - else: - raise TypeError('Only support 4D, 3D or 2D tensor. ' - f'But received with dimension: {n_dim}') - if out_type == np.uint8: - # Unlike MATLAB, numpy.unit8() WILL NOT round by default. - img_np = (img_np * 255.0).round() - img_np = img_np.astype(out_type) - result.append(img_np) - if len(result) == 1: - result = result[0] - return result - - -class ProgressBar(object): - """A progress bar that can print the progress. - - Modified from: - https://github.com/hellock/cvbase/blob/master/cvbase/progress.py - """ - - def __init__(self, task_num=0, bar_width=50, start=True): - self.task_num = task_num - max_bar_width = self._get_max_bar_width() - self.bar_width = ( - bar_width if bar_width <= max_bar_width else max_bar_width) - self.completed = 0 - if start: - self.start() - - def _get_max_bar_width(self): - terminal_width, _ = get_terminal_size() - max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) - if max_bar_width < 10: - print(f'terminal width is too small ({terminal_width}), ' - 'please consider widen the terminal for better ' - 'progressbar visualization') - max_bar_width = 10 - return max_bar_width - - def start(self): - if self.task_num > 0: - sys.stdout.write(f"[{' ' * self.bar_width}] 0/{self.task_num}, " - f'elapsed: 0s, ETA:\nStart...\n') - else: - sys.stdout.write('completed: 0, elapsed: 0s') - sys.stdout.flush() - self.start_time = time.time() - - def update(self, msg='In progress...'): - self.completed += 1 - elapsed = time.time() - self.start_time + 1e-8 - fps = self.completed / elapsed - if self.task_num > 0: - percentage = self.completed / float(self.task_num) - eta = int(elapsed * (1 - percentage) / percentage + 0.5) - mark_width = int(self.bar_width * percentage) - bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) - sys.stdout.write('\033[2F') # cursor up 2 lines - sys.stdout.write( - '\033[J' - ) # clean the output (remove extra chars since last display) - sys.stdout.write( - f'[{bar_chars}] {self.completed}/{self.task_num}, ' - f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' - f'ETA: {eta:5}s\n{msg}\n') - else: - sys.stdout.write( - f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,' - f' {fps:.1f} tasks/s') - sys.stdout.flush() diff --git a/colab/README.md b/colab/README.md new file mode 100644 index 0000000..0e83739 --- /dev/null +++ b/colab/README.md @@ -0,0 +1,13 @@ +# Colab + +google colab logo + +To maintain a small size of BasicSR repo, we do not include the original colab notebooks in this repo, but provide links to the google colab. + +| Face Restoration| | +| :--- | :---: | +|DFDNet | [BasicSR_inference_DFDNet.ipynb](https://colab.research.google.com/drive/1RoNDeipp9yPjI3EbpEbUhn66k5Uzg4n8?usp=sharing)| +| **Super-Resolution**| | +|ESRGAN |[BasicSR_inference_ESRGAN.ipynb](https://colab.research.google.com/drive/1JQScYICvEC3VqaabLu-lxvq9h7kSV1ML?usp=sharing)| +| **Deblurring**| | +| **Denoise**| | diff --git a/docs/Config.md b/docs/Config.md index f2a0775..5a3b04f 100644 --- a/docs/Config.md +++ b/docs/Config.md @@ -127,11 +127,11 @@ network_g: ######################################################### path: # Path for pretrained models, usually end with pth - pretrain_model_g: ~ + pretrain_network_g: ~ # Whether to load pretrained models strictly, that is the corresponding parameter names should be the same - strict_load: true + strict_load_g: true # Path for resume state. Usually in the `experiments/exp_name/training_states` folder - # This argument will over-write the `pretrain_model_g` + # This argument will over-write the `pretrain_network_g` resume_state: ~ @@ -302,9 +302,9 @@ network_g: ################################################# path: ## Path for pretrained models, usually end with pth - pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth # Whether to load pretrained models strictly, that is the corresponding parameter names should be the same - strict_load: true + strict_load_g: true ########################################################## # The following are validation settings (Also for testing) diff --git a/docs/Config_CN.md b/docs/Config_CN.md index 6fa159d..6517110 100644 --- a/docs/Config_CN.md +++ b/docs/Config_CN.md @@ -126,11 +126,11 @@ network_g: ###################################### path: # 预训练模型的路径, 需要以pth结尾的模型 - pretrain_model_g: ~ + pretrain_network_g: ~ # 加载预训练模型的时候, 是否需要网络参数的名称严格对应 - strict_load: true + strict_load_g: true # 重启训练的状态路径, 一般在`experiments/exp_name/training_states`目录下 - # 这个设置了, 会覆盖 pretrain_model_g 的设定 + # 这个设置了, 会覆盖 pretrain_network_g 的设定 resume_state: ~ @@ -299,9 +299,9 @@ network_g: ############################# path: # 预训练模型的路径, 需要以pth结尾的模型 - pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth # 加载预训练模型的时候, 是否需要网络参数的名称严格对应 - strict_load: true + strict_load_g: true ################################## # 以下为Validation (也是测试)的设置 diff --git a/docs/DatasetPreparation.md b/docs/DatasetPreparation.md index b579f58..207df31 100644 --- a/docs/DatasetPreparation.md +++ b/docs/DatasetPreparation.md @@ -24,7 +24,7 @@ At present, there are three types of data storage formats supported: 1. Store in `hard disk` directly in the format of images / video frames. 1. Make [LMDB](https://lmdb.readthedocs.io/en/release/), which could accelerate the IO and decompression speed during training. -1. [memcached](https://memcached.org/) or [CEPH](https://ceph.io/) are also supported, if they are installed (usually on clusters). +1. [memcached](https://memcached.org/) is also supported, if they are installed (usually on clusters). #### How to Use @@ -115,7 +115,7 @@ For convenience, the binary content stored in LMDB dataset is encoded image by c **How to Make LMDB** We provide a script to make LMDB. Before running the script, we need to modify the corresponding parameters accordingly. At present, we support DIV2K, REDS and Vimeo90K datasets; other datasets can also be made in a similar way.
- `python scripts/create_lmdb.py` + `python scripts/data_preparation/create_lmdb.py` #### Data Pre-fetcher @@ -155,17 +155,17 @@ It is recommended to symlink the dataset root to `datasets` with the command `ln 1. Download the datasets from the [official DIV2K website](https://data.vision.ee.ethz.ch/cvl/DIV2K/).
1. Crop to sub-images: DIV2K has 2K resolution (e.g., 2048 × 1080) images but the training patches are usually small (e.g., 128x128 or 192x192). So there is a waste if reading the whole image but only using a very small part of it. In order to accelerate the IO speed during training, we crop the 2K resolution images to sub-images (here, we crop to 480x480 sub-images).
Note that the size of sub-images is different from the training patch size (`gt_size`) defined in the config file. Specifically, the cropped sub-images with 480x480 are stored. The dataloader will further randomly crop the sub-images to `GT_size x GT_size` patches for training.
- Run the script [extract_subimages.py](../scripts/extract_subimages.py): + Run the script [extract_subimages.py](../scripts/data_preparation/extract_subimages.py): ```python - python scripts/extract_subimages.py + python scripts/data_preparation/extract_subimages.py ``` Remember to modify the paths and configurations if you have different settings. -1. [Optional] Create LMDB files. Please refer to [LMDB Description](#LMDB-Description). `python scripts/create_lmdb.py`. Use the `create_lmdb_for_div2k` function and remember to modify the paths and configurations accordingly. +1. [Optional] Create LMDB files. Please refer to [LMDB Description](#LMDB-Description). `python scripts/data_preparation/create_lmdb.py`. Use the `create_lmdb_for_div2k` function and remember to modify the paths and configurations accordingly. 1. Test the dataloader with the script `tests/test_paired_image_dataset.py`. Remember to modify the paths and configurations accordingly. -1. [Optional] If you want to use meta_info_file, you may need to run `python scripts/generate_meta_info.py` to generate the meta_info_file. +1. [Optional] If you want to use meta_info_file, you may need to run `python scripts/data_preparation/generate_meta_info.py` to generate the meta_info_file. ### Common Image SR Datasets @@ -182,7 +182,7 @@ We provide a list of common image super-resolution datasets. Classical SR Training T91 91 images for training - Google Drive / Baidu Drive + Google Drive / Baidu Drive BSDS200 @@ -277,8 +277,8 @@ All the left clips are used for training. Note that it it not required to explic **Preparation Steps** 1. Download the datasets from the [official website](https://seungjunnah.github.io/Datasets/reds.html). -1. Regroup the training and validation datasets: `python scripts/regroup_reds_dataset.py` -1. [Optional] Make LMDB files when necessary. Please refer to [LMDB Description](#LMDB-Description). `python scripts/create_lmdb.py`. Use the `create_lmdb_for_reds` function and remember to modify the paths and configurations accordingly. +1. Regroup the training and validation datasets: `python scripts/data_preparation/regroup_reds_dataset.py` +1. [Optional] Make LMDB files when necessary. Please refer to [LMDB Description](#LMDB-Description). `python scripts/data_preparation/create_lmdb.py`. Use the `create_lmdb_for_reds` function and remember to modify the paths and configurations accordingly. 1. Test the dataloader with the script `tests/test_reds_dataset.py`. Remember to modify the paths and configurations accordingly. @@ -289,7 +289,7 @@ Remember to modify the paths and configurations accordingly. 1. Download the dataset: [`Septuplets dataset --> The original training + test set (82GB)`](http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip).This is the Ground-Truth (GT). There is a `sep_trainlist.txt` file listing the training samples in the download zip file. 1. Generate the low-resolution images (TODO) The low-resolution images in the Vimeo90K test dataset are generated with the MATLAB bicubic downsampling kernel. Use the script `data_scripts/generate_LR_Vimeo90K.m` (run in MATLAB) to generate the low-resolution images. -1. [Optional] Make LMDB files when necessary. Please refer to [LMDB Description](#LMDB-Description). `python scripts/create_lmdb.py`. Use the `create_lmdb_for_vimeo90k` function and remember to modify the paths and configurations accordingly. +1. [Optional] Make LMDB files when necessary. Please refer to [LMDB Description](#LMDB-Description). `python scripts/data_preparation/create_lmdb.py`. Use the `create_lmdb_for_vimeo90k` function and remember to modify the paths and configurations accordingly. 1. Test the dataloader with the script `tests/test_vimeo90k_dataset.py`. Remember to modify the paths and configurations accordingly. @@ -303,5 +303,5 @@ Training dataset: [FFHQ](https://github.com/NVlabs/ffhq-dataset). 1. Extract tfrecords to images or LMDBs. (TensorFlow is required to read tfrecords). For each resolution, we will create images folder or LMDB files separately. ```bash - python scripts/extract_images_from_tfrecords.py + python scripts/data_preparation/extract_images_from_tfrecords.py ``` diff --git a/docs/DatasetPreparation_CN.md b/docs/DatasetPreparation_CN.md index 7600256..b3e90a0 100644 --- a/docs/DatasetPreparation_CN.md +++ b/docs/DatasetPreparation_CN.md @@ -24,7 +24,7 @@ 1. 直接以图像/视频帧的格式存放在硬盘 2. 制作成 [LMDB](https://lmdb.readthedocs.io/en/release/). 训练数据使用这种形式, 一般会加快读取速度. -3. 若是支持 [Memcached](https://memcached.org/) 或 [Ceph](https://ceph.io/), 则可以使用. 它们一般应用在集群上. +3. 若是支持 [Memcached](https://memcached.org/), 则可以使用. 它们一般应用在集群上. #### 如何使用 @@ -116,7 +116,7 @@ DIV2K_train_HR_sub.lmdb **如何制作** 我们提供了脚本来制作. 在运行脚本前, 需要根据需求修改相应的参数. 目前支持 DIV2K, REDS 和 Vimeo90K 数据集; 其他数据集可仿照进行制作.
- `python scripts/create_lmdb.py` + `python scripts/data_preparation/create_lmdb.py` #### 预读取数据 @@ -155,17 +155,17 @@ DIV2K 数据集被广泛使用在图像复原的任务中. 1. 从[官网](https://data.vision.ee.ethz.ch/cvl/DIV2K)下载数据. 1. Crop to sub-images: 因为 DIV2K 数据集是 2K 分辨率的 (比如: 2048x1080), 而我们在训练的时候往往并不要那么大 (常见的是 128x128 或者 192x192 的训练patch). 因此我们可以先把2K的图片裁剪成有overlap的 480x480 的子图像块. 然后再由 dataloader 从这个 480x480 的子图像块中随机crop出 128x128 或者 192x192 的训练patch.
- 运行脚本 [extract_subimages.py](../scripts/extract_subimages.py): + 运行脚本 [extract_subimages.py](../scripts/data_preparation/extract_subimages.py): ```python - python scripts/extract_subimages.py + python scripts/data_preparation/extract_subimages.py ``` 使用之前可能需要修改文件里面的路径和配置参数. **注意**: sub-image 的尺寸和训练patch的尺寸 (`gt_size`) 是不同的. 我们先把2K分辨率的图像 crop 成 sub-images (往往是 480x480), 然后存储起来. 在训练的时候, dataloader会读取这些sub-images, 然后进一步随机裁剪成 `gt_size` x `gt_size`的大小. -1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/create_lmdb.py`, 注意选择`create_lmdb_for_div2k`函数, 并需要修改函数相应的配置和路径. +1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/data_preparation/create_lmdb.py`, 注意选择`create_lmdb_for_div2k`函数, 并需要修改函数相应的配置和路径. 1. 测试: `tests/test_paired_image_dataset.py`, 注意修改函数相应的配置和路径. -1. [可选] 若需要使用 meta_info_file, 运行 `python scripts/generate_meta_info.py` 来生成 meta_info_file. +1. [可选] 若需要使用 meta_info_file, 运行 `python scripts/data_preparation/generate_meta_info.py` 来生成 meta_info_file. ### 其他常见图像超分数据集 @@ -182,7 +182,7 @@ DIV2K 数据集被广泛使用在图像复原的任务中. Classical SR Training T91 91 images for training - Google Drive / Baidu Drive + Google Drive / Baidu Drive BSDS200 @@ -277,8 +277,8 @@ DIV2K 数据集被广泛使用在图像复原的任务中. **数据准备步骤** 1. 从[官网](https://seungjunnah.github.io/Datasets/reds.html)下载数据 -1. 整合 training 和 validation 数据: `python scripts/regroup_reds_dataset.py` -1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/create_lmdb.py`, 注意选择`create_lmdb_for_reds`函数, 并需要修改函数相应的配置和路径. +1. 整合 training 和 validation 数据: `python scripts/data_preparation/regroup_reds_dataset.py` +1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/data_preparation/create_lmdb.py`, 注意选择`create_lmdb_for_reds`函数, 并需要修改函数相应的配置和路径. 1. 测试: `python tests/test_reds_dataset.py`, 注意修改函数相应的配置和路径. ### Vimeo90K @@ -290,7 +290,7 @@ DIV2K 数据集被广泛使用在图像复原的任务中. 1. 下载数据: [`Septuplets dataset --> The original training + test set (82GB)`](http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip). 这些是Ground-Truth. 里面有`sep_trainlist.txt`文件来区分训练数据. 1. 生成低分辨率图片. (TODO) The low-resolution images in the Vimeo90K test dataset are generated with the MATLAB bicubic downsampling kernel. Use the script `data_scripts/generate_LR_Vimeo90K.m` (run in MATLAB) to generate the low-resolution images. -1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/create_lmdb.py`, 注意选择`create_lmdb_for_vimeo90k`函数, 并需要修改函数相应的配置和路径. +1. [可选] 若需要使用 LMDB, 则需要制作 LMDB, 参考 [LMDB具体说明](#LMDB具体说明). `python scripts/data_preparation/create_lmdb.py`, 注意选择`create_lmdb_for_vimeo90k`函数, 并需要修改函数相应的配置和路径. 1. 测试: `python tests/test_vimeo90k_dataset.py`, 注意修改函数相应的配置和路径. ## StyleGAN2 @@ -303,5 +303,5 @@ The low-resolution images in the Vimeo90K test dataset are generated with the MA 1. 从 tfrecords 提取到*图片*或者*LMDB*. (需要安装 TensorFlow 来读取 tfrecords). 我们对每一个分辨率的人脸都单独创建文件夹或者LMDB文件. ```bash - python scripts/extract_images_from_tfrecords.py + python scripts/data_preparation/extract_images_from_tfrecords.py ``` diff --git a/docs/DesignConvention.md b/docs/DesignConvention.md index 35ee55f..10d737a 100644 --- a/docs/DesignConvention.md +++ b/docs/DesignConvention.md @@ -34,7 +34,7 @@ Specifically, we implement it through `importlib` and `getattr`. Taking the data # scan all the files under the data folder with '_dataset' in file names data_folder = osp.dirname(osp.abspath(__file__)) dataset_filenames = [ - osp.splitext(osp.basename(v))[0] for v in mmcv.scandir(data_folder) + osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py') ] # import all the dataset modules diff --git a/docs/DesignConvention_CN.md b/docs/DesignConvention_CN.md index d3c16d3..536d6a6 100644 --- a/docs/DesignConvention_CN.md +++ b/docs/DesignConvention_CN.md @@ -36,7 +36,7 @@ # scan all the files under the data folder with '_dataset' in file names data_folder = osp.dirname(osp.abspath(__file__)) dataset_filenames = [ - osp.splitext(osp.basename(v))[0] for v in mmcv.scandir(data_folder) + osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py') ] # import all the dataset modules diff --git a/docs/HOWTOs.md b/docs/HOWTOs.md index a2d6433..ddda95f 100644 --- a/docs/HOWTOs.md +++ b/docs/HOWTOs.md @@ -8,23 +8,23 @@ 1. Download FFHQ dataset. Recommend to download the tfrecords files from [NVlabs/ffhq-dataset](https://github.com/NVlabs/ffhq-dataset). 1. Extract tfrecords to images or LMDBs (TensorFlow is required to read tfrecords): - > python scripts/extract_images_from_tfrecords.py + > python scripts/data_preparation/extract_images_from_tfrecords.py 1. Modify the config file in `options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml` 1. Train with distributed training. More training commands are in [TrainTest.md](TrainTest.md). > python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ_800k.yml --launcher pytorch -## How to test StyleGAN2 +## How to inference StyleGAN2 1. Download pre-trained models from **ModelZoo** ([Google Drive](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing), [百度网盘](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g)) to the `experiments/pretrained_models` folder. 1. Test. - > python tests/test_stylegan2.py + > python inference/inference_stylegan2.py 1. The results are in the `samples` folder. -## How to test DFDNet +## How to inference DFDNet 1. Install [dlib](http://dlib.net/), because DFDNet uses dlib to do face recognition and landmark detection. [Installation reference](https://github.com/davisking/dlib). 1. Clone dlib repo: `git clone git@github.com:davisking/dlib.git` @@ -43,6 +43,6 @@ 4. Prepare the testing dataset in the `datasets`, for example, we put images in the `datasets/TestWhole` folder. 5. Test. - > python tests/test_face_dfdnet.py --upscale_factor=2 --test_path datasets/TestWhole + > python inference/inference_dfdnet.py --upscale_factor=2 --test_path datasets/TestWhole 6. The results are in the `results/DFDNet` folder. diff --git a/docs/HOWTOs_CN.md b/docs/HOWTOs_CN.md index aad7f25..df2ab25 100644 --- a/docs/HOWTOs_CN.md +++ b/docs/HOWTOs_CN.md @@ -8,7 +8,7 @@ 1. 下载 FFHQ 数据集. 推荐从 [NVlabs/ffhq-dataset](https://github.com/NVlabs/ffhq-dataset) 下载 tfrecords 文件. 1. 从tfrecords 提取到*图片*或者*LMDB*. (需要安装 TensorFlow 来读取 tfrecords). - > python scripts/extract_images_from_tfrecords.py + > python scripts/data_preparation/extract_images_from_tfrecords.py 1. 修改配置文件 `options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml` 1. 使用分布式训练. 更多训练命令: [TrainTest_CN.md](TrainTest_CN.md) @@ -20,7 +20,7 @@ 1. 从 **ModelZoo** ([Google Drive](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing), [百度网盘](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g)) 下载预训练模型到 `experiments/pretrained_models` 文件夹. 1. 测试. - > python tests/test_stylegan2.py + > python inference/inference_stylegan2.py 1. 结果在 `samples` 文件夹 @@ -43,6 +43,6 @@ 4. 准备测试图片到 `datasets`, 比如说我们把测试图片放在 `datasets/TestWhole` 文件夹. 5. 测试. - > python tests/test_face_dfdnet.py --upscale_factor=2 --test_path datasets/TestWhole + > python inference/inference_dfdnet.py --upscale_factor=2 --test_path datasets/TestWhole 6. 结果在 `results/DFDNet` 文件夹. diff --git a/docs/Metrics.md b/docs/Metrics.md new file mode 100644 index 0000000..c4f0cb1 --- /dev/null +++ b/docs/Metrics.md @@ -0,0 +1,35 @@ +# Metrics + +[English](Metrics.md) **|** [简体中文](Metrics_CN.md) + +## PSNR and SSIM + +## NIQE + +## FID + +> FID measures the similarity between two datasets of images. It was shown to correlate well with human judgement of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks. +> FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network. + +References + +- https://github.com/mseitzer/pytorch-fid +- [GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium](https://arxiv.org/abs/1706.08500) +- [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337) + +### Pre-calculated FFHQ inception feature statistics + +Usually, we put the downloaded inception feature statistics in `basicsr/metrics`. + +:arrow_double_down: Google Drive: [metrics data](https://drive.google.com/drive/folders/13cWIQyHX3iNmZRJ5v8v3kdyrT9RBTAi9?usp=sharing) +:arrow_double_down: 百度网盘: [评价指标数据](https://pan.baidu.com/s/10mMKXSEgrC5y7m63W5vbMQ)
+ +| File Name | Dataset | Image Shape | Sample Numbers| +| :------------- | :----------:|:----------:|:----------:| +| inception_FFHQ_256-0948f50d.pth | FFHQ | 256 x 256 | 50,000 | +| inception_FFHQ_512-f7b384ab.pth | FFHQ | 512 x 512 | 50,000 | +| inception_FFHQ_1024-75f195dc.pth | FFHQ | 1024 x 1024 | 50,000 | +| inception_FFHQ_256_stylegan2_pytorch-abba9d31.pth | FFHQ | 256 x 256 | 50,000 | + +- All the FFHQ inception feature statistics calculated on the resized 299 x 299 size. +- `inception_FFHQ_256_stylegan2_pytorch-abba9d31.pth` is converted from the statistics in [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch). diff --git a/docs/Metrics_CN.md b/docs/Metrics_CN.md new file mode 100644 index 0000000..c5f518c --- /dev/null +++ b/docs/Metrics_CN.md @@ -0,0 +1,36 @@ +# 评价指标 + +[English](Metrics.md) **|** [简体中文](Metrics_CN.md) + +## PSNR and SSIM + +## NIQE + +## FID + +> FID measures the similarity between two datasets of images. It was shown to correlate well with human judgement of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks. +> FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network. + +参考 + +- https://github.com/mseitzer/pytorch-fid +- [GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium](https://arxiv.org/abs/1706.08500) +- [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337) + +### Pre-calculated FFHQ inception feature statistics + +通常, 我们把下载的 inception 网络的特征统计数据 (用于计算FID) 放在 `basicsr/metrics`. + + +:arrow_double_down: 百度网盘: [评价指标数据](https://pan.baidu.com/s/10mMKXSEgrC5y7m63W5vbMQ) +:arrow_double_down: Google Drive: [metrics data](https://drive.google.com/drive/folders/13cWIQyHX3iNmZRJ5v8v3kdyrT9RBTAi9?usp=sharing)
+ +| File Name | Dataset | Image Shape | Sample Numbers| +| :------------- | :----------:|:----------:|:----------:| +| inception_FFHQ_256-0948f50d.pth | FFHQ | 256 x 256 | 50,000 | +| inception_FFHQ_512-f7b384ab.pth | FFHQ | 512 x 512 | 50,000 | +| inception_FFHQ_1024-75f195dc.pth | FFHQ | 1024 x 1024 | 50,000 | +| inception_FFHQ_256_stylegan2_pytorch-abba9d31.pth | FFHQ | 256 x 256 | 50,000 | + +- All the FFHQ inception feature statistics calculated on the resized 299 x 299 size. +- `inception_FFHQ_256_stylegan2_pytorch-abba9d31.pth` is converted from the statistics in [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch). diff --git a/docs/ModelZoo.md b/docs/ModelZoo.md index af6579c..4dd25aa 100644 --- a/docs/ModelZoo.md +++ b/docs/ModelZoo.md @@ -2,6 +2,11 @@ [English](ModelZoo.md) **|** [简体中文](ModelZoo_CN.md) +:arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing) +:arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ) + +--- + We provide: 1. Official models converted directly from official released models @@ -9,7 +14,7 @@ We provide: You can put the downloaded models in the `experiments/pretrained_models` folder. -**[Download official pre-trained models]** ([Google Drive](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing), [百度网盘](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g))(https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing)) +**[Download official pre-trained models]** ([Google Drive](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing), [百度网盘](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g)) You can use the scrip to download pre-trained models from Google Drive. @@ -93,7 +98,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity) - **L** (Large): # of channels = 128, # of back residual blocks = 40. This setting is used in our competition submission. - **M** (Moderate): # of channels = 64, # of back residual blocks = 10. -[Download Models from Google Drive](https://drive.google.com/open?id=1WfROVUqKOBS5gGvQzBfU1DNZ4XwPA3LD) | Model name |[Test Set] PSNR/SSIM | |:----------:|:----------:| @@ -107,7 +111,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity) 1 Y or RGB denotes the evaluation on Y (luminance) or RGB channels. #### Stage 2 models for the NTIRE19 Competition -[Download Models from Google Drive](https://drive.google.com/drive/folders/1PMoy1cKlIYWly6zY0tG2Q4YAH7V_HZns?usp=sharing) | Model name |[Test Set] PSNR/SSIM | |:----------:|:----------:| @@ -119,7 +122,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity) ## DUF The models are converted from the [officially released models](https://github.com/yhjo09/VSR-DUF).
-[Download Models from Google Drive](https://drive.google.com/open?id=1seY9nclMuwk_SpqKQhx1ItTcQShM5R50) | Model name | [Test Set] PSNR/SSIM1 | Official Results2 | |:----------:|:----------:|:----------:| @@ -136,7 +138,6 @@ The models are converted from the [officially released models](https://github.co ## TOF The models are converted from the [officially released models](https://github.com/anchen1011/toflow).
-[Download Models from Google Drive](https://drive.google.com/open?id=18kJcxPLeNK8e0kYEiwmsnu9wVmhdMFFG) | Model name | [Test Set] PSNR/SSIM | Official Results1 | |:----------:|:----------:|:----------:| diff --git a/docs/ModelZoo_CN.md b/docs/ModelZoo_CN.md index 9290a91..b192ee1 100644 --- a/docs/ModelZoo_CN.md +++ b/docs/ModelZoo_CN.md @@ -2,6 +2,11 @@ [English](ModelZoo.md) **|** [简体中文](ModelZoo_CN.md) +:arrow_double_down: 百度网盘: [预训练模型](https://pan.baidu.com/s/1R6Nc4v3cl79XPAiK0Toe7g) **|** [复现实验](https://pan.baidu.com/s/1UElD6q8sVAgn_cxeBDOlvQ) +:arrow_double_down: Google Drive: [Pretrained Models](https://drive.google.com/drive/folders/15DgDtfaLASQ3iAPJEVHQF49g9msexECG?usp=sharing) **|** [Reproduced Experiments](https://drive.google.com/drive/folders/1XN4WXKJ53KQ0Cu0Yv-uCt8DZWq6uufaP?usp=sharing) + +--- + 我们提供了: 1. 官方的模型, 它们是从官方release的models直接转化过来的 @@ -92,8 +97,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity) - **L** (Large): # of channels = 128, # of back residual blocks = 40. This setting is used in our competition submission. - **M** (Moderate): # of channels = 64, # of back residual blocks = 10. -[Download Models from Google Drive](https://drive.google.com/open?id=1WfROVUqKOBS5gGvQzBfU1DNZ4XwPA3LD) - | Model name |[Test Set] PSNR/SSIM | |:----------:|:----------:| | EDVR_Vimeo90K_SR_L | [Vid4] (Y1) 27.35/0.8264 [[↓Results]](https://drive.google.com/open?id=14nozpSfe9kC12dVuJ9mspQH5ZqE4mT9K)
(RGB) 25.83/0.8077| @@ -106,7 +109,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity) 1 Y or RGB denotes the evaluation on Y (luminance) or RGB channels. #### Stage 2 models for the NTIRE19 Competition -[Download Models from Google Drive](https://drive.google.com/drive/folders/1PMoy1cKlIYWly6zY0tG2Q4YAH7V_HZns?usp=sharing) | Model name |[Test Set] PSNR/SSIM | |:----------:|:----------:| @@ -118,7 +120,6 @@ EDVR\_(training dataset)\_(track name)\_(model complexity) ## DUF The models are converted from the [officially released models](https://github.com/yhjo09/VSR-DUF).
-[Download Models from Google Drive](https://drive.google.com/open?id=1seY9nclMuwk_SpqKQhx1ItTcQShM5R50) | Model name | [Test Set] PSNR/SSIM1 | Official Results2 | |:----------:|:----------:|:----------:| @@ -135,7 +136,6 @@ The models are converted from the [officially released models](https://github.co ## TOF The models are converted from the [officially released models](https://github.com/anchen1011/toflow).
-[Download Models from Google Drive](https://drive.google.com/open?id=18kJcxPLeNK8e0kYEiwmsnu9wVmhdMFFG) | Model name | [Test Set] PSNR/SSIM | Official Results1 | |:----------:|:----------:|:----------:| diff --git a/inference/inference_dfdnet.py b/inference/inference_dfdnet.py new file mode 100644 index 0000000..982c524 --- /dev/null +++ b/inference/inference_dfdnet.py @@ -0,0 +1,210 @@ +import argparse +import glob +import numpy as np +import os +import torch +import torchvision.transforms as transforms +from skimage import io + +from basicsr.models.archs.dfdnet_arch import DFDNet +from basicsr.utils import imwrite, tensor2img +from basicsr.utils.face_util import FaceRestorationHelper + + +def get_part_location(landmarks): + """Get part locations from landmarks.""" + map_left_eye = list(np.hstack((range(17, 22), range(36, 42)))) + map_right_eye = list(np.hstack((range(22, 27), range(42, 48)))) + map_nose = list(range(29, 36)) + map_mouth = list(range(48, 68)) + + # left eye + mean_left_eye = np.mean(landmarks[map_left_eye], 0) # (x, y) + half_len_left_eye = np.max((np.max( + np.max(landmarks[map_left_eye], 0) - + np.min(landmarks[map_left_eye], 0)) / 2, 16)) # A number + loc_left_eye = np.hstack((mean_left_eye - half_len_left_eye + 1, + mean_left_eye + half_len_left_eye)).astype(int) + loc_left_eye = torch.from_numpy(loc_left_eye).unsqueeze(0) + # (1, 4), the four numbers forms two coordinates in the diagonal + + # right eye + mean_right_eye = np.mean(landmarks[map_right_eye], 0) + half_len_right_eye = np.max((np.max( + np.max(landmarks[map_right_eye], 0) - + np.min(landmarks[map_right_eye], 0)) / 2, 16)) + loc_right_eye = np.hstack( + (mean_right_eye - half_len_right_eye + 1, + mean_right_eye + half_len_right_eye)).astype(int) + loc_right_eye = torch.from_numpy(loc_right_eye).unsqueeze(0) + # nose + mean_nose = np.mean(landmarks[map_nose], 0) + half_len_nose = np.max((np.max( + np.max(landmarks[map_nose], 0) - np.min(landmarks[map_nose], 0)) / 2, + 16)) # noqa: E126 + loc_nose = np.hstack( + (mean_nose - half_len_nose + 1, mean_nose + half_len_nose)).astype(int) + loc_nose = torch.from_numpy(loc_nose).unsqueeze(0) + # mouth + mean_mouth = np.mean(landmarks[map_mouth], 0) + half_len_mouth = np.max((np.max( + np.max(landmarks[map_mouth], 0) - np.min(landmarks[map_mouth], 0)) / 2, + 16)) # noqa: E126 + loc_mouth = np.hstack((mean_mouth - half_len_mouth + 1, + mean_mouth + half_len_mouth)).astype(int) + loc_mouth = torch.from_numpy(loc_mouth).unsqueeze(0) + + return loc_left_eye, loc_right_eye, loc_nose, loc_mouth + + +if __name__ == '__main__': + """We try to align to the official codes. But there are still slight + differences: 1) we use dlib for 68 landmark detection; 2) the used image + package are different (especially for reading and writing.) + """ + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + parser = argparse.ArgumentParser() + + parser.add_argument('--upscale_factor', type=int, default=2) + parser.add_argument( + '--model_path', + type=str, + default= # noqa: E251 + 'experiments/pretrained_models/DFDNet/DFDNet_official-d1fa5650.pth') + parser.add_argument( + '--dict_path', + type=str, + default= # noqa: E251 + 'experiments/pretrained_models/DFDNet/DFDNet_dict_512-f79685f0.pth') + parser.add_argument('--test_path', type=str, default='datasets/TestWhole') + parser.add_argument('--upsample_num_times', type=int, default=1) + parser.add_argument('--save_inverse_affine', action='store_true') + parser.add_argument('--only_keep_largest', action='store_true') + # The official codes use skimage.io to read the cropped images from disk + # instead of directly using the intermediate results in the memory (as we + # do). Such a different operation brings slight differences due to + # skimage.io. For aligning with the official results, we could set the + # official_adaption to True. + parser.add_argument('--official_adaption', type=bool, default=True) + + # The following are the paths for dlib models + parser.add_argument( + '--detection_path', + type=str, + default= # noqa: E251 + 'experiments/pretrained_models/dlib/mmod_human_face_detector-4cb19393.dat' # noqa: E501 + ) + parser.add_argument( + '--landmark5_path', + type=str, + default= # noqa: E251 + 'experiments/pretrained_models/dlib/shape_predictor_5_face_landmarks-c4b1e980.dat' # noqa: E501 + ) + parser.add_argument( + '--landmark68_path', + type=str, + default= # noqa: E251 + 'experiments/pretrained_models/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat' # noqa: E501 + ) + + args = parser.parse_args() + if args.test_path.endswith('/'): # solve when path ends with / + args.test_path = args.test_path[:-1] + result_root = f'results/DFDNet/{os.path.basename(args.test_path)}' + + # set up the DFDNet + net = DFDNet(64, dict_path=args.dict_path).to(device) + checkpoint = torch.load( + args.model_path, map_location=lambda storage, loc: storage) + net.load_state_dict(checkpoint['params']) + net.eval() + + save_crop_root = os.path.join(result_root, 'cropped_faces') + save_inverse_affine_root = os.path.join(result_root, 'inverse_affine') + os.makedirs(save_inverse_affine_root, exist_ok=True) + save_restore_root = os.path.join(result_root, 'restored_faces') + save_final_root = os.path.join(result_root, 'final_results') + + face_helper = FaceRestorationHelper(args.upscale_factor, face_size=512) + + # scan all the jpg and png images + for img_path in sorted( + glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))): + img_name = os.path.basename(img_path) + print(f'Processing {img_name} image ...') + save_crop_path = os.path.join(save_crop_root, img_name) + if args.save_inverse_affine: + save_inverse_affine_path = os.path.join(save_inverse_affine_root, + img_name) + else: + save_inverse_affine_path = None + + face_helper.init_dlib(args.detection_path, args.landmark5_path, + args.landmark68_path) + # detect faces + num_det_faces = face_helper.detect_faces( + img_path, + upsample_num_times=args.upsample_num_times, + only_keep_largest=args.only_keep_largest) + # get 5 face landmarks for each face + num_landmarks = face_helper.get_face_landmarks_5() + print(f'\tDetect {num_det_faces} faces, {num_landmarks} landmarks.') + # warp and crop each face + face_helper.warp_crop_faces(save_crop_path, save_inverse_affine_path) + + if args.official_adaption: + path, ext = os.path.splitext(save_crop_path) + pathes = sorted(glob.glob(f'{path}_[0-9]*.png')) + cropped_faces = [io.imread(path) for path in pathes] + else: + cropped_faces = face_helper.cropped_faces + + # get 68 landmarks for each cropped face + num_landmarks = face_helper.get_face_landmarks_68() + print(f'\tDetect {num_landmarks} faces for 68 landmarks.') + + face_helper.free_dlib_gpu_memory() + + print('\tFace restoration ...') + # face restoration for each cropped face + assert len(cropped_faces) == len(face_helper.all_landmarks_68) + for idx, (cropped_face, landmarks) in enumerate( + zip(cropped_faces, face_helper.all_landmarks_68)): + if landmarks is None: + print(f'Landmarks is None, skip cropped faces with idx {idx}.') + # just copy the cropped faces to the restored faces + restored_face = cropped_face + else: + # prepare data + part_locations = get_part_location(landmarks) + cropped_face = transforms.ToTensor()(cropped_face) + cropped_face = transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))( + cropped_face) + cropped_face = cropped_face.unsqueeze(0).to(device) + + try: + with torch.no_grad(): + output = net(cropped_face, part_locations) + restored_face = tensor2img(output, min_max=(-1, 1)) + del output + torch.cuda.empty_cache() + except Exception as e: + print(f'DFDNet inference fail: {e}') + restored_face = tensor2img(cropped_face, min_max=(-1, 1)) + + path = os.path.splitext(os.path.join(save_restore_root, + img_name))[0] + save_path = f'{path}_{idx:02d}.png' + imwrite(restored_face, save_path) + face_helper.add_restored_face(restored_face) + + print('\tGenerate the final result ...') + # paste each restored face to the input image + face_helper.paste_faces_to_input_image( + os.path.join(save_final_root, img_name)) + + # clean all the intermediate results to process the next image + face_helper.clean_all() + + print(f'\nAll results are saved in {result_root}') diff --git a/inference/inference_esrgan.py b/inference/inference_esrgan.py new file mode 100644 index 0000000..8c64966 --- /dev/null +++ b/inference/inference_esrgan.py @@ -0,0 +1,55 @@ +import argparse +import cv2 +import glob +import numpy as np +import os +import torch + +from basicsr.models.archs.rrdbnet_arch import RRDBNet + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--model_path', + type=str, + default= # noqa: E251 + 'experiments/pretrained_models/ESRGAN/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth' # noqa: E501 + ) + parser.add_argument( + '--folder', + type=str, + default='datasets/Set14/LRbicx4', + help='input test image folder') + args = parser.parse_args() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # set up model + model = RRDBNet( + num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32) + model.load_state_dict(torch.load(args.model_path)['params'], strict=True) + model.eval() + model = model.to(device) + + os.makedirs('results/ESRGAN', exist_ok=True) + for idx, path in enumerate( + sorted(glob.glob(os.path.join(args.folder, '*')))): + imgname = os.path.splitext(os.path.basename(path))[0] + print('Testing', idx, imgname) + # read image + img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255. + img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], + (2, 0, 1))).float() + img = img.unsqueeze(0).to(device) + # inference + with torch.no_grad(): + output = model(img) + # save image + output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) + output = (output * 255.0).round().astype(np.uint8) + cv2.imwrite(f'results/ESRGAN/{imgname}_ESRGAN.png', output) + + +if __name__ == '__main__': + main() diff --git a/tests/test_stylegan2.py b/inference/inference_stylegan2.py similarity index 89% rename from tests/test_stylegan2.py rename to inference/inference_stylegan2.py index c166e64..47bbe47 100644 --- a/tests/test_stylegan2.py +++ b/inference/inference_stylegan2.py @@ -1,6 +1,6 @@ import argparse import math -import mmcv +import os import torch from torchvision import utils @@ -30,7 +30,7 @@ def generate(args, g_ema, device, mean_latent, randomize_noise): if __name__ == '__main__': - device = 'cuda' + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') parser = argparse.ArgumentParser() @@ -43,7 +43,7 @@ def generate(args, g_ema, device, mean_latent, randomize_noise): '--ckpt', type=str, default= # noqa: E251 - 'experiments/pretrained_models/stylegan2_ffhq_config_f_1024_official-b09c3668.pth' # noqa: E501 + 'experiments/pretrained_models/StyleGAN/stylegan2_ffhq_config_f_1024_official-b09c3668.pth' # noqa: E501 ) parser.add_argument('--channel_multiplier', type=int, default=2) parser.add_argument('--randomize_noise', type=bool, default=True) @@ -52,7 +52,7 @@ def generate(args, g_ema, device, mean_latent, randomize_noise): args.latent = 512 args.n_mlp = 8 - mmcv.mkdir_or_exist('samples') + os.makedirs('samples', exist_ok=True) set_random_seed(2020) g_ema = StyleGAN2Generator( diff --git a/make.sh b/make.sh deleted file mode 100644 index 1990c6b..0000000 --- a/make.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env bash - -# You may need to modify the following paths before compiling -CUDA_HOME=/usr/local/cuda \ -CUDNN_INCLUDE_DIR=/usr/local/cuda \ -CUDNN_LIB_DIR=/usr/local/cuda \ -python setup.py develop diff --git a/options/test/DUF/test_DUF_official.yml b/options/test/DUF/test_DUF_official.yml index 5d16dd2..d0bc81c 100644 --- a/options/test/DUF/test_DUF_official.yml +++ b/options/test/DUF/test_DUF_official.yml @@ -28,8 +28,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/DUF_x4_52L_official-483d2c78.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/DUF/DUF_x4_52L_official-483d2c78.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDSR/test_EDSR_Lx2.yml b/options/test/EDSR/test_EDSR_Lx2.yml index 05a1398..82dcb49 100644 --- a/options/test/EDSR/test_EDSR_Lx2.yml +++ b/options/test/EDSR/test_EDSR_Lx2.yml @@ -43,8 +43,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDSR/test_EDSR_Lx3.yml b/options/test/EDSR/test_EDSR_Lx3.yml index c7c951c..6053ba6 100644 --- a/options/test/EDSR/test_EDSR_Lx3.yml +++ b/options/test/EDSR/test_EDSR_Lx3.yml @@ -43,8 +43,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDSR/test_EDSR_Lx4.yml b/options/test/EDSR/test_EDSR_Lx4.yml index e9a55e0..37bb209 100644 --- a/options/test/EDSR/test_EDSR_Lx4.yml +++ b/options/test/EDSR/test_EDSR_Lx4.yml @@ -43,8 +43,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDSR/test_EDSR_Mx2.yml b/options/test/EDSR/test_EDSR_Mx2.yml index f18dae7..b6ab304 100644 --- a/options/test/EDSR/test_EDSR_Mx2.yml +++ b/options/test/EDSR/test_EDSR_Mx2.yml @@ -43,8 +43,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDSR/test_EDSR_Mx3.yml b/options/test/EDSR/test_EDSR_Mx3.yml index 612f213..c799603 100644 --- a/options/test/EDSR/test_EDSR_Mx3.yml +++ b/options/test/EDSR/test_EDSR_Mx3.yml @@ -43,8 +43,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDSR/test_EDSR_Mx4.yml b/options/test/EDSR/test_EDSR_Mx4.yml index 0d52ef1..2686861 100644 --- a/options/test/EDSR/test_EDSR_Mx4.yml +++ b/options/test/EDSR/test_EDSR_Mx4.yml @@ -43,8 +43,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDSR/EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDVR/test_EDVR_L_deblur_REDS.yml b/options/test/EDVR/test_EDVR_L_deblur_REDS.yml index 1576fc1..6982ab8 100644 --- a/options/test/EDVR/test_EDVR_L_deblur_REDS.yml +++ b/options/test/EDVR/test_EDVR_L_deblur_REDS.yml @@ -35,8 +35,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDVR_L_deblur_REDS_official-ca46bd8c.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_deblur_REDS_official-ca46bd8c.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml b/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml index fbb243d..4108a2a 100644 --- a/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml +++ b/options/test/EDVR/test_EDVR_L_deblurcomp_REDS.yml @@ -35,8 +35,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDVR_L_deblurcomp_REDS_official-0e988e5c.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_deblurcomp_REDS_official-0e988e5c.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml b/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml index bd75815..768c173 100644 --- a/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml +++ b/options/test/EDVR/test_EDVR_L_x4_SR_REDS.yml @@ -35,8 +35,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDVR_L_x4_SR_REDS_official-9f5f5039.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_x4_SR_REDS_official-9f5f5039.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml b/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml index 7428355..9929067 100644 --- a/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml +++ b/options/test/EDVR/test_EDVR_L_x4_SR_Vid4.yml @@ -34,8 +34,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml b/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml index 21cf0bf..dff07d8 100644 --- a/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml +++ b/options/test/EDVR/test_EDVR_L_x4_SR_Vimeo90K.yml @@ -35,8 +35,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml b/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml index ed4ed55..fbe2b1b 100644 --- a/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml +++ b/options/test/EDVR/test_EDVR_L_x4_SRblur_REDS.yml @@ -35,8 +35,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth + strict_load_g: true # validation settings val: diff --git a/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml b/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml index 95271f8..773286d 100644 --- a/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml +++ b/options/test/EDVR/test_EDVR_M_x4_SR_REDS.yml @@ -35,8 +35,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/EDVR_M_x4_SR_REDS_official-32075921.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/EDVR/EDVR_M_x4_SR_REDS_official-32075921.pth + strict_load_g: true # validation settings val: diff --git a/options/test/ESRGAN/test_ESRGAN_x4.yml b/options/test/ESRGAN/test_ESRGAN_x4.yml index 1d23fb8..845789c 100644 --- a/options/test/ESRGAN/test_ESRGAN_x4.yml +++ b/options/test/ESRGAN/test_ESRGAN_x4.yml @@ -40,8 +40,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/ESRGAN/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth + strict_load_g: true # validation settings val: diff --git a/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml b/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml index 997381d..d428740 100644 --- a/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml +++ b/options/test/ESRGAN/test_ESRGAN_x4_woGT.yml @@ -29,8 +29,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/ESRGAN/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth + strict_load_g: true # validation settings val: diff --git a/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml b/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml index 7c39a50..9636f22 100644 --- a/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml +++ b/options/test/ESRGAN/test_RRDBNet_PSNR_x4.yml @@ -40,8 +40,8 @@ network_g: # path path: - pretrain_model_g: experiments/pretrained_models/ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/ESRGAN/ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth + strict_load_g: true # validation settings val: diff --git a/options/test/RCAN/test_RCAN.yml b/options/test/RCAN/test_RCAN.yml index 7734d91..3f22dd2 100644 --- a/options/test/RCAN/test_RCAN.yml +++ b/options/test/RCAN/test_RCAN.yml @@ -49,5 +49,5 @@ save_img: true # path path: - pretrain_model_g: ./experiments/pretrained_models/RCAN_BIX4-official.pth - strict_load: true + pretrain_network_g: ./experiments/pretrained_models/RCAN/RCAN_BIX4-official.pth + strict_load_g: true diff --git a/options/test/SRResNet_SRGAN/test_MSRGAN_x4.yml b/options/test/SRResNet_SRGAN/test_MSRGAN_x4.yml index 517150c..5fef091 100644 --- a/options/test/SRResNet_SRGAN/test_MSRGAN_x4.yml +++ b/options/test/SRResNet_SRGAN/test_MSRGAN_x4.yml @@ -40,8 +40,8 @@ network_g: # path path: - pretrain_model_g: experiments/004_MSRGAN_x4_f64b16_DIV2K_400k_B16G1_wandb/models/net_g_400000.pth - strict_load: true + pretrain_network_g: experiments/004_MSRGAN_x4_f64b16_DIV2K_400k_B16G1_wandb/models/net_g_400000.pth + strict_load_g: true # validation settings val: diff --git a/options/test/SRResNet_SRGAN/test_MSRResNet_x2.yml b/options/test/SRResNet_SRGAN/test_MSRResNet_x2.yml index 29c09f8..d76411d 100644 --- a/options/test/SRResNet_SRGAN/test_MSRResNet_x2.yml +++ b/options/test/SRResNet_SRGAN/test_MSRResNet_x2.yml @@ -40,8 +40,8 @@ network_g: # path path: - pretrain_model_g: experiments/002_MSRResNet_x2_f64b16_DIV2K_1000k_B16G1_001pretrain_wandb/models/net_g_1000000.pth - strict_load: true + pretrain_network_g: experiments/002_MSRResNet_x2_f64b16_DIV2K_1000k_B16G1_001pretrain_wandb/models/net_g_1000000.pth + strict_load_g: true # validation settings val: diff --git a/options/test/SRResNet_SRGAN/test_MSRResNet_x3.yml b/options/test/SRResNet_SRGAN/test_MSRResNet_x3.yml index 91b4e7f..0e8dc78 100644 --- a/options/test/SRResNet_SRGAN/test_MSRResNet_x3.yml +++ b/options/test/SRResNet_SRGAN/test_MSRResNet_x3.yml @@ -40,8 +40,8 @@ network_g: # path path: - pretrain_model_g: experiments/003_MSRResNet_x3_f64b16_DIV2K_1000k_B16G1_001pretrain_wandb/models/net_g_1000000.pth - strict_load: true + pretrain_network_g: experiments/003_MSRResNet_x3_f64b16_DIV2K_1000k_B16G1_001pretrain_wandb/models/net_g_1000000.pth + strict_load_g: true # validation settings val: diff --git a/options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml b/options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml index c5b0e32..ce5e1cf 100644 --- a/options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml +++ b/options/test/SRResNet_SRGAN/test_MSRResNet_x4.yml @@ -40,8 +40,8 @@ network_g: # path path: - pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth - strict_load: true + pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + strict_load_g: true # validation settings val: diff --git a/options/test/SRResNet_SRGAN/test_MSRResNet_x4_woGT.yml b/options/test/SRResNet_SRGAN/test_MSRResNet_x4_woGT.yml index 8e499cf..cdc8ea7 100644 --- a/options/test/SRResNet_SRGAN/test_MSRResNet_x4_woGT.yml +++ b/options/test/SRResNet_SRGAN/test_MSRResNet_x4_woGT.yml @@ -29,8 +29,8 @@ network_g: # path path: - pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth - strict_load: true + pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + strict_load_g: true # validation settings val: diff --git a/options/test/TOF/test_TOF_official.yml b/options/test/TOF/test_TOF_official.yml index ab916c7..f61dbaf 100644 --- a/options/test/TOF/test_TOF_official.yml +++ b/options/test/TOF/test_TOF_official.yml @@ -26,8 +26,8 @@ save_img: true # path path: - pretrain_model_g: experiments/pretrained_models/tof_official-e81c455f.pth - strict_load: true + pretrain_network_g: experiments/pretrained_models/TOF/tof_official-e81c455f.pth + strict_load_g: true # validation settings val: diff --git a/options/train/EDSR/train_EDSR_Lx2.yml b/options/train/EDSR/train_EDSR_Lx2.yml index da645b7..bb3167e 100644 --- a/options/train/EDSR/train_EDSR_Lx2.yml +++ b/options/train/EDSR/train_EDSR_Lx2.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/EDSR/train_EDSR_Lx3.yml b/options/train/EDSR/train_EDSR_Lx3.yml index 7b6ae45..326d95e 100644 --- a/options/train/EDSR/train_EDSR_Lx3.yml +++ b/options/train/EDSR/train_EDSR_Lx3.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: experiments/204_EDSR_Lx2_f256b32_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth - strict_load: false + pretrain_network_g: experiments/204_EDSR_Lx2_f256b32_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/EDSR/train_EDSR_Lx4.yml b/options/train/EDSR/train_EDSR_Lx4.yml index 6fe945c..ffd3a60 100644 --- a/options/train/EDSR/train_EDSR_Lx4.yml +++ b/options/train/EDSR/train_EDSR_Lx4.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: experiments/204_EDSR_Lx2_f256b32_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth - strict_load: false + pretrain_network_g: experiments/204_EDSR_Lx2_f256b32_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/EDSR/train_EDSR_Mx2.yml b/options/train/EDSR/train_EDSR_Mx2.yml index 37410f0..b8c81f9 100644 --- a/options/train/EDSR/train_EDSR_Mx2.yml +++ b/options/train/EDSR/train_EDSR_Mx2.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/EDSR/train_EDSR_Mx3.yml b/options/train/EDSR/train_EDSR_Mx3.yml index 7f473a0..bd44e87 100644 --- a/options/train/EDSR/train_EDSR_Mx3.yml +++ b/options/train/EDSR/train_EDSR_Mx3.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: experiments/201_EDSR_Mx2_f64b16_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth - strict_load: false + pretrain_network_g: experiments/201_EDSR_Mx2_f64b16_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/EDSR/train_EDSR_Mx4.yml b/options/train/EDSR/train_EDSR_Mx4.yml index aa12b57..0f5e583 100644 --- a/options/train/EDSR/train_EDSR_Mx4.yml +++ b/options/train/EDSR/train_EDSR_Mx4.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: experiments/201_EDSR_Mx2_f64b16_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth - strict_load: false + pretrain_network_g: experiments/201_EDSR_Mx2_f64b16_DIV2K_300k_B16G1_wandb/models/net_g_300000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml b/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml index 0623d4c..ec5a78e 100644 --- a/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml +++ b/options/train/EDVR/train_EDVRM_woTSA_GAN_TODO.yml @@ -71,8 +71,8 @@ network_d: # path path: - pretrain_model_g: experiments/101_EDVR_M_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth - strict_load: true + pretrain_network_g: experiments/101_EDVR_M_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth + strict_load_g: true resume_state: ~ # training settings @@ -107,9 +107,9 @@ train: 'conv5_4': 1 # before relu vgg_type: vgg19 use_input_norm: true + range_norm: false perceptual_weight: 1.0 style_weight: 0 - norm_img: false criterion: l1 gan_opt: type: GANLoss diff --git a/options/train/EDVR/train_EDVR_L_x4_SR_REDS.yml b/options/train/EDVR/train_EDVR_L_x4_SR_REDS.yml index d0bb472..bcc6418 100644 --- a/options/train/EDVR/train_EDVR_L_x4_SR_REDS.yml +++ b/options/train/EDVR/train_EDVR_L_x4_SR_REDS.yml @@ -63,8 +63,8 @@ network_g: # path path: - pretrain_model_g: experiments/103_EDVR_L_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth - strict_load: false + pretrain_network_g: experiments/103_EDVR_L_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/EDVR/train_EDVR_L_x4_SR_REDS_woTSA.yml b/options/train/EDVR/train_EDVR_L_x4_SR_REDS_woTSA.yml index becefe2..32645a4 100644 --- a/options/train/EDVR/train_EDVR_L_x4_SR_REDS_woTSA.yml +++ b/options/train/EDVR/train_EDVR_L_x4_SR_REDS_woTSA.yml @@ -63,8 +63,8 @@ network_g: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/EDVR/train_EDVR_M_x4_SR_REDS.yml b/options/train/EDVR/train_EDVR_M_x4_SR_REDS.yml index c463310..d79c8d3 100644 --- a/options/train/EDVR/train_EDVR_M_x4_SR_REDS.yml +++ b/options/train/EDVR/train_EDVR_M_x4_SR_REDS.yml @@ -63,8 +63,8 @@ network_g: # path path: - pretrain_model_g: experiments/101_EDVR_M_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth - strict_load: false + pretrain_network_g: experiments/101_EDVR_M_x4_SR_REDS_woTSA_600k_B4G8_valREDS4_wandb/models/net_g_600000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml b/options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml index bb8dba3..75552e9 100644 --- a/options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml +++ b/options/train/EDVR/train_EDVR_M_x4_SR_REDS_woTSA.yml @@ -63,8 +63,8 @@ network_g: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/ESRGAN/train_ESRGAN_x4.yml b/options/train/ESRGAN/train_ESRGAN_x4.yml index 23acff2..057de06 100644 --- a/options/train/ESRGAN/train_ESRGAN_x4.yml +++ b/options/train/ESRGAN/train_ESRGAN_x4.yml @@ -55,8 +55,8 @@ network_d: # path path: - pretrain_model_g: experiments/051_RRDBNet_PSNR_x4_f64b23_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth - strict_load: true + pretrain_network_g: experiments/051_RRDBNet_PSNR_x4_f64b23_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + strict_load_g: true resume_state: ~ # training settings @@ -91,9 +91,9 @@ train: 'conv5_4': 1 # before relu vgg_type: vgg19 use_input_norm: true + range_norm: false perceptual_weight: 1.0 style_weight: 0 - norm_img: false criterion: l1 gan_opt: type: GANLoss diff --git a/options/train/ESRGAN/train_RRDBNet_PSNR_x4.yml b/options/train/ESRGAN/train_RRDBNet_PSNR_x4.yml index a4ede70..c5882c8 100644 --- a/options/train/ESRGAN/train_RRDBNet_PSNR_x4.yml +++ b/options/train/ESRGAN/train_RRDBNet_PSNR_x4.yml @@ -51,8 +51,8 @@ network_g: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/RCAN/train_RCAN_x2.yml b/options/train/RCAN/train_RCAN_x2.yml index c525c0d..531b142 100644 --- a/options/train/RCAN/train_RCAN_x2.yml +++ b/options/train/RCAN/train_RCAN_x2.yml @@ -57,8 +57,8 @@ network_g: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml b/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml index 978b28e..e3681f2 100644 --- a/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml +++ b/options/train/SRResNet_SRGAN/train_MSRGAN_x4.yml @@ -60,8 +60,8 @@ network_d: # path path: - pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth - strict_load: true + pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + strict_load_g: true resume_state: ~ # training settings @@ -96,9 +96,9 @@ train: 'conv5_4': 1 # before relu vgg_type: vgg19 use_input_norm: true + scale: false perceptual_weight: 1.0 style_weight: 0 - norm_img: false criterion: l1 gan_opt: type: GANLoss diff --git a/options/train/SRResNet_SRGAN/train_MSRResNet_x2.yml b/options/train/SRResNet_SRGAN/train_MSRResNet_x2.yml index f7e6014..3688a1a 100644 --- a/options/train/SRResNet_SRGAN/train_MSRResNet_x2.yml +++ b/options/train/SRResNet_SRGAN/train_MSRResNet_x2.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth - strict_load: false + pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/SRResNet_SRGAN/train_MSRResNet_x3.yml b/options/train/SRResNet_SRGAN/train_MSRResNet_x3.yml index 9b94d29..5c414ad 100644 --- a/options/train/SRResNet_SRGAN/train_MSRResNet_x3.yml +++ b/options/train/SRResNet_SRGAN/train_MSRResNet_x3.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth - strict_load: false + pretrain_network_g: experiments/001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth + strict_load_g: false resume_state: ~ # training settings diff --git a/options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml b/options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml index 647b334..1fa782f 100644 --- a/options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml +++ b/options/train/SRResNet_SRGAN/train_MSRResNet_x4.yml @@ -54,8 +54,8 @@ network_g: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml b/options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml index 00b77ba..e112d44 100644 --- a/options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml +++ b/options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml @@ -42,8 +42,8 @@ network_d: # path path: - pretrain_model_g: ~ - strict_load: true + pretrain_network_g: ~ + strict_load_g: true resume_state: ~ # training settings diff --git a/requirements.txt b/requirements.txt index 8202611..c014cdc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,15 @@ addict future lmdb -matplotlib -mmcv>=0.6 numpy opencv-python +Pillow pyyaml +requests scikit-image scipy tb-nightly torch>=1.3 torchvision +tqdm yapf diff --git a/scripts/create_lmdb.py b/scripts/data_preparation/create_lmdb.py similarity index 87% rename from scripts/create_lmdb.py rename to scripts/data_preparation/create_lmdb.py index 4fa359b..e8eec3b 100644 --- a/scripts/create_lmdb.py +++ b/scripts/data_preparation/create_lmdb.py @@ -1,7 +1,8 @@ -import mmcv +import argparse from os import path as osp -from basicsr.utils.lmdb import make_lmdb_from_imgs +from basicsr.utils import scandir +from basicsr.utils.lmdb_util import make_lmdb_from_imgs def create_lmdb_for_div2k(): @@ -53,7 +54,7 @@ def prepare_keys_div2k(folder_path): """ print('Reading image path list ...') img_path_list = sorted( - list(mmcv.scandir(folder_path, suffix='png', recursive=False))) + list(scandir(folder_path, suffix='png', recursive=False))) keys = [img_path.split('.png')[0] for img_path in sorted(img_path_list)] return img_path_list, keys @@ -96,7 +97,7 @@ def prepare_keys_reds(folder_path): """ print('Reading image path list ...') img_path_list = sorted( - list(mmcv.scandir(folder_path, suffix='png', recursive=True))) + list(scandir(folder_path, suffix='png', recursive=True))) keys = [v.split('.png')[0] for v in img_path_list] # example: 000/00000000 return img_path_list, keys @@ -160,6 +161,22 @@ def prepare_keys_vimeo90k(folder_path, train_list_path, mode): if __name__ == '__main__': - create_lmdb_for_div2k() - # create_lmdb_for_reds() - # create_lmdb_for_vimeo90k() + parser = argparse.ArgumentParser() + + parser.add_argument( + '--dataset', + type=str, + help=( + "Options: 'DIV2K', 'REDS', 'Vimeo90K' " + 'You may need to modify the corresponding configurations in codes.' + )) + args = parser.parse_args() + dataset = args.dataset.lower() + if dataset == 'div2k': + create_lmdb_for_div2k() + elif dataset == 'reds': + create_lmdb_for_reds() + elif dataset == 'vimeo90k': + create_lmdb_for_vimeo90k() + else: + raise ValueError('Wrong dataset.') diff --git a/scripts/data_preparation/download_datasets.py b/scripts/data_preparation/download_datasets.py new file mode 100644 index 0000000..215e3c8 --- /dev/null +++ b/scripts/data_preparation/download_datasets.py @@ -0,0 +1,71 @@ +import argparse +import glob +import os +from os import path as osp + +from basicsr.utils.download_util import download_file_from_google_drive + + +def download_dataset(dataset, file_ids): + save_path_root = './datasets/' + os.makedirs(save_path_root, exist_ok=True) + + for file_name, file_id in file_ids.items(): + save_path = osp.abspath(osp.join(save_path_root, file_name)) + if osp.exists(save_path): + user_response = input( + f'{file_name} already exist. Do you want to cover it? Y/N\n') + if user_response.lower() == 'y': + print(f'Covering {file_name} to {save_path}') + download_file_from_google_drive(file_id, save_path) + elif user_response.lower() == 'n': + print(f'Skipping {file_name}') + else: + raise ValueError('Wrong input. Only accpets Y/N.') + else: + print(f'Downloading {file_name} to {save_path}') + download_file_from_google_drive(file_id, save_path) + + # unzip + if save_path.endswith('.zip'): + extracted_path = save_path.replace('.zip', '') + print(f'Extract {save_path} to {extracted_path}') + import zipfile + with zipfile.ZipFile(save_path, 'r') as zip_ref: + zip_ref.extractall(extracted_path) + + file_name = file_name.replace('.zip', '') + subfolder = osp.join(extracted_path, file_name) + if osp.isdir(subfolder): + print(f'Move {subfolder} to {extracted_path}') + import shutil + for path in glob.glob(osp.join(subfolder, '*')): + shutil.move(path, extracted_path) + shutil.rmtree(subfolder) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument( + 'dataset', + type=str, + help=("Options: 'Set5', 'Set14'. " + "Set to 'all' if you want to download all the dataset.")) + args = parser.parse_args() + + file_ids = { + 'Set5': { + 'Set5.zip': # file name + '1RtyIeUFTyW8u7oa4z7a0lSzT3T1FwZE9', # file id + }, + 'Set14': { + 'Set14.zip': '1vsw07sV8wGrRQ8UARe2fO5jjgy9QJy_E', + } + } + + if args.dataset == 'all': + for dataset in file_ids.keys(): + download_dataset(dataset, file_ids[dataset]) + else: + download_dataset(args.dataset, file_ids[args.dataset]) diff --git a/scripts/data_preparation/extract_images_from_tfrecords.py b/scripts/data_preparation/extract_images_from_tfrecords.py new file mode 100644 index 0000000..14a4f67 --- /dev/null +++ b/scripts/data_preparation/extract_images_from_tfrecords.py @@ -0,0 +1,235 @@ +import argparse +import cv2 +import glob +import numpy as np +import os + +from basicsr.utils.lmdb_util import LmdbMaker + + +def convert_celeba_tfrecords(tf_file, + log_resolution, + save_root, + save_type='img', + compress_level=1): + """Convert CelebA tfrecords to images or lmdb files. + + Args: + tf_file (str): Input tfrecords file in glob pattern. + Example: 'datasets/celeba/celeba_tfrecords/validation/validation-r08-s-*-of-*.tfrecords' # noqa:E501 + log_resolution (int): Log scale of resolution. + save_root (str): Path root to save. + save_type (str): Save type. Options: img | lmdb. Default: img. + compress_level (int): Compress level when encoding images. Default: 1. + """ + if 'validation' in tf_file: + phase = 'validation' + else: + phase = 'train' + if save_type == 'lmdb': + save_path = os.path.join(save_root, + f'celeba_{2**log_resolution}_{phase}.lmdb') + lmdb_maker = LmdbMaker(save_path) + elif save_type == 'img': + save_path = os.path.join(save_root, + f'celeba_{2**log_resolution}_{phase}') + else: + raise ValueError('Wrong save type.') + + os.makedirs(save_path, exist_ok=True) + + idx = 0 + for record in sorted(glob.glob(tf_file)): + print('Processing record: ', record) + record_iterator = tf.python_io.tf_record_iterator(record) + for string_record in record_iterator: + example = tf.train.Example() + example.ParseFromString(string_record) + + # label = example.features.feature['label'].int64_list.value[0] + # attr = example.features.feature['attr'].int64_list.value + # male = attr[20] + # young = attr[39] + + shape = example.features.feature['shape'].int64_list.value + h, w, c = shape + img_str = example.features.feature['data'].bytes_list.value[0] + img = np.fromstring(img_str, dtype=np.uint8).reshape((h, w, c)) + + img = img[:, :, [2, 1, 0]] + + if save_type == 'img': + cv2.imwrite(os.path.join(save_path, f'{idx:08d}.png'), img) + elif save_type == 'lmdb': + _, img_byte = cv2.imencode( + '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + key = f'{idx:08d}/r{log_resolution:02d}' + lmdb_maker.put(img_byte, key, (h, w, c)) + + idx += 1 + print(idx) + + if save_type == 'lmdb': + lmdb_maker.close() + + +def convert_ffhq_tfrecords(tf_file, + log_resolution, + save_root, + save_type='img', + compress_level=1): + """Convert FFHQ tfrecords to images or lmdb files. + + Args: + tf_file (str): Input tfrecords file. + log_resolution (int): Log scale of resolution. + save_root (str): Path root to save. + save_type (str): Save type. Options: img | lmdb. Default: img. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + if save_type == 'lmdb': + save_path = os.path.join(save_root, f'ffhq_{2**log_resolution}.lmdb') + lmdb_maker = LmdbMaker(save_path) + elif save_type == 'img': + save_path = os.path.join(save_root, f'ffhq_{2**log_resolution}') + else: + raise ValueError('Wrong save type.') + + os.makedirs(save_path, exist_ok=True) + + idx = 0 + for record in sorted(glob.glob(tf_file)): + print('Processing record: ', record) + record_iterator = tf.python_io.tf_record_iterator(record) + for string_record in record_iterator: + example = tf.train.Example() + example.ParseFromString(string_record) + + shape = example.features.feature['shape'].int64_list.value + c, h, w = shape + img_str = example.features.feature['data'].bytes_list.value[0] + img = np.fromstring(img_str, dtype=np.uint8).reshape((c, h, w)) + + img = img.transpose(1, 2, 0) + img = img[:, :, [2, 1, 0]] + if save_type == 'img': + cv2.imwrite(os.path.join(save_path, f'{idx:08d}.png'), img) + elif save_type == 'lmdb': + _, img_byte = cv2.imencode( + '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + key = f'{idx:08d}/r{log_resolution:02d}' + lmdb_maker.put(img_byte, key, (h, w, c)) + + idx += 1 + print(idx) + + if save_type == 'lmdb': + lmdb_maker.close() + + +def make_ffhq_lmdb_from_imgs(folder_path, + log_resolution, + save_root, + save_type='lmdb', + compress_level=1): + """Make FFHQ lmdb from images. + + Args: + folder_path (str): Folder path. + log_resolution (int): Log scale of resolution. + save_root (str): Path root to save. + save_type (str): Save type. Options: img | lmdb. Default: img. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + if save_type == 'lmdb': + save_path = os.path.join(save_root, + f'ffhq_{2**log_resolution}_crop1.2.lmdb') + lmdb_maker = LmdbMaker(save_path) + else: + raise ValueError('Wrong save type.') + + os.makedirs(save_path, exist_ok=True) + + img_list = sorted(glob.glob(os.path.join(folder_path, '*'))) + for idx, img_path in enumerate(img_list): + print(f'Processing {idx}: ', img_path) + img = cv2.imread(img_path) + h, w, c = img.shape + + if save_type == 'lmdb': + _, img_byte = cv2.imencode( + '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + key = f'{idx:08d}/r{log_resolution:02d}' + lmdb_maker.put(img_byte, key, (h, w, c)) + + if save_type == 'lmdb': + lmdb_maker.close() + + +if __name__ == '__main__': + """Read tfrecords w/o define a graph. + + We have tested it on TensorFlow 1.15 + + Ref: + http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/ + """ + parser = argparse.ArgumentParser() + parser.add_argument( + '--dataset', + type=str, + default='ffhq', + help="Dataset name. Options: 'ffhq' | 'celeba'. Default: 'ffhq'.") + parser.add_argument( + '--tf_file', + type=str, + default='datasets/ffhq/ffhq-r10.tfrecords', + help=( + 'Input tfrecords file. For celeba, it should be glob pattern. ' + 'Put quotes around the wildcard argument to prevent the shell ' + 'from expanding it.' + "Example: 'datasets/celeba/celeba_tfrecords/validation/validation-r08-s-*-of-*.tfrecords'" # noqa:E501 + )) + parser.add_argument( + '--log_resolution', + type=int, + default=10, + help='Log scale of resolution.') + parser.add_argument( + '--save_root', + type=str, + default='datasets/ffhq/', + help='Save root path.') + parser.add_argument( + '--save_type', + type=str, + default='img', + help="Save type. Options: 'img' | 'lmdb'. Default: 'img'.") + parser.add_argument( + '--compress_level', + type=int, + default=1, + help='Compress level when encoding images. Default: 1.') + args = parser.parse_args() + + try: + import tensorflow as tf + except Exception: + raise ImportError('You need to install tensorflow to read tfrecords.') + + if args.dataset == 'ffhq': + convert_ffhq_tfrecords( + args.tf_file, + args.log_resolution, + args.save_root, + save_type=args.save_type, + compress_level=args.compress_level) + else: + convert_celeba_tfrecords( + args.tf_file, + args.log_resolution, + args.save_root, + save_type=args.save_type, + compress_level=args.compress_level) diff --git a/scripts/extract_subimages.py b/scripts/data_preparation/extract_subimages.py similarity index 96% rename from scripts/extract_subimages.py rename to scripts/data_preparation/extract_subimages.py index 6cf06b1..6424e8d 100644 --- a/scripts/extract_subimages.py +++ b/scripts/data_preparation/extract_subimages.py @@ -1,12 +1,12 @@ import cv2 -import mmcv import numpy as np import os import sys from multiprocessing import Pool from os import path as osp +from tqdm import tqdm -from basicsr.utils.util import ProgressBar +from basicsr.utils import scandir def main(): @@ -94,16 +94,16 @@ def extract_subimages(opt): print(f'Folder {save_folder} already exists. Exit.') sys.exit(1) - img_list = list(mmcv.scandir(input_folder)) - img_list = [osp.join(input_folder, v) for v in img_list] + img_list = list(scandir(input_folder, full_path=True)) - pbar = ProgressBar(len(img_list)) + pbar = tqdm(total=len(img_list), unit='image', desc='Extract') pool = Pool(opt['n_thread']) for path in img_list: pool.apply_async( - worker, args=(path, opt), callback=lambda arg: pbar.update(arg)) + worker, args=(path, opt), callback=lambda arg: pbar.update(1)) pool.close() pool.join() + pbar.close() print('All processes done.') diff --git a/scripts/generate_meta_info.py b/scripts/data_preparation/generate_meta_info.py similarity index 91% rename from scripts/generate_meta_info.py rename to scripts/data_preparation/generate_meta_info.py index 22d851e..7bb1aed 100644 --- a/scripts/generate_meta_info.py +++ b/scripts/data_preparation/generate_meta_info.py @@ -1,7 +1,8 @@ -import mmcv from os import path as osp from PIL import Image +from basicsr.utils import scandir + def generate_meta_info_div2k(): """Generate meta info for DIV2K dataset. @@ -10,7 +11,7 @@ def generate_meta_info_div2k(): gt_folder = 'datasets/DIV2K/DIV2K_train_HR_sub/' meta_info_txt = 'basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt' - img_list = sorted(list(mmcv.scandir(gt_folder))) + img_list = sorted(list(scandir(gt_folder))) with open(meta_info_txt, 'w') as f: for idx, img_path in enumerate(img_list): diff --git a/scripts/regroup_reds_dataset.py b/scripts/data_preparation/regroup_reds_dataset.py similarity index 86% rename from scripts/regroup_reds_dataset.py rename to scripts/data_preparation/regroup_reds_dataset.py index 3ce71fa..7d3ddbf 100644 --- a/scripts/regroup_reds_dataset.py +++ b/scripts/data_preparation/regroup_reds_dataset.py @@ -18,8 +18,9 @@ def regroup_reds_dataset(train_path, val_path): # move the validation data to the train folder val_folders = glob.glob(os.path.join(val_path, '*')) for folder in val_folders: - new_folder_idx = int(folder.split(' / ')[-1]) + 240 - os.system(f'cp -r {folder} {os.path.join(train_path, new_folder_idx)}') + new_folder_idx = int(folder.split('/')[-1]) + 240 + os.system( + f'cp -r {folder} {os.path.join(train_path, str(new_folder_idx))}') if __name__ == '__main__': diff --git a/scripts/download_gdrive.py b/scripts/download_gdrive.py new file mode 100644 index 0000000..c3e34c7 --- /dev/null +++ b/scripts/download_gdrive.py @@ -0,0 +1,12 @@ +import argparse + +from basicsr.utils.download_util import download_file_from_google_drive + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--id', type=str, help='File id') + parser.add_argument('--output', type=str, help='Save path') + args = parser.parse_args() + + download_file_from_google_drive(args.id, args.save_path) diff --git a/scripts/download_pretrained_models.py b/scripts/download_pretrained_models.py index cc26218..3eb6911 100644 --- a/scripts/download_pretrained_models.py +++ b/scripts/download_pretrained_models.py @@ -1,19 +1,19 @@ import argparse -import mmcv +import os from os import path as osp -from basicsr.utils.download import download_file_from_google_drive +from basicsr.utils.download_util import download_file_from_google_drive def download_pretrained_models(method, file_ids): save_path_root = f'./experiments/pretrained_models/{method}' - mmcv.mkdir_or_exist(save_path_root) + os.makedirs(save_path_root, exist_ok=True) for file_name, file_id in file_ids.items(): save_path = osp.abspath(osp.join(save_path_root, file_name)) if osp.exists(save_path): user_response = input( - f'{file_name} already exist. Do you want to cover it? Y/N') + f'{file_name} already exist. Do you want to cover it? Y/N\n') if user_response.lower() == 'y': print(f'Covering {file_name} to {save_path}') download_file_from_google_drive(file_id, save_path) @@ -112,9 +112,7 @@ def download_pretrained_models(method, file_ids): 'DFDNet_dict_512-f79685f0.pth': '1iH00oMsoN_1OJaEQw3zP7_wqiAYMnY79', 'DFDNet_official-d1fa5650.pth': - '1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe', - 'FFHQ_5_landmarks_template_1024-90a00515.npy': - '1IQdQcq9QnpW6YzRwDaNbpV-rJ1Cq7RUq' + '1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe' }, 'dlib': { 'mmod_human_face_detector-4cb19393.dat': diff --git a/scripts/extract_images_from_tfrecords.py b/scripts/extract_images_from_tfrecords.py deleted file mode 100644 index 3ee902a..0000000 --- a/scripts/extract_images_from_tfrecords.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Read tfrecords w/o define a graph. - -Ref: -http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/ -""" - -import cv2 -import glob -import numpy as np -import os - -from basicsr.utils.lmdb import LmdbMaker - - -def celeba_tfrecords(): - # Configurations - file_pattern = '/home/xtwang/datasets/CelebA_tfrecords/celeba-full-tfr/train/train-r08-s-*-of-*.tfrecords' # noqa:E501 - # r08: resolution 2^8 = 256 - resolution = 128 - save_path = f'/home/xtwang/datasets/CelebA_tfrecords/tmptrain_{resolution}' - - save_all_path = os.path.join(save_path, f'all_{resolution}') - os.makedirs(save_all_path) - - idx = 0 - print(glob.glob(file_pattern)) - for record in glob.glob(file_pattern): - record_iterator = tf.python_io.tf_record_iterator(record) - for string_record in record_iterator: - example = tf.train.Example() - example.ParseFromString(string_record) - # label = example.features.feature['label'].int64_list.value[0] - - # attr = example.features.feature['attr'].int64_list.value - # male = attr[20] - # young = attr[39] - - shape = example.features.feature['shape'].int64_list.value - h, w, c = shape - img_str = example.features.feature['data'].bytes_list.value[0] - img = np.fromstring(img_str, dtype=np.uint8).reshape((h, w, c)) - - # save image - img = img[:, :, [2, 1, 0]] - cv2.imwrite(os.path.join(save_all_path, f'{idx:08d}.png'), img) - - idx += 1 - print(idx) - - -def ffhq_tfrecords(): - # Configurations - file_pattern = '/home/xtwang/datasets/ffhq/ffhq-r10.tfrecords' - resolution = 1024 - save_path = f'/home/xtwang/datasets/ffhq/ffhq_imgs/ffhq_{resolution}' - - os.makedirs(save_path, exist_ok=True) - idx = 0 - print(glob.glob(file_pattern)) - for record in glob.glob(file_pattern): - record_iterator = tf.python_io.tf_record_iterator(record) - for string_record in record_iterator: - example = tf.train.Example() - example.ParseFromString(string_record) - - shape = example.features.feature['shape'].int64_list.value - c, h, w = shape - img_str = example.features.feature['data'].bytes_list.value[0] - img = np.fromstring(img_str, dtype=np.uint8).reshape((c, h, w)) - - # save image - img = img.transpose(1, 2, 0) - img = img[:, :, [2, 1, 0]] - cv2.imwrite(os.path.join(save_path, f'{idx:08d}.png'), img) - - idx += 1 - print(idx) - - -def ffhq_tfrecords_to_lmdb(): - # Configurations - file_pattern = '/home/xtwang/datasets/ffhq/ffhq-r10.tfrecords' - log_resolution = 10 - compress_level = 1 - lmdb_path = f'/home/xtwang/datasets/ffhq/ffhq_{2**log_resolution}.lmdb' - - idx = 0 - print(glob.glob(file_pattern)) - - lmdb_maker = LmdbMaker(lmdb_path) - for record in glob.glob(file_pattern): - record_iterator = tf.python_io.tf_record_iterator(record) - for string_record in record_iterator: - example = tf.train.Example() - example.ParseFromString(string_record) - - shape = example.features.feature['shape'].int64_list.value - c, h, w = shape - img_str = example.features.feature['data'].bytes_list.value[0] - img = np.fromstring(img_str, dtype=np.uint8).reshape((c, h, w)) - - # write image to lmdb - img = img.transpose(1, 2, 0) - img = img[:, :, [2, 1, 0]] - _, img_byte = cv2.imencode( - '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) - key = f'{idx:08d}/r{log_resolution:02d}' - lmdb_maker.put(img_byte, key, (h, w, c)) - - idx += 1 - print(key) - lmdb_maker.close() - - -if __name__ == '__main__': - # we have test on TensorFlow 1.15 - try: - import tensorflow as tf - except Exception: - raise ImportError('You need to install tensorflow to read tfrecords.') - # celeba_tfrecords() - # ffhq_tfrecords() - ffhq_tfrecords_to_lmdb() diff --git a/scripts/metrics/calculate_fid_folder.py b/scripts/metrics/calculate_fid_folder.py new file mode 100644 index 0000000..b903160 --- /dev/null +++ b/scripts/metrics/calculate_fid_folder.py @@ -0,0 +1,83 @@ +import argparse +import math +import numpy as np +import torch +from torch.utils.data import DataLoader + +from basicsr.data import create_dataset +from basicsr.metrics.fid import (calculate_fid, extract_inception_features, + load_patched_inception_v3) + + +def calculate_fid_folder(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + parser = argparse.ArgumentParser() + parser.add_argument('folder', type=str, help='Path to the folder.') + parser.add_argument( + '--fid_stats', type=str, help='Path to the dataset fid statistics.') + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--num_sample', type=int, default=50000) + parser.add_argument('--num_workers', type=int, default=4) + parser.add_argument( + '--backend', + type=str, + default='disk', + help='io backend for dataset. Option: disk, lmdb') + args = parser.parse_args() + + # inception model + inception = load_patched_inception_v3(device) + + # create dataset + opt = {} + opt['name'] = 'SingleImageDataset' + opt['type'] = 'SingleImageDataset' + opt['dataroot_lq'] = args.folder + opt['io_backend'] = dict(type=args.backend) + opt['mean'] = [0.5, 0.5, 0.5] + opt['std'] = [0.5, 0.5, 0.5] + dataset = create_dataset(opt) + + # create dataloader + data_loader = DataLoader( + dataset=dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + sampler=None, + drop_last=False) + args.num_sample = min(args.num_sample, len(dataset)) + total_batch = math.ceil(args.num_sample / args.batch_size) + + def data_generator(data_loader, total_batch): + for idx, data in enumerate(data_loader): + if idx >= total_batch: + break + else: + yield data['lq'] + + features = extract_inception_features( + data_generator(data_loader, total_batch), inception, total_batch, + device) + features = features.numpy() + total_len = features.shape[0] + features = features[:args.num_sample] + print(f'Extracted {total_len} features, ' + f'use the first {features.shape[0]} features to calculate stats.') + + sample_mean = np.mean(features, 0) + sample_cov = np.cov(features, rowvar=False) + + # load the dataset stats + stats = torch.load(args.fid_stats) + real_mean = stats['mean'] + real_cov = stats['cov'] + + # calculate FID metric + fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov) + print('fid:', fid) + + +if __name__ == '__main__': + calculate_fid_folder() diff --git a/scripts/metrics/calculate_fid_stats_from_datasets.py b/scripts/metrics/calculate_fid_stats_from_datasets.py new file mode 100644 index 0000000..8b61f5c --- /dev/null +++ b/scripts/metrics/calculate_fid_stats_from_datasets.py @@ -0,0 +1,72 @@ +import argparse +import math +import numpy as np +import torch +from torch.utils.data import DataLoader + +from basicsr.data import create_dataset +from basicsr.metrics.fid import (extract_inception_features, + load_patched_inception_v3) + + +def calculate_stats_from_dataset(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + parser = argparse.ArgumentParser() + parser.add_argument('--num_sample', type=int, default=50000) + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--size', type=int, default=512) + parser.add_argument('--dataroot', type=str, default='datasets/ffhq') + args = parser.parse_args() + + # inception model + inception = load_patched_inception_v3(device) + + # create dataset + opt = {} + opt['name'] = 'FFHQ' + opt['type'] = 'FFHQDataset' + opt['dataroot_gt'] = f'datasets/ffhq/ffhq_{args.size}.lmdb' + opt['io_backend'] = dict(type='lmdb') + opt['use_hflip'] = False + opt['mean'] = [0.5, 0.5, 0.5] + opt['std'] = [0.5, 0.5, 0.5] + dataset = create_dataset(opt) + + # create dataloader + data_loader = DataLoader( + dataset=dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=4, + sampler=None, + drop_last=False) + total_batch = math.ceil(args.num_sample / args.batch_size) + + def data_generator(data_loader, total_batch): + for idx, data in enumerate(data_loader): + if idx >= total_batch: + break + else: + yield data['gt'] + + features = extract_inception_features( + data_generator(data_loader, total_batch), inception, total_batch, + device) + features = features.numpy() + total_len = features.shape[0] + features = features[:args.num_sample] + print(f'Extracted {total_len} features, ' + f'use the first {features.shape[0]} features to calculate stats.') + mean = np.mean(features, 0) + cov = np.cov(features, rowvar=False) + + save_path = f'inception_{opt["name"]}_{args.size}.pth' + torch.save( + dict(name=opt['name'], size=args.size, mean=mean, cov=cov), + save_path, + _use_new_zipfile_serialization=False) + + +if __name__ == '__main__': + calculate_stats_from_dataset() diff --git a/scripts/metrics/calculate_lpips.py b/scripts/metrics/calculate_lpips.py new file mode 100644 index 0000000..d9fbd3c --- /dev/null +++ b/scripts/metrics/calculate_lpips.py @@ -0,0 +1,56 @@ +import cv2 +import glob +import numpy as np +import os.path as osp +from torchvision.transforms.functional import normalize + +from basicsr.utils import img2tensor + +try: + import lpips +except ImportError: + print('Please install lpips: pip install lpips') + + +def main(): + # Configurations + # ------------------------------------------------------------------------- + folder_gt = 'datasets/celeba/celeba_512_validation' + folder_restored = 'datasets/celeba/celeba_512_validation_lq' + # crop_border = 4 + suffix = '' + # ------------------------------------------------------------------------- + loss_fn_vgg = lpips.LPIPS(net='vgg').cuda() # RGB, normalized to [-1,1] + lpips_all = [] + img_list = sorted(glob.glob(osp.join(folder_gt, '*'))) + + mean = [0.5, 0.5, 0.5] + std = [0.5, 0.5, 0.5] + for i, img_path in enumerate(img_list): + basename, ext = osp.splitext(osp.basename(img_path)) + img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype( + np.float32) / 255. + img_restored = cv2.imread( + osp.join(folder_restored, basename + suffix + ext), + cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + + img_gt, img_restored = img2tensor([img_gt, img_restored], + bgr2rgb=True, + float32=True) + # norm to [-1, 1] + normalize(img_gt, mean, std, inplace=True) + normalize(img_restored, mean, std, inplace=True) + + # calculate lpips + lpips_val = loss_fn_vgg( + img_restored.unsqueeze(0).cuda(), + img_gt.unsqueeze(0).cuda()) + + print(f'{i+1:3d}: {basename:25}. \tLPIPS: {lpips_val:.6f}.') + lpips_all.append(lpips_val) + + print(f'Average: LPIPS: {sum(lpips_all) / len(lpips_all):.6f}') + + +if __name__ == '__main__': + main() diff --git a/scripts/calculate_psnr_ssim.py b/scripts/metrics/calculate_psnr_ssim.py similarity index 57% rename from scripts/calculate_psnr_ssim.py rename to scripts/metrics/calculate_psnr_ssim.py index 7e802d1..1a14af5 100644 --- a/scripts/calculate_psnr_ssim.py +++ b/scripts/metrics/calculate_psnr_ssim.py @@ -1,8 +1,10 @@ -import mmcv +import cv2 import numpy as np from os import path as osp from basicsr.metrics import calculate_psnr, calculate_ssim +from basicsr.utils import scandir +from basicsr.utils.matlab_functions import bgr2ycbcr def main(): @@ -23,11 +25,12 @@ def main(): crop_border = 4 suffix = '_expname' test_y_channel = False + correct_mean_var = False # ------------------------------------------------------------------------- psnr_all = [] ssim_all = [] - img_list = sorted(mmcv.scandir(folder_gt, recursive=True)) + img_list = sorted(scandir(folder_gt, recursive=True, full_path=True)) if test_y_channel: print('Testing Y channel.') @@ -36,16 +39,35 @@ def main(): for i, img_path in enumerate(img_list): basename, ext = osp.splitext(osp.basename(img_path)) - img_gt = mmcv.imread( - osp.join(folder_gt, img_path), flag='unchanged').astype( - np.float32) / 255. - img_restored = mmcv.imread( + img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype( + np.float32) / 255. + img_restored = cv2.imread( osp.join(folder_restored, basename + suffix + ext), - flag='unchanged').astype(np.float32) / 255. + cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + + if correct_mean_var: + mean_l = [] + std_l = [] + for j in range(3): + mean_l.append(np.mean(img_gt[:, :, j])) + std_l.append(np.std(img_gt[:, :, j])) + for j in range(3): + # correct twice + mean = np.mean(img_restored[:, :, j]) + img_restored[:, :, + j] = img_restored[:, :, j] - mean + mean_l[j] + std = np.std(img_restored[:, :, j]) + img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j] + + mean = np.mean(img_restored[:, :, j]) + img_restored[:, :, + j] = img_restored[:, :, j] - mean + mean_l[j] + std = np.std(img_restored[:, :, j]) + img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j] if test_y_channel and img_gt.ndim == 3 and img_gt.shape[2] == 3: - img_gt = mmcv.bgr2ycbcr(img_gt, y_only=True) - img_restored = mmcv.bgr2ycbcr(img_restored, y_only=True) + img_gt = bgr2ycbcr(img_gt, y_only=True) + img_restored = bgr2ycbcr(img_restored, y_only=True) # calculate PSNR and SSIM psnr = calculate_psnr( @@ -62,6 +84,8 @@ def main(): f'\tSSIM: {ssim:.6f}') psnr_all.append(psnr) ssim_all.append(ssim) + print(folder_gt) + print(folder_restored) print(f'Average: PSNR: {sum(psnr_all) / len(psnr_all):.6f} dB, ' f'SSIM: {sum(ssim_all) / len(ssim_all):.6f}') diff --git a/scripts/metrics/calculate_stylegan2_fid.py b/scripts/metrics/calculate_stylegan2_fid.py new file mode 100644 index 0000000..bd3acb1 --- /dev/null +++ b/scripts/metrics/calculate_stylegan2_fid.py @@ -0,0 +1,79 @@ +import argparse +import math +import numpy as np +import torch +from torch import nn + +from basicsr.metrics.fid import (calculate_fid, extract_inception_features, + load_patched_inception_v3) +from basicsr.models.archs.stylegan2_arch import StyleGAN2Generator + + +def calculate_stylegan2_fid(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + parser = argparse.ArgumentParser() + parser.add_argument( + 'ckpt', type=str, help='Path to the stylegan2 checkpoint.') + parser.add_argument( + 'fid_stats', type=str, help='Path to the dataset fid statistics.') + parser.add_argument('--size', type=int, default=256) + parser.add_argument('--channel_multiplier', type=int, default=2) + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--num_sample', type=int, default=50000) + parser.add_argument('--truncation', type=float, default=1) + parser.add_argument('--truncation_mean', type=int, default=4096) + args = parser.parse_args() + + # create stylegan2 model + generator = StyleGAN2Generator( + out_size=args.size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=args.channel_multiplier, + resample_kernel=(1, 3, 3, 1)) + generator.load_state_dict(torch.load(args.ckpt)['params_ema']) + generator = nn.DataParallel(generator).eval().to(device) + + if args.truncation < 1: + with torch.no_grad(): + truncation_latent = generator.mean_latent(args.truncation_mean) + else: + truncation_latent = None + + # inception model + inception = load_patched_inception_v3(device) + + total_batch = math.ceil(args.num_sample / args.batch_size) + + def sample_generator(total_batch): + for i in range(total_batch): + with torch.no_grad(): + latent = torch.randn(args.batch_size, 512, device=device) + samples, _ = generator([latent], + truncation=args.truncation, + truncation_latent=truncation_latent) + yield samples + + features = extract_inception_features( + sample_generator(total_batch), inception, total_batch, device) + features = features.numpy() + total_len = features.shape[0] + features = features[:args.num_sample] + print(f'Extracted {total_len} features, ' + f'use the first {features.shape[0]} features to calculate stats.') + sample_mean = np.mean(features, 0) + sample_cov = np.cov(features, rowvar=False) + + # load the dataset stats + stats = torch.load(args.fid_stats) + real_mean = stats['mean'] + real_cov = stats['cov'] + + # calculate FID metric + fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov) + print('fid:', fid) + + +if __name__ == '__main__': + calculate_stylegan2_fid() diff --git a/scripts/convert_dfdnet.py b/scripts/model_conversion/convert_dfdnet.py similarity index 100% rename from scripts/convert_dfdnet.py rename to scripts/model_conversion/convert_dfdnet.py diff --git a/scripts/convert_models.py b/scripts/model_conversion/convert_models.py similarity index 100% rename from scripts/convert_models.py rename to scripts/model_conversion/convert_models.py diff --git a/scripts/convert_stylegan.py b/scripts/model_conversion/convert_stylegan.py similarity index 100% rename from scripts/convert_stylegan.py rename to scripts/model_conversion/convert_stylegan.py diff --git a/scripts/publish_models.py b/scripts/publish_models.py index ea2b5f4..ea4ae79 100644 --- a/scripts/publish_models.py +++ b/scripts/publish_models.py @@ -53,6 +53,7 @@ def convert_to_backward_compatible_models(paths): if __name__ == '__main__': - paths = glob.glob('experiments/pretrained_models/*.pth') + paths = glob.glob('experiments/pretrained_models/*.pth') + glob.glob( + 'experiments/pretrained_models/**/*.pth') convert_to_backward_compatible_models(paths) update_sha(paths) diff --git a/setup.cfg b/setup.cfg index dccb00b..ae5a6eb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,6 +16,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = basicsr -known_third_party = PIL,cv2,lmdb,matplotlib,mmcv,numpy,requests,scipy,skimage,torch,torchvision,yaml +known_third_party = PIL,cv2,lmdb,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/setup.py b/setup.py index 0a339ff..621007f 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ import os import subprocess +import sys import time import torch from torch.utils.cpp_extension import (BuildExtension, CppExtension, @@ -85,8 +86,9 @@ def get_version(): return locals()['__version__'] -def make_cuda_ext(name, module, sources, sources_cuda=[]): - +def make_cuda_ext(name, module, sources, sources_cuda=None): + if sources_cuda is None: + sources_cuda = [] define_macros = [] extra_compile_args = {'cxx': []} @@ -118,6 +120,31 @@ def get_requirements(filename='requirements.txt'): if __name__ == '__main__': + if '--no_cuda_ext' in sys.argv: + ext_modules = [] + sys.argv.remove('--no_cuda_ext') + else: + ext_modules = [ + make_cuda_ext( + name='deform_conv_ext', + module='basicsr.models.ops.dcn', + sources=['src/deform_conv_ext.cpp'], + sources_cuda=[ + 'src/deform_conv_cuda.cpp', + 'src/deform_conv_cuda_kernel.cu' + ]), + make_cuda_ext( + name='fused_act_ext', + module='basicsr.models.ops.fused_act', + sources=['src/fused_bias_act.cpp'], + sources_cuda=['src/fused_bias_act_kernel.cu']), + make_cuda_ext( + name='upfirdn2d_ext', + module='basicsr.models.ops.upfirdn2d', + sources=['src/upfirdn2d.cpp'], + sources_cuda=['src/upfirdn2d_kernel.cu']), + ] + write_version_py() setup( name='basicsr', @@ -142,25 +169,6 @@ def get_requirements(filename='requirements.txt'): license='Apache License 2.0', setup_requires=['cython', 'numpy'], install_requires=get_requirements(), - ext_modules=[ - make_cuda_ext( - name='deform_conv_ext', - module='basicsr.models.ops.dcn', - sources=['src/deform_conv_ext.cpp'], - sources_cuda=[ - 'src/deform_conv_cuda.cpp', - 'src/deform_conv_cuda_kernel.cu' - ]), - make_cuda_ext( - name='fused_act_ext', - module='basicsr.models.ops.fused_act', - sources=['src/fused_bias_act.cpp'], - sources_cuda=['src/fused_bias_act_kernel.cu']), - make_cuda_ext( - name='upfirdn2d_ext', - module='basicsr.models.ops.upfirdn2d', - sources=['src/upfirdn2d.cpp'], - sources_cuda=['src/upfirdn2d_kernel.cu']), - ], + ext_modules=ext_modules, cmdclass={'build_ext': BuildExtension}, zip_safe=False) diff --git a/tests/test_face_dfdnet.py b/tests/test_face_dfdnet.py deleted file mode 100644 index ab2d6db..0000000 --- a/tests/test_face_dfdnet.py +++ /dev/null @@ -1,353 +0,0 @@ -import argparse -import cv2 -import glob -import mmcv -import numpy as np -import os -import torch -import torchvision.transforms as transforms -from skimage import io -from skimage import transform as trans - -from basicsr.models.archs.dfdnet_arch import DFDNet -from basicsr.utils import tensor2img - -try: - import dlib -except ImportError: - print('Please install dlib before testing face restoration.' - 'Reference: https://github.com/davisking/dlib') - - -class FaceRestorationHelper(object): - """Helper for the face restoration pipeline.""" - - def __init__(self, upscale_factor, face_template_path, out_size=512): - self.upscale_factor = upscale_factor - self.out_size = (out_size, out_size) - - # standard 5 landmarks for FFHQ faces with 1024 x 1024 - self.face_template = np.load(face_template_path) / (1024 // out_size) - # for estimation the 2D similarity transformation - self.similarity_trans = trans.SimilarityTransform() - - self.all_landmarks_5 = [] - self.all_landmarks_68 = [] - self.affine_matrices = [] - self.inverse_affine_matrices = [] - self.cropped_faces = [] - self.restored_faces = [] - - def init_dlib(self, detection_path, landmark5_path, landmark68_path): - """Initialize the dlib detectors and predictors.""" - self.face_detector = dlib.cnn_face_detection_model_v1(detection_path) - self.shape_predictor_5 = dlib.shape_predictor(landmark5_path) - self.shape_predictor_68 = dlib.shape_predictor(landmark68_path) - - def free_dlib_gpu_memory(self): - del self.face_detector - del self.shape_predictor_5 - del self.shape_predictor_68 - - def read_input_image(self, img_path): - # self.input_img is Numpy array, (h, w, c) with RGB order - self.input_img = dlib.load_rgb_image(img_path) - - def detect_faces(self, img_path, upsample_num_times=1): - """ - Args: - img_path (str): Image path. - upsample_num_times (int): Upsamples the image before running the - face detector - - Returns: - int: Number of detected faces. - """ - self.read_input_image(img_path) - self.det_faces = self.face_detector(self.input_img, upsample_num_times) - if len(self.det_faces) == 0: - print('No face detected. Try to increase upsample_num_times.') - return len(self.det_faces) - - def get_face_landmarks_5(self): - for face in self.det_faces: - shape = self.shape_predictor_5(self.input_img, face.rect) - landmark = np.array([[part.x, part.y] for part in shape.parts()]) - self.all_landmarks_5.append(landmark) - return len(self.all_landmarks_5) - - def get_face_landmarks_68(self): - """Get 68 densemarks for cropped images. - - Should only have one face at most in the cropped image. - """ - num_detected_face = 0 - for idx, face in enumerate(self.cropped_faces): - # face detection - det_face = self.face_detector(face, 1) # TODO: can we remove it - if len(det_face) == 0: - print(f'Cannot find faces in cropped image with index {idx}.') - self.all_landmarks_68.append(None) - elif len(det_face) == 1: - shape = self.shape_predictor_68(face, det_face[0].rect) - landmark = np.array([[part.x, part.y] - for part in shape.parts()]) - self.all_landmarks_68.append(landmark) - num_detected_face += 1 - else: - print('Should only have one face at most.') - return num_detected_face - - def warp_crop_faces(self, save_cropped_path=None): - """Get affine matrix, warp and cropped faces. - - Also get inverse affine matrix for post-processing. - """ - for idx, landmark in enumerate(self.all_landmarks_5): - # use 5 landmarks to get affine matrix - self.similarity_trans.estimate(landmark, self.face_template) - affine_matrix = self.similarity_trans.params[0:2, :] - self.affine_matrices.append(affine_matrix) - # warp and crop faces - cropped_face = cv2.warpAffine(self.input_img, affine_matrix, - self.out_size) - self.cropped_faces.append(cropped_face) - # save the cropped face - if save_cropped_path is not None: - path, ext = os.path.splitext(save_cropped_path) - save_path = f'{path}_{idx:02d}{ext}' - mmcv.imwrite(mmcv.rgb2bgr(cropped_face), save_path) - - # get inverse affine matrix - self.similarity_trans.estimate(self.face_template, - landmark * self.upscale_factor) - inverse_affine = self.similarity_trans.params[0:2, :] - self.inverse_affine_matrices.append(inverse_affine) - - def add_restored_face(self, face): - self.restored_faces.append(face) - - def paste_faces_to_input_image(self, save_path): - # operate in the BGR order - input_img = mmcv.rgb2bgr(self.input_img) - h, w, _ = input_img.shape - h_up, w_up = h * self.upscale_factor, w * self.upscale_factor - # simply resize the background - upsample_img = cv2.resize(input_img, (w_up, h_up)) - for restored_face, inverse_affine in zip(self.restored_faces, - self.inverse_affine_matrices): - inv_restored = cv2.warpAffine(restored_face, inverse_affine, - (w_up, h_up)) - mask = np.ones((*self.out_size, 3), dtype=np.float32) - inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) - # remove the black borders - inv_mask_erosion = cv2.erode( - inv_mask, - np.ones((2 * self.upscale_factor, 2 * self.upscale_factor), - np.uint8)) - inv_restored_remove_border = inv_mask_erosion * inv_restored - total_face_area = np.sum(inv_mask_erosion) // 3 - # compute the fusion edge based on the area of face - w_edge = int(total_face_area**0.5) // 20 - erosion_radius = w_edge * 2 - inv_mask_center = cv2.erode( - inv_mask_erosion, - np.ones((erosion_radius, erosion_radius), np.uint8)) - blur_size = w_edge * 2 - inv_soft_mask = cv2.GaussianBlur(inv_mask_center, - (blur_size + 1, blur_size + 1), 0) - upsample_img = inv_soft_mask * inv_restored_remove_border + ( - 1 - inv_soft_mask) * upsample_img - mmcv.imwrite(upsample_img.astype(np.uint8), save_path) - - def clean_all(self): - self.all_landmarks_5 = [] - self.all_landmarks_68 = [] - self.restored_faces = [] - self.affine_matrices = [] - self.cropped_faces = [] - self.inverse_affine_matrices = [] - - -def get_part_location(landmarks): - """Get part locations from landmarks.""" - map_left_eye = list(np.hstack((range(17, 22), range(36, 42)))) - map_right_eye = list(np.hstack((range(22, 27), range(42, 48)))) - map_nose = list(range(29, 36)) - map_mouth = list(range(48, 68)) - - # left eye - mean_left_eye = np.mean(landmarks[map_left_eye], 0) # (x, y) - half_len_left_eye = np.max((np.max( - np.max(landmarks[map_left_eye], 0) - - np.min(landmarks[map_left_eye], 0)) / 2, 16)) # A number - loc_left_eye = np.hstack((mean_left_eye - half_len_left_eye + 1, - mean_left_eye + half_len_left_eye)).astype(int) - loc_left_eye = torch.from_numpy(loc_left_eye).unsqueeze(0) - # (1, 4), the four numbers forms two coordinates in the diagonal - - # right eye - mean_right_eye = np.mean(landmarks[map_right_eye], 0) - half_len_right_eye = np.max((np.max( - np.max(landmarks[map_right_eye], 0) - - np.min(landmarks[map_right_eye], 0)) / 2, 16)) - loc_right_eye = np.hstack( - (mean_right_eye - half_len_right_eye + 1, - mean_right_eye + half_len_right_eye)).astype(int) - loc_right_eye = torch.from_numpy(loc_right_eye).unsqueeze(0) - # nose - mean_nose = np.mean(landmarks[map_nose], 0) - half_len_nose = np.max((np.max( - np.max(landmarks[map_nose], 0) - np.min(landmarks[map_nose], 0)) / 2, - 16)) # noqa: E126 - loc_nose = np.hstack( - (mean_nose - half_len_nose + 1, mean_nose + half_len_nose)).astype(int) - loc_nose = torch.from_numpy(loc_nose).unsqueeze(0) - # mouth - mean_mouth = np.mean(landmarks[map_mouth], 0) - half_len_mouth = np.max((np.max( - np.max(landmarks[map_mouth], 0) - np.min(landmarks[map_mouth], 0)) / 2, - 16)) # noqa: E126 - loc_mouth = np.hstack((mean_mouth - half_len_mouth + 1, - mean_mouth + half_len_mouth)).astype(int) - loc_mouth = torch.from_numpy(loc_mouth).unsqueeze(0) - - return loc_left_eye, loc_right_eye, loc_nose, loc_mouth - - -if __name__ == '__main__': - """We try to align to the official codes. But there are still slight - differences: 1) we use dlib for 68 landmark detection; 2) the used image - package are different (especially for reading and writing.) - """ - device = 'cuda' - parser = argparse.ArgumentParser() - - parser.add_argument('--upscale_factor', type=int, default=2) - parser.add_argument( - '--model_path', - type=str, - default= # noqa: E251 - 'experiments/pretrained_models/DFDNet/DFDNet_official-d1fa5650.pth') - parser.add_argument( - '--dict_path', - type=str, - default= # noqa: E251 - 'experiments/pretrained_models/DFDNet/DFDNet_dict_512-f79685f0.pth') - parser.add_argument('--test_path', type=str, default='datasets/TestWhole') - parser.add_argument('--upsample_num_times', type=int, default=1) - # The official codes use skimage.io to read the cropped images from disk - # instead of directly using the intermediate results in the memory (as we - # do). Such a different operation brings slight differences due to - # skimage.io. For aligning with the official results, we could set the - # official_adaption to True. - parser.add_argument('--official_adaption', type=bool, default=True) - - # The following are the paths for face template and dlib models - parser.add_argument( - '--face_template_path', - type=str, - default= # noqa: E251 - 'experiments/pretrained_models/DFDNet/FFHQ_5_landmarks_template_1024-90a00515.npy' # noqa: E501 - ) - parser.add_argument( - '--detection_path', - type=str, - default= # noqa: E251 - 'experiments/pretrained_models/dlib/mmod_human_face_detector-4cb19393.dat' # noqa: E501 - ) - parser.add_argument( - '--landmark5_path', - type=str, - default= # noqa: E251 - 'experiments/pretrained_models/dlib/shape_predictor_5_face_landmarks-c4b1e980.dat' # noqa: E501 - ) - parser.add_argument( - '--landmark68_path', - type=str, - default= # noqa: E251 - 'experiments/pretrained_models/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat' # noqa: E501 - ) - - args = parser.parse_args() - result_root = f'results/DFDNet/{args.test_path.split("/")[-1]}' - - # set up the DFDNet - net = DFDNet(64, dict_path=args.dict_path).to(device) - checkpoint = torch.load( - args.model_path, map_location=lambda storage, loc: storage) - net.load_state_dict(checkpoint['params']) - net.eval() - - save_crop_root = os.path.join(result_root, 'cropped_faces') - save_restore_root = os.path.join(result_root, 'restored_faces') - save_final_root = os.path.join(result_root, 'final_results') - - face_helper = FaceRestorationHelper( - args.upscale_factor, args.face_template_path, out_size=512) - - # scan all the jpg and png images - for img_path in glob.glob(os.path.join(args.test_path, '*.[jp][pn]g')): - img_name = os.path.basename(img_path) - print(f'Processing {img_name} image ...') - save_crop_path = os.path.join(save_crop_root, img_name) - - face_helper.init_dlib(args.detection_path, args.landmark5_path, - args.landmark68_path) - # detect faces - num_det_faces = face_helper.detect_faces( - img_path, upsample_num_times=args.upsample_num_times) - # get 5 face landmarks for each face - num_landmarks = face_helper.get_face_landmarks_5() - print(f'\tDetect {num_det_faces} faces, {num_landmarks} landmarks.') - # warp and crop each face - face_helper.warp_crop_faces(save_crop_path) - - if args.official_adaption: - path, ext = os.path.splitext(save_crop_path) - pathes = sorted(glob.glob(f'{path}_[0-9]*{ext}')) - cropped_faces = [io.imread(path) for path in pathes] - else: - cropped_faces = face_helper.cropped_faces - - # get 68 landmarks for each cropped face - num_landmarks = face_helper.get_face_landmarks_68() - print(f'\tDetect {num_landmarks} faces for 68 landmarks.') - - face_helper.free_dlib_gpu_memory() - - print('\tFace restoration ...') - # face restoration for each cropped face - for idx, (cropped_face, landmarks) in enumerate( - zip(cropped_faces, face_helper.all_landmarks_68)): - if landmarks is None: - print(f'Landmarks is None, skip cropped faces with idx {idx}.') - else: - # prepare data - part_locations = get_part_location(landmarks) - cropped_face = transforms.ToTensor()(cropped_face) - cropped_face = transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))( - cropped_face) - cropped_face = cropped_face.unsqueeze(0).to(device) - - with torch.no_grad(): - output = net(cropped_face, part_locations) - im = tensor2img(output, min_max=(-1, 1)) - del output - torch.cuda.empty_cache() - path, ext = os.path.splitext( - os.path.join(save_restore_root, img_name)) - save_path = f'{path}_{idx:02d}{ext}' - mmcv.imwrite(im, save_path) - face_helper.add_restored_face(im) - - print('\tGenerate the final result ...') - # paste each restored face to the input image - face_helper.paste_faces_to_input_image( - os.path.join(save_final_root, img_name)) - - # clean all the intermediate results to process the next image - face_helper.clean_all() - - print(f'\nAll results are saved in {result_root}') diff --git a/tests/test_ffhq_dataset.py b/tests/test_ffhq_dataset.py index 5486385..655e402 100644 --- a/tests/test_ffhq_dataset.py +++ b/tests/test_ffhq_dataset.py @@ -1,5 +1,5 @@ import math -import mmcv +import os import torch import torchvision.utils @@ -29,7 +29,7 @@ def main(): opt['dataset_enlarge_ratio'] = 1 - mmcv.mkdir_or_exist('tmp') + os.makedirs('tmp', exist_ok=True) dataset = create_dataset(opt) data_loader = create_dataloader( diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index 9562ffd..b9642d1 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -1,10 +1,14 @@ -import matplotlib as mpl import torch -from matplotlib import pyplot as plt -from matplotlib import ticker as mtick from basicsr.models.lr_scheduler import CosineAnnealingRestartLR +try: + import matplotlib as mpl + from matplotlib import pyplot as plt + from matplotlib import ticker as mtick +except ImportError: + print('Please install matplotlib.') + mpl.use('Agg') diff --git a/tests/test_paired_image_dataset.py b/tests/test_paired_image_dataset.py index 3c415a3..a133a36 100644 --- a/tests/test_paired_image_dataset.py +++ b/tests/test_paired_image_dataset.py @@ -1,5 +1,5 @@ import math -import mmcv +import os import torchvision.utils from basicsr.data import create_dataloader, create_dataset @@ -44,7 +44,7 @@ def main(mode='folder'): opt['dataset_enlarge_ratio'] = 1 - mmcv.mkdir_or_exist('tmp') + os.makedirs('tmp', exist_ok=True) dataset = create_dataset(opt) data_loader = create_dataloader( diff --git a/tests/test_reds_dataset.py b/tests/test_reds_dataset.py index 7863fe0..cbf23a6 100644 --- a/tests/test_reds_dataset.py +++ b/tests/test_reds_dataset.py @@ -1,5 +1,5 @@ import math -import mmcv +import os import torchvision.utils from basicsr.data import create_dataloader, create_dataset @@ -45,7 +45,7 @@ def main(mode='folder'): opt['dataset_enlarge_ratio'] = 1 - mmcv.mkdir_or_exist('tmp') + os.makedirs('tmp', exist_ok=True) dataset = create_dataset(opt) data_loader = create_dataloader( diff --git a/tests/test_vimeo90k_dataset.py b/tests/test_vimeo90k_dataset.py index 8a9661a..80bb45a 100644 --- a/tests/test_vimeo90k_dataset.py +++ b/tests/test_vimeo90k_dataset.py @@ -1,5 +1,5 @@ import math -import mmcv +import os import torchvision.utils from basicsr.data import create_dataloader, create_dataset @@ -41,7 +41,7 @@ def main(mode='folder'): opt['dataset_enlarge_ratio'] = 1 - mmcv.mkdir_or_exist('tmp') + os.makedirs('tmp', exist_ok=True) dataset = create_dataset(opt) data_loader = create_dataloader(