diff --git a/utils/swap_func.py b/utils/swap_func.py index 928ef81..7d4e566 100644 --- a/utils/swap_func.py +++ b/utils/swap_func.py @@ -1,8 +1,3 @@ -# -*- coding: utf-8 -*- -# @Author: netrunner-exe -# @Date: 2022-11-23 09:52:13 -# @Last Modified by: netrunner-exe -# @Last Modified time: 2022-12-22 14:45:19 import glob import os import shutil @@ -37,25 +32,27 @@ def __exit__(self, *_): def run_inference(opt, source, target, RetinaFace, - ArcFace, G, result_img_path): + ArcFace, G, result_img_path, source_z=None): try: - source = cv2.imread(source) - source = cv2.cvtColor(source, cv2.COLOR_RGB2BGR) - source = np.array(source) if not isinstance(target, str): - target = target + target = target else: - target = cv2.imread(target) + target = cv2.imread(target) target = np.array(target) - source_h, source_w, _ = source.shape - source_a = RetinaFace(np.expand_dims(source, axis=0)).numpy()[0] - source_lm = get_lm(source_a, source_w, source_h) - source_aligned = norm_crop(source, source_lm, image_size=112, shrink_factor=1.0) + if source_z is None: + source = cv2.imread(source) + source = cv2.cvtColor(source, cv2.COLOR_RGB2BGR) + source = np.array(source) - source_z = ArcFace.predict(np.expand_dims(source_aligned / 255.0, axis=0)) + source_h, source_w, _ = source.shape + source_a = RetinaFace(np.expand_dims(source, axis=0)).numpy()[0] + source_lm = get_lm(source_a, source_w, source_h) + source_aligned = norm_crop(source, source_lm, image_size=112, shrink_factor=1.0) + + source_z = ArcFace.predict(np.expand_dims(source_aligned / 255.0, axis=0)) blend_mask_base = np.zeros(shape=(256, 256, 1)) blend_mask_base[77:240, 32:224] = 1 @@ -104,6 +101,8 @@ def run_inference(opt, source, target, RetinaFace, cv2.imwrite(result_img_path, cv2.cvtColor(total_img, cv2.COLOR_BGR2RGB)) + return total_img, source_z + except Exception as e: print('\n', e) sys.exit(0) @@ -134,11 +133,14 @@ def video_swap(opt, face, input_video, RetinaFace, ArcFace, G, out_video_filenam shutil.rmtree(temp_results_dir) os.makedirs(temp_results_dir, exist_ok=True) + source_z = None + for frame_index in tqdm(range(frame_count)): ret, frame = video.read() if ret: - run_inference(opt, face, frame, RetinaFace, ArcFace, G, - os.path.join('./tmp_frames', 'frame_{:0>7d}.png'.format(frame_index))) + _, source_z = run_inference(opt, face, frame, RetinaFace, ArcFace, G, + os.path.join('./tmp_frames', 'frame_{:0>7d}.png'.format(frame_index)), + source_z=source_z) video.release() path = os.path.join('./tmp_frames', '*.png') @@ -151,7 +153,8 @@ def video_swap(opt, face, input_video, RetinaFace, ArcFace, G, out_video_filenam try: clips.write_videofile(out_video_filename, codec='libx264', audio_codec='aac', ffmpeg_params=[ '-pix_fmt:v', 'yuv420p', '-colorspace:v', 'bt709', '-color_primaries:v', 'bt709', - '-color_trc:v', 'bt709', '-color_range:v', 'tv', '-movflags', '+faststart'], logger=proglog.TqdmProgressBarLogger(print_messages=False)) + '-color_trc:v', 'bt709', '-color_range:v', 'tv', '-movflags', '+faststart'], + logger=proglog.TqdmProgressBarLogger(print_messages=False)) except: sys.exit(0)