Skip to content

Commit

Permalink
save memory for zsm_my_video; formatting the code
Browse files Browse the repository at this point in the history
  • Loading branch information
yulunzhang committed Sep 5, 2021
1 parent 0d5ead5 commit a053e08
Show file tree
Hide file tree
Showing 24 changed files with 540 additions and 321 deletions.
79 changes: 52 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Zooming-Slow-Mo (CVPR-2020)

By [Xiaoyu Xiang<sup>*</sup>](https://engineering.purdue.edu/people/xiaoyu.xiang.1), [Yapeng Tian<sup>*</sup>](http://yapengtian.org/), [Yulun Zhang](http://yulunzhang.com/), [Yun Fu](http://www1.ece.neu.edu/~yunfu/), [Jan P. Allebach<sup>+</sup>](https://engineering.purdue.edu/~allebach/), [Chenliang Xu<sup>+</sup>](https://www.cs.rochester.edu/~cxu22/) (<sup>*</sup> equal contributions, <sup>+</sup> equal advising)

This is the official Pytorch implementation of *Zooming Slow-Mo: Fast and Accurate One-Stage Space-Time Video Super-Resolution*.
By [Xiaoyu Xiang<sup>\*</sup>](https://engineering.purdue.edu/people/xiaoyu.xiang.1), [Yapeng Tian<sup>\*</sup>](http://yapengtian.org/), [Yulun Zhang](http://yulunzhang.com/), [Yun Fu](http://www1.ece.neu.edu/~yunfu/), [Jan P. Allebach<sup>+</sup>](https://engineering.purdue.edu/~allebach/), [Chenliang Xu<sup>+</sup>](https://www.cs.rochester.edu/~cxu22/) (<sup>\*</sup> equal contributions, <sup>+</sup> equal advising)

This is the official Pytorch implementation of _Zooming Slow-Mo: Fast and Accurate One-Stage Space-Time Video Super-Resolution_.

#### [Paper](https://arxiv.org/abs/2002.11616) | [Journal Version](https://arxiv.org/abs/2104.07473) | [Demo (YouTube)](https://youtu.be/8mgD8JxBOus) | [1-min teaser (YouTube)](https://www.youtube.com/watch?v=C1o85AXUNl8) | [1-min teaser (Bilibili)](https://www.bilibili.com/video/BV1GK4y1t7nb/)

Expand All @@ -24,30 +23,31 @@ This is the official Pytorch implementation of *Zooming Slow-Mo: Fast and Accura
</tr>
</table>



## Updates

- 2020.3.13 Add meta-info of datasets used in this paper
- 2020.3.11 Add new function: video converter
- 2020.3.10: Upload the complete code and pretrained models

## Contents

0. [Introduction](#introduction)
0. [Prerequisites](#Prerequisites)
0. [Get Started](#Get-Started)
* [Installation](#Installation)
* [Training](#Training)
* [Testing](#Testing)
* [Colab Notebook](#Colab-Notebook)
0. [Citations](#citations)
0. [Contact](#Contact)
0. [License](#License)
0. [Acknowledgments](#Acknowledgments)
1. [Prerequisites](#Prerequisites)
2. [Get Started](#Get-Started)
- [Installation](#Installation)
- [Training](#Training)
- [Testing](#Testing)
- [Colab Notebook](#Colab-Notebook)
3. [Citations](#citations)
4. [Contact](#Contact)
5. [License](#License)
6. [Acknowledgments](#Acknowledgments)

## Introduction
The repository contains the entire project (including all the preprocessing) for one-stage space-time video super-resolution with Zooming Slow-Mo.

Zooming Slow-Mo is a recently proposed joint video frame interpolation (VFI) and video super-resolution (VSR) method, which directly synthesizes an HR slow-motion video from an LFR, LR video. It is going to be published in [CVPR 2020](http://cvpr2020.thecvf.com/). The most up-to-date paper with supplementary materials can be found at [arXiv](https://arxiv.org/abs/2002.11616).
The repository contains the entire project (including all the preprocessing) for one-stage space-time video super-resolution with Zooming Slow-Mo.

Zooming Slow-Mo is a recently proposed joint video frame interpolation (VFI) and video super-resolution (VSR) method, which directly synthesizes an HR slow-motion video from an LFR, LR video. It is going to be published in [CVPR 2020](http://cvpr2020.thecvf.com/). The most up-to-date paper with supplementary materials can be found at [arXiv](https://arxiv.org/abs/2002.11616).

In Zooming Slow-Mo, we firstly temporally interpolate features of the missing LR frame by the proposed feature temporal interpolation network. Then, we propose a deformable ConvLSTM to align and aggregate temporal information simultaneously. Finally, a deep reconstruction network is adopted to predict HR slow-motion video frames. If our proposed architectures also help your research, please consider citing our paper.

Expand All @@ -64,39 +64,47 @@ Zooming Slow-Mo achieves state-of-the-art performance by PSNR and SSIM in Vid4,
- Python packages: `pip install numpy opencv-python lmdb pyyaml pickle5 matplotlib seaborn`

## Get Started
### Installation
Install the required packages: ```pip install -r requirements.txt```

### Installation

Install the required packages: `pip install -r requirements.txt`

First, make sure your machine has a GPU, which is required for the DCNv2 module.

1. Clone the Zooming Slow-Mo repository. We'll call the directory that you cloned Zooming Slow-Mo as ZOOMING_ROOT.

```Shell
git clone --recursive https://github.com/Mukosame/Zooming-Slow-Mo-CVPR-2020.git
```

2. Compile the DCNv2:

```Shell
cd $ZOOMING_ROOT/codes/models/modules/DCNv2
bash make.sh # build
python test.py # run examples and gradient check
python test.py # run examples and gradient check
```
Please make sure the test script finishes successfully without any errors before running the following experiments.

Please make sure the test script finishes successfully without any errors before running the following experiments.

### Training

#### Part 1: Data Preparation

1. Download the original training + test set of `Vimeo-septuplet` (82 GB).

```Shell
wget http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip
wget http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip
apt-get install unzip
unzip vimeo_septuplet.zip
```

2. Split the `Vimeo-septuplet` into a training set and a test set, make sure you change the dataset's path to your download path in script, also you need to run for the training set and test set separately:

```Shell
cd $ZOOMING_ROOT/codes/data_scripts/sep_vimeo_list.py
```

This will create `train` and `test` folders in the directory of **`vimeo_septuplet/sequences`**. The folder structure is as follows:

```
Expand All @@ -122,14 +130,17 @@ run $ZOOMING_ROOT/codes/data_scripts/generate_LR_Vimeo90K.m
```

```Shell
python $ZOOMING_ROOT/codes/data_scripts/generate_mod_LR_bic.py
python $ZOOMING_ROOT/codes/data_scripts/generate_mod_LR_bic.py
```

4. Create the LMDB files for faster I/O speed. Note that you need to configure your input and output path in the following script:

```Shell
python $ZOOMING_ROOT/codes/data_scripts/create_lmdb_mp.py
```

The structure of generated lmdb folder is as follows:

```
Vimeo7_train.lmdb
├── data.mdb
Expand All @@ -138,6 +149,7 @@ Vimeo7_train.lmdb
```

#### Part 2: Train

**Note:** In this part, we assume you are in the directory **`$ZOOMING_ROOT/codes/`**

1. Configure your training settings that can be found at [options/train](./codes/options/train). Our training settings in the paper can be found at [train_zsm.yml](https://github.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/blob/master/codes/options/train/train_zsm.yml). We'll take this setting as an example to illustrate the following steps.
Expand All @@ -147,10 +159,12 @@ Vimeo7_train.lmdb
```Shell
python train.py -opt options/train/train_zsm.yml
```

After training, your model `xxxx_G.pth` and its training states, and a corresponding log file `train_LunaTokis_scratch_b16p32f5b40n7l1_600k_Vimeo_xxxx.log` are placed in the directory of `$ZOOMING_ROOT/experiments/LunaTokis_scratch_b16p32f5b40n7l1_600k_Vimeo/`.

### Testing
We provide the test code for both standard test sets (Vid4, SPMC, etc.) and custom video frames.

We provide the test code for both standard test sets (Vid4, SPMC, etc.) and custom video frames.

#### Pretrained Models

Expand All @@ -159,17 +173,20 @@ Our pretrained model can be downloaded via [GitHub](https://github.com/Mukosame/
#### From Video

If you have installed ffmpeg, you can convert any video to a high-resolution and high frame-rate video using [video_to_zsm.py](./codes/video_to_zsm.py). The corresponding commands are:

```Shell
cd $ZOOMING_ROOT/codes
python video_to_zsm.py --video PATH/TO/VIDEO.mp4 --model PATH/TO/PRETRAINED/MODEL.pth --output PATH/TO/OUTPUT.mp4
```

We also write the above commands to a Shell script, so you can directly run:

```Shell
bash zsm_my_video.sh
```

#### From Extracted Frames

As a quick start, we also provide some example images in the [test_example](./test_example) folder. You can test the model with the following commands:

```Shell
Expand All @@ -182,15 +199,20 @@ python test.py
- Your custom test results will be saved to a folder here: `$ZOOMING_ROOT/results/your_data_name/`.

#### Evaluate on Standard Test Sets

The [test.py](codes/test.py) script also provides modes for evaluation on the following test sets: `Vid4`, `SPMC`, etc. We evaluate PSNR and SSIM on the Y-channels in YCrCb color space. The commands are the same with the ones above. All you need to do is the change the data_mode and corresponding path of the standard test set.

### Colab Notebook

PyTorch Colab notebook (provided by [@HanClinto](https://github.com/HanClinto)): [HighResSlowMo.ipynb](https://gist.github.com/HanClinto/49219942f76d5f20990b6d048dbacbaf)

## Citations

If you find the code helpful in your resarch or work, please cite the following papers.

```BibTex
@misc{xiang2021zooming,
title={Zooming SlowMo: An Efficient One-Stage Framework for Space-Time Video Super-Resolution},
title={Zooming SlowMo: An Efficient One-Stage Framework for Space-Time Video Super-Resolution},
author={Xiang, Xiaoyu and Tian, Yapeng and Zhang, Yulun and Fu, Yun and Allebach, Jan P and Xu, Chenliang},
archivePrefix={arXiv},
eprint={2104.07473},
Expand Down Expand Up @@ -226,12 +248,15 @@ If you find the code helpful in your resarch or work, please cite the following
```

## Contact
[Xiaoyu Xiang](https://engineering.purdue.edu/people/xiaoyu.xiang.1) and [Yapeng Tian](http://yapengtian.org/).

[Xiaoyu Xiang](https://engineering.purdue.edu/people/xiaoyu.xiang.1) and [Yapeng Tian](http://yapengtian.org/).

You can also leave your questions as issues in the repository. We will be glad to answer them.

## License

This project is released under the [GNU General Public License v3.0](https://github.com/Mukosame/Zooming-Slow-Mo-CVPR-2020/blob/master/LICENSE).

## Acknowledgments

Our code is inspired by [TDAN-VSR](https://github.com/YapengTian/TDAN-VSR) and [EDVR](https://github.com/xinntao/EDVR).
60 changes: 37 additions & 23 deletions codes/data/Vimeo7_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, opt):
self.half_N_frames = opt['N_frames'] // 2
self.LR_N_frames = 1 + self.half_N_frames
assert self.LR_N_frames > 1, 'Error: Not enough LR frames to interpolate'
#### determine the LQ frame list
# determine the LQ frame list
'''
N | frames
1 | error
Expand All @@ -54,16 +54,17 @@ def __init__(self, opt):

self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
self.data_type = self.opt['data_type']
self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs
#### directly load image keys
# low resolution inputs
self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True
# directly load image keys
if opt['cache_keys']:
logger.info('Using cache keys: {}'.format(opt['cache_keys']))
cache_keys = opt['cache_keys']
else:
cache_keys = 'Vimeo7_train_keys.pkl'
logger.info('Using cache keys - {}.'.format(cache_keys))
self.paths_GT = pickle.load(open('./data/{}'.format(cache_keys), 'rb'))

assert self.paths_GT, 'Error: GT path is empty.'

if self.data_type == 'lmdb':
Expand Down Expand Up @@ -101,9 +102,12 @@ def _read_img_mc(self, path):

def _read_img_mc_BGR(self, path, name_a, name_b):
''' Read BGR channels separately and then combine for 1M limits in cluster'''
img_B = self._read_img_mc(osp.join(path + '_B', name_a, name_b + '.png'))
img_G = self._read_img_mc(osp.join(path + '_G', name_a, name_b + '.png'))
img_R = self._read_img_mc(osp.join(path + '_R', name_a, name_b + '.png'))
img_B = self._read_img_mc(
osp.join(path + '_B', name_a, name_b + '.png'))
img_G = self._read_img_mc(
osp.join(path + '_G', name_a, name_b + '.png'))
img_R = self._read_img_mc(
osp.join(path + '_R', name_a, name_b + '.png'))
img = cv2.merge((img_B, img_G, img_R))
return img

Expand All @@ -121,9 +125,9 @@ def __getitem__(self, index):
key = self.paths_GT['keys'][index]
name_a, name_b = key.split('_')

center_frame_idx = random.randint(2,6) # 2<= index <=6
center_frame_idx = random.randint(2, 6) # 2<= index <=6

#### determine the neighbor frames
# determine the neighbor frames
interval = random.choice(self.interval_list)
if self.opt['border_mode']:
direction = 1 # 1: forward; 0: backward
Expand Down Expand Up @@ -160,19 +164,22 @@ def __getitem__(self, index):
neighbor_list) == self.opt['N_frames'], 'Wrong length of neighbor list: {}'.format(
len(neighbor_list))

#### get the GT image (as the center frame)
# get the GT image (as the center frame)
img_GT_l = []
for v in neighbor_list:
if self.data_type == 'mc':
img_GT = self._read_img_mc_BGR(self.GT_root, name_a, name_b, '{}.png'.format(v))
img_GT = self._read_img_mc_BGR(
self.GT_root, name_a, name_b, '{}.png'.format(v))
img_GT = img_GT.astype(np.float32) / 255.
elif self.data_type == 'lmdb':
img_GT = util.read_img(self.GT_env, key + '_{}'.format(v), (3, 256, 448))
else:
img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b, 'im{}.png'.format(v)))
img_GT = util.read_img(
self.GT_env, key + '_{}'.format(v), (3, 256, 448))
else:
img_GT = util.read_img(None, osp.join(
self.GT_root, name_a, name_b, 'im{}.png'.format(v)))
img_GT_l.append(img_GT)
#### get LQ images

# get LQ images
LQ_size_tuple = (3, 64, 112) if self.LR_input else (3, 256, 448)
img_LQ_l = []
for v in self.LQ_frames_list:
Expand All @@ -181,7 +188,8 @@ def __getitem__(self, index):
osp.join(self.LQ_root, name_a, name_b, '/{}.png'.format(v)))
img_LQ = img_LQ.astype(np.float32) / 255.
elif self.data_type == 'lmdb':
img_LQ = util.read_img(self.LQ_env, key + '_{}'.format(v), LQ_size_tuple)
img_LQ = util.read_img(
self.LQ_env, key + '_{}'.format(v), LQ_size_tuple)
else:
img_LQ = util.read_img(None,
osp.join(self.LQ_root, name_a, name_b, 'im{}.png'.format(v)))
Expand All @@ -194,18 +202,23 @@ def __getitem__(self, index):
LQ_size = GT_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
img_LQ_l = [v[rnd_h:rnd_h + LQ_size,
rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_GT_l = [v[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :] for v in img_GT_l]
img_GT_l = [v[rnd_h_HR:rnd_h_HR + GT_size,
rnd_w_HR:rnd_w_HR + GT_size, :] for v in img_GT_l]
else:
rnd_h = random.randint(0, max(0, H - GT_size))
rnd_w = random.randint(0, max(0, W - GT_size))
img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
img_GT_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_GT_l]
img_LQ_l = [v[rnd_h:rnd_h + GT_size,
rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
img_GT_l = [v[rnd_h:rnd_h + GT_size,
rnd_w:rnd_w + GT_size, :] for v in img_GT_l]

# augmentation - flip, rotate
img_LQ_l = img_LQ_l + img_GT_l
rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
rlt = util.augment(
img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
img_LQ_l = rlt[0:-N_frames]
img_GT_l = rlt[-N_frames:]

Expand All @@ -215,7 +228,8 @@ def __getitem__(self, index):
# BGR to RGB, HWC to CHW, numpy to tensor
img_GTs = img_GTs[:, :, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GTs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GTs, (0, 3, 1, 2)))).float()
img_GTs = torch.from_numpy(np.ascontiguousarray(
np.transpose(img_GTs, (0, 3, 1, 2)))).float()
img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,
(0, 3, 1, 2)))).float()
return {'LQs': img_LQs, 'GT': img_GTs, 'key': key}
Expand Down
3 changes: 2 additions & 1 deletion codes/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def create_dataset(dataset_opt):
if mode == 'Vimeo7':
from data.Vimeo7_dataset import Vimeo7Dataset as D
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
raise NotImplementedError(
'Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt)

logger = logging.getLogger('base')
Expand Down
9 changes: 6 additions & 3 deletions codes/data/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,20 @@ class DistIterSampler(Sampler):
def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
raise RuntimeError(
"Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
raise RuntimeError(
"Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
self.num_samples = int(
math.ceil(len(self.dataset) * ratio / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas

def __iter__(self):
Expand Down
Loading

0 comments on commit a053e08

Please sign in to comment.