Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the code #58

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ python demo.py --config config/dataset_name.yaml --driving_video path/to/drivin
```
The result will be stored in ```result.mp4```. To use Animation via Disentaglemet add ```--mode avd```, for standard animation add ```--mode standard``` instead.

Checkpoints from google drive:
https://drive.google.com/drive/folders/1jCeFPqfU_wKNYwof0ONICwsj3xHlr_tb

### Colab Demo
We prepared a demo runnable in google-colab, see: ```demo.ipynb```.

Expand Down
105 changes: 105 additions & 0 deletions config/tiktok.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.
#No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,
#publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.
#Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,
#title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.
#In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.


dataset_params:
root_dir: /home/jupyter/data-lake/downloaded-videos
frame_shape: null
id_sampling: False
pairs_list: null
augmentation_params:
flip_param:
horizontal_flip: True
time_flip: True
jitter_param:
brightness: 0.1
contrast: 0.1
saturation: 0.1
hue: 0.1

model_params:
num_regions: 20
num_channels: 3
estimate_affine: True
revert_axis_swap: True

bg_predictor_params:
block_expansion: 32
max_features: 1024
num_blocks: 5
bg_type: 'affine'
region_predictor_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
pca_based: True
pad: 0
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
num_bottleneck_blocks: 6
skips: True
pixelwise_flow_predictor_params:
block_expansion: 64
max_features: 1024
num_blocks: 5
scale_factor: 0.25
use_deformed_source: True
use_covar_heatmap: True
estimate_occlusion_map: True
avd_network_params:
id_bottle_size: 64
pose_bottle_size: 64

train_params:
num_epochs: 100
num_repeats: 150
epoch_milestones: [60, 90]
lr: 2.0e-4
batch_size: 20
dataloader_workers: 6
checkpoint_freq: 50
use_sync_bn: False
scales: [1, 0.5, 0.25, 0.125]
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
perceptual: [10, 10, 10, 10, 10]
equivariance_shift: 10
equivariance_affine: 10

train_avd_params:
num_epochs: 30
num_repeats: 500
batch_size: 256
dataloader_workers: 12
checkpoint_freq: 50
epoch_milestones: [20, 25]
lr: 1.0e-3
lambda_shift: 1
lambda_affine: 1
random_scale: 0.25

reconstruction_params:
num_videos: 1000
format: '.mp4'

animate_params:
num_pairs: 50
format: '.mp4'
mode: 'avd'

visualizer_params:
kp_size: 5
draw_border: True
colormap: 'gist_rainbow'
region_bg_color: [1, 1, 1]
2 changes: 1 addition & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def make_animation(source_image, driving_video, generator, region_predictor, avd

def main(opt):
source_image = imageio.imread(opt.source_image)
reader = imageio.get_reader(opt.driving_video)
reader = imageio.get_reader(opt.driving_video, memtest=False)
fps = reader.get_meta_data()['fps']
reader.close()
driving_video = imageio.mimread(opt.driving_video, memtest=False)
Expand Down
2 changes: 1 addition & 1 deletion frames_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def read_video(name, frame_shape):
video_array = video_array.reshape((-1,) + frame_shape + (3, ))
video_array = np.moveaxis(video_array, 1, 2)
elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
video = mimread(name)
video = mimread(name, memtest=False)
if len(video[0].shape) == 2:
video = [gray2rgb(frame) for frame in video]
if frame_shape is not None:
Expand Down
4 changes: 2 additions & 2 deletions logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import imageio

import os
from skimage.draw import circle
from skimage.draw import disk

import matplotlib.pyplot as plt
import collections
Expand Down Expand Up @@ -146,7 +146,7 @@ def draw_image_with_kp(self, image, kp_array):
kp_array = spatial_size * (kp_array + 1) / 2
num_regions = kp_array.shape[0]
for kp_ind, kp in enumerate(kp_array):
rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
rr, cc = disk(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
image[rr, cc] = np.array(self.colormap(kp_ind / num_regions))[:3]
return image

Expand Down
24 changes: 12 additions & 12 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
imageio==2.3.0
matplotlib==2.2.2
numpy==1.15.0
pandas==0.23.4
Pillow==5.2.0
PyYAML==5.1
scikit-image==0.14.0
scikit-learn==0.19.2
scipy==1.1.0
torch==1.4.0
torchvision==0.2.1
tqdm==4.24.0
imageio
matplotlib
numpy
pandas
Pillow
PyYAML
scikit-image
scikit-learn
scipy
torch
torchvision
tqdm
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

opt = parser.parse_args()
with open(opt.config) as f:
config = yaml.load(f)
config = yaml.load(f, Loader=yaml.SafeLoader)

if opt.checkpoint is not None:
log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
Expand Down