forked from tarun005/FLAVR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
interpolate.py
161 lines (120 loc) · 5.38 KB
/
interpolate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import torch
import cv2
import pdb
import time
import sys
import torchvision
from PIL import Image
import numpy as np
import tqdm
from torchvision.io import read_video , write_video
from dataset.transforms import ToTensorVideo , Resize
import torch.nn.functional as F
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--input_video" , type=str , required=True , help="Path/WebURL to input video")
parser.add_argument("--youtube-dl" , type=str , help="Path to youtube_dl" , default=".local/bin/youtube-dl")
parser.add_argument("--factor" , type=int , required=True , choices=[2,4,8] , help="How much interpolation needed. 2x/4x/8x.")
parser.add_argument("--codec" , type=str , help="video codec" , default="mpeg4")
parser.add_argument("--load_model" , required=True , type=str , help="path for stored model")
parser.add_argument("--up_mode" , type=str , help="Upsample Mode" , default="transpose")
parser.add_argument("--output_ext" , type=str , help="Output video format" , default=".avi")
parser.add_argument("--input_ext" , type=str, help="Input video format", default=".mp4")
parser.add_argument("--downscale" , type=float , help="Downscale input res. for memory" , default=1)
parser.add_argument("--output_fps" , type=int , help="Target FPS" , default=30)
parser.add_argument("--is_folder" , action="store_true" )
args = parser.parse_args()
input_video = args.input_video
input_ext = args.input_ext
from os import path
if not args.is_folder and not path.exists(input_video):
print("Invalid input file path!")
exit()
if args.is_folder and not path.exists(input_video):
print("Invalid input directory path!")
exit()
if args.output_ext != ".avi":
print("Currently supporting only writing to avi. Try using ffmpeg for conversion to mp4 etc.")
if input_video.endswith("/"):
video_name = input_video.split("/")[-2].split(input_ext)[0]
else:
video_name = input_video.split("/")[-1].split(input_ext)[0]
output_video = os.path.join(video_name + f"_{args.factor}x" + str(args.output_ext))
n_outputs = args.factor - 1
model_name = "unet_18"
nbr_frame = 4
joinType = "concat"
if input_video.startswith("http"):
assert args.youtube_dl is not None
youtube_dl_path = args.youtube_dl
cmd = f"{youtube_dl_path} -i -o video.mp4 {input_video}"
os.system(cmd)
input_video = "video.mp4"
output_video = "video" + str(args.output_ext)
def loadModel(model, checkpoint):
saved_state_dict = torch.load(checkpoint)['state_dict']
saved_state_dict = {k.partition("module.")[-1]:v for k,v in saved_state_dict.items()}
model.load_state_dict(saved_state_dict)
checkpoint = args.load_model
from model.FLAVR_arch import UNet_3D_3D
model = UNet_3D_3D(model_name.lower() , n_inputs=4, n_outputs=n_outputs, joinType=joinType , upmode=args.up_mode)
loadModel(model , checkpoint)
model = model.cuda()
def write_video_cv2(frames , video_name , fps , sizes):
out = cv2.VideoWriter(video_name,cv2.CAP_OPENCV_MJPEG,cv2.VideoWriter_fourcc('M','J','P','G'), fps, sizes)
for frame in frames:
out.write(frame)
def make_image(img):
q_im = img.data.mul(255.).clamp(0,255).round()
im = q_im.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
return im
def files_to_videoTensor(path , downscale=1.):
from PIL import Image
files = sorted(os.listdir(path))
print(len(files))
images = [torch.Tensor(np.asarray(Image.open(os.path.join(input_video , f)))).type(torch.uint8) for f in files]
print(images[0].shape)
videoTensor = torch.stack(images)
return videoTensor
def video_to_tensor(video):
videoTensor , _ , md = read_video(video)
fps = md["video_fps"]
print(fps)
return videoTensor
def video_transform(videoTensor , downscale=1):
T , H , W = videoTensor.size(0), videoTensor.size(1) , videoTensor.size(2)
downscale = int(downscale * 8)
resizes = 8*(H//downscale) , 8*(W//downscale)
transforms = torchvision.transforms.Compose([ToTensorVideo() , Resize(resizes)])
videoTensor = transforms(videoTensor)
# resizes = 720,1280
print("Resizing to %dx%d"%(resizes[0] , resizes[1]) )
return videoTensor , resizes
if args.is_folder:
videoTensor = files_to_videoTensor(input_video , args.downscale)
else:
videoTensor = video_to_tensor(input_video)
idxs = torch.Tensor(range(len(videoTensor))).type(torch.long).view(1,-1).unfold(1,size=nbr_frame,step=1).squeeze(0)
videoTensor , resizes = video_transform(videoTensor , args.downscale)
print("Video tensor shape is , " , videoTensor.shape)
frames = torch.unbind(videoTensor , 1)
n_inputs = len(frames)
width = n_outputs + 1
outputs = [] ## store the input and interpolated frames
outputs.append(frames[idxs[0][1]])
model = model.eval()
for i in tqdm.tqdm(range(len(idxs))):
idxSet = idxs[i]
inputs = [frames[idx_].cuda().unsqueeze(0) for idx_ in idxSet]
with torch.no_grad():
outputFrame = model(inputs)
outputFrame = [of.squeeze(0).cpu().data for of in outputFrame]
outputs.extend(outputFrame)
outputs.append(inputs[2].squeeze(0).cpu().data)
new_video = [make_image(im_) for im_ in outputs]
write_video_cv2(new_video , output_video , args.output_fps , (resizes[1] , resizes[0]))
print("Writing to " , output_video.split(".")[0] + ".mp4")
os.system('ffmpeg -hide_banner -loglevel warning -i %s %s'%(output_video , output_video.split(".")[0] + ".mp4"))
os.remove(output_video)