Skip to content

Commit

Permalink
Minor code readability
Browse files Browse the repository at this point in the history
  • Loading branch information
felixrosberg committed Dec 24, 2022
1 parent 534ed46 commit a3527fd
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions utils/swap_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __exit__(self, *_):


def run_inference(opt, source, target, RetinaFace,
ArcFace, G, result_img_path, source_z=None):
ArcFace, FaceDancer, result_img_path, source_z=None):
try:

if not isinstance(target, str):
Expand Down Expand Up @@ -81,15 +81,15 @@ def run_inference(opt, source, target, RetinaFace,
im_aligned = cv2.warpAffine(im, M, (256, 256), borderValue=0.0)

# face swap
changed_face_cage = G.predict([np.expand_dims((im_aligned - 127.5) / 127.5, axis=0), source_z])
changed_face = (changed_face_cage[0] + 1) / 2
face_swap = FaceDancer.predict([np.expand_dims((im_aligned - 127.5) / 127.5, axis=0), source_z])
face_swap = (face_swap[0] + 1) / 2

# get inverse transformation landmarks
transformed_lmk = transform_landmark_points(M, lm_align)

# warp image back
iM, _ = inverse_estimate_norm(lm_align, transformed_lmk, 256, "arcface", shrink_factor=1.0)
iim_aligned = cv2.warpAffine(changed_face, iM, im_shape, borderValue=0.0)
iim_aligned = cv2.warpAffine(face_swap, iM, im_shape, borderValue=0.0)

# blend swapped face with target image
blend_mask = cv2.warpAffine(blend_mask_base, iM, im_shape, borderValue=0.0)
Expand All @@ -108,7 +108,7 @@ def run_inference(opt, source, target, RetinaFace,
sys.exit(0)


def video_swap(opt, face, input_video, RetinaFace, ArcFace, G, out_video_filename):
def video_swap(opt, face, input_video, RetinaFace, ArcFace, FaceDancer, out_video_filename):
video_forcheck = VideoFileClip(input_video)

if video_forcheck.audio is None:
Expand Down Expand Up @@ -138,7 +138,7 @@ def video_swap(opt, face, input_video, RetinaFace, ArcFace, G, out_video_filenam
for frame_index in tqdm(range(frame_count)):
ret, frame = video.read()
if ret:
_, source_z = run_inference(opt, face, frame, RetinaFace, ArcFace, G,
_, source_z = run_inference(opt, face, frame, RetinaFace, ArcFace, FaceDancer,
os.path.join('./tmp_frames', 'frame_{:0>7d}.png'.format(frame_index)),
source_z=source_z)
video.release()
Expand Down

0 comments on commit a3527fd

Please sign in to comment.