Skip to content

Commit

Permalink
fixed redundant calls to ArcFace when video swapping
Browse files Browse the repository at this point in the history
  • Loading branch information
felixrosberg committed Dec 23, 2022
1 parent ca07df7 commit 534ed46
Showing 1 changed file with 22 additions and 19 deletions.
41 changes: 22 additions & 19 deletions utils/swap_func.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
# -*- coding: utf-8 -*-
# @Author: netrunner-exe
# @Date: 2022-11-23 09:52:13
# @Last Modified by: netrunner-exe
# @Last Modified time: 2022-12-22 14:45:19
import glob
import os
import shutil
Expand Down Expand Up @@ -37,25 +32,27 @@ def __exit__(self, *_):


def run_inference(opt, source, target, RetinaFace,
ArcFace, G, result_img_path):
ArcFace, G, result_img_path, source_z=None):
try:
source = cv2.imread(source)
source = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
source = np.array(source)

if not isinstance(target, str):
target = target
target = target
else:
target = cv2.imread(target)
target = cv2.imread(target)

target = np.array(target)

source_h, source_w, _ = source.shape
source_a = RetinaFace(np.expand_dims(source, axis=0)).numpy()[0]
source_lm = get_lm(source_a, source_w, source_h)
source_aligned = norm_crop(source, source_lm, image_size=112, shrink_factor=1.0)
if source_z is None:
source = cv2.imread(source)
source = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
source = np.array(source)

source_z = ArcFace.predict(np.expand_dims(source_aligned / 255.0, axis=0))
source_h, source_w, _ = source.shape
source_a = RetinaFace(np.expand_dims(source, axis=0)).numpy()[0]
source_lm = get_lm(source_a, source_w, source_h)
source_aligned = norm_crop(source, source_lm, image_size=112, shrink_factor=1.0)

source_z = ArcFace.predict(np.expand_dims(source_aligned / 255.0, axis=0))

blend_mask_base = np.zeros(shape=(256, 256, 1))
blend_mask_base[77:240, 32:224] = 1
Expand Down Expand Up @@ -104,6 +101,8 @@ def run_inference(opt, source, target, RetinaFace,

cv2.imwrite(result_img_path, cv2.cvtColor(total_img, cv2.COLOR_BGR2RGB))

return total_img, source_z

except Exception as e:
print('\n', e)
sys.exit(0)
Expand Down Expand Up @@ -134,11 +133,14 @@ def video_swap(opt, face, input_video, RetinaFace, ArcFace, G, out_video_filenam
shutil.rmtree(temp_results_dir)
os.makedirs(temp_results_dir, exist_ok=True)

source_z = None

for frame_index in tqdm(range(frame_count)):
ret, frame = video.read()
if ret:
run_inference(opt, face, frame, RetinaFace, ArcFace, G,
os.path.join('./tmp_frames', 'frame_{:0>7d}.png'.format(frame_index)))
_, source_z = run_inference(opt, face, frame, RetinaFace, ArcFace, G,
os.path.join('./tmp_frames', 'frame_{:0>7d}.png'.format(frame_index)),
source_z=source_z)
video.release()

path = os.path.join('./tmp_frames', '*.png')
Expand All @@ -151,7 +153,8 @@ def video_swap(opt, face, input_video, RetinaFace, ArcFace, G, out_video_filenam
try:
clips.write_videofile(out_video_filename, codec='libx264', audio_codec='aac', ffmpeg_params=[
'-pix_fmt:v', 'yuv420p', '-colorspace:v', 'bt709', '-color_primaries:v', 'bt709',
'-color_trc:v', 'bt709', '-color_range:v', 'tv', '-movflags', '+faststart'], logger=proglog.TqdmProgressBarLogger(print_messages=False))
'-color_trc:v', 'bt709', '-color_range:v', 'tv', '-movflags', '+faststart'],
logger=proglog.TqdmProgressBarLogger(print_messages=False))
except:
sys.exit(0)

Expand Down

0 comments on commit 534ed46

Please sign in to comment.