diff --git a/README.md b/README.md
index ead6e978..86c3a3eb 100644
--- a/README.md
+++ b/README.md
@@ -23,4 +23,4 @@ pip install xformers==0.0.22.post7
```
## Download Checkpoints
-All models will be downloaded automatically to ComfyUI's model folder, just no wrries.
+All models will be downloaded automatically to ComfyUI's model folder, just no wrries.
\ No newline at end of file
diff --git a/README.md.bak b/README.md.bak
index 9720bc1b..74a2493f 100644
--- a/README.md.bak
+++ b/README.md.bak
@@ -24,7 +24,7 @@
-
+
@@ -65,6 +65,7 @@ Explore [more examples](https://fudan-generative-vision.github.io/hallo).
## π° News
+- **`2024/06/28`**: πππ We are proud to announce the release of our model training code. Try your own training data. Here is [tutorial](#training).
- **`2024/06/21`**: πππ Cloned a Gradio demo on [π€Huggingface space](https://huggingface.co/spaces/fudan-generative-ai/hallo).
- **`2024/06/20`**: πππ Received numerous contributions from the community, including a [Windows version](https://github.com/sdbds/hallo-for-windows), [ComfyUI](https://github.com/AIFSH/ComfyUI-Hallo), [WebUI](https://github.com/fudan-generative-vision/hallo/pull/51), and [Docker template](https://github.com/ashleykleynhans/hallo-docker).
- **`2024/06/15`**: β¨β¨β¨ Released some images and audios for inference testing on [π€Huggingface](https://huggingface.co/datasets/fudan-generative-ai/hallo_inference_samples).
@@ -74,6 +75,7 @@ Explore [more examples](https://fudan-generative-vision.github.io/hallo).
Explore the resources developed by our community to enhance your experience with Hallo:
+- [TTS x Hallo Talking Portrait Generator](https://huggingface.co/spaces/fffiloni/tts-hallo-talking-portrait) - Check out this awesome Gradio demo by [@Sylvain Filoni](https://huggingface.co/fffiloni)! With this tool, you can conveniently prepare portrait image and audio for Hallo.
- [Demo on Huggingface](https://huggingface.co/spaces/multimodalart/hallo) - Check out this easy-to-use Gradio demo by [@multimodalart](https://huggingface.co/multimodalart).
- [hallo-webui](https://github.com/daswer123/hallo-webui) - Explore the WebUI created by [@daswer123](https://github.com/daswer123).
- [hallo-for-windows](https://github.com/sdbds/hallo-for-windows) - Utilize Hallo on Windows with the guide by [@sdbds](https://github.com/sdbds).
@@ -233,15 +235,113 @@ options:
face region
```
+## Training
+
+### Prepare Data for Training
+
+The training data, which utilizes some talking-face videos similar to the source images used for inference, also needs to meet the following requirements:
+
+1. It should be cropped into squares.
+2. The face should be the main focus, making up 50%-70% of the image.
+3. The face should be facing forward, with a rotation angle of less than 30Β° (no side profiles).
+
+Organize your raw videos into the following directory structure:
+
+
+```text
+dataset_name/
+|-- videos/
+| |-- 0001.mp4
+| |-- 0002.mp4
+| |-- 0003.mp4
+| `-- 0004.mp4
+```
+
+You can use any `dataset_name`, but ensure the `videos` directory is named as shown above.
+
+Next, process the videos with the following commands:
+
+```bash
+python -m scripts.data_preprocess --input_dir dataset_name/videos --step 1
+python -m scripts.data_preprocess --input_dir dataset_name/videos --step 2
+```
+
+**Note:** Execute steps 1 and 2 sequentially as they perform different tasks. Step 1 converts videos into frames, extracts audio from each video, and generates the necessary masks. Step 2 generates face embeddings using InsightFace and audio embeddings using Wav2Vec, and requires a GPU. For parallel processing, use the `-p` and `-r` arguments. The `-p` argument specifies the total number of instances to launch, dividing the data into `p` parts. The `-r` argument specifies which part the current process should handle. You need to manually launch multiple instances with different values for `-r`.
+
+Generate the metadata JSON files with the following commands:
+
+```bash
+python scripts/extract_meta_info_stage1.py -r path/to/dataset -n dataset_name
+python scripts/extract_meta_info_stage2.py -r path/to/dataset -n dataset_name
+```
+
+Replace `path/to/dataset` with the path to the parent directory of `videos`, such as `dataset_name` in the example above. This will generate `dataset_name_stage1.json` and `dataset_name_stage2.json` in the `./data` directory.
+
+### Training
+
+Update the data meta path settings in the configuration YAML files, `configs/train/stage1.yaml` and `configs/train/stage2.yaml`:
+
+
+```yaml
+#stage1.yaml
+data:
+ meta_paths:
+ - ./data/dataset_name_stage1.json
+
+#stage2.yaml
+data:
+ meta_paths:
+ - ./data/dataset_name_stage2.json
+```
+
+Start training with the following command:
+
+```shell
+accelerate launch -m \
+ --config_file accelerate_config.yaml \
+ --machine_rank 0 \
+ --main_process_ip 0.0.0.0 \
+ --main_process_port 20055 \
+ --num_machines 1 \
+ --num_processes 8 \
+ scripts.train_stage1 --config ./configs/train/stage1.yaml
+```
+
+#### Accelerate Usage Explanation
+
+The `accelerate launch` command is used to start the training process with distributed settings.
+
+```shell
+accelerate launch [arguments] {training_script} --{training_script-argument-1} --{training_script-argument-2} ...
+```
+
+**Arguments for Accelerate:**
+
+- `-m, --module`: Interpret the launch script as a Python module.
+- `--config_file`: Configuration file for Hugging Face Accelerate.
+- `--machine_rank`: Rank of the current machine in a multi-node setup.
+- `--main_process_ip`: IP address of the master node.
+- `--main_process_port`: Port of the master node.
+- `--num_machines`: Total number of nodes participating in the training.
+- `--num_processes`: Total number of processes for training, matching the total number of GPUs across all machines.
+
+**Arguments for Training:**
+
+- `{training_script}`: The training script, such as `scripts.train_stage1` or `scripts.train_stage2`.
+- `--{training_script-argument-1}`: Arguments specific to the training script. Our training scripts accept one argument, `--config`, to specify the training configuration file.
+
+For multi-node training, you need to manually run the command with different `machine_rank` on each node separately.
+
+For more settings, refer to the [Accelerate documentation](https://huggingface.co/docs/accelerate/en/index).
+
## π
οΈ Roadmap
| Status | Milestone | ETA |
| :----: | :---------------------------------------------------------------------------------------------------- | :--------: |
| β
| **[Inference source code meet everyone on GitHub](https://github.com/fudan-generative-vision/hallo)** | 2024-06-15 |
| β
| **[Pretrained models on Huggingface](https://huggingface.co/fudan-generative-ai/hallo)** | 2024-06-15 |
-| π§ | **[Optimizing Performance on images with a resolution of 256x256.]()** | 2024-06-23 |
-| π | **[Improving the model's performance on Mandarin Chinese]()** | 2024-06-25 |
-| π | **[Releasing data preparation and training scripts]()** | 2024-06-28 |
+| β
| **[Releasing data preparation and training scripts](#training)** | 2024-06-28 |
+| π | **[Improving the model's performance on Mandarin Chinese]()** | TBD |
Other Enhancements
@@ -250,7 +350,6 @@ options:
- [x] Bug: Output video may lose several frames. [#41](https://github.com/fudan-generative-vision/hallo/issues/41)
- [ ] Bug: Sound volume affecting inference results (audio normalization).
- [ ] ~~Enhancement: Inference code logic optimization~~. This solution doesn't show significant performance improvements. Trying other approaches.
-- [ ] Enhancement: Enhancing performance on low resolutions(256x256) to support more efficient usage.
@@ -297,4 +396,4 @@ Thank you to all the contributors who have helped to make this project better!
-
+
\ No newline at end of file
diff --git a/accelerate_config.yaml b/accelerate_config.yaml
new file mode 100644
index 00000000..6fa766f1
--- /dev/null
+++ b/accelerate_config.yaml
@@ -0,0 +1,21 @@
+compute_environment: LOCAL_MACHINE
+debug: true
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 1
+ offload_optimizer_device: none
+ offload_param_device: none
+ zero3_init_flag: false
+ zero_stage: 2
+distributed_type: DEEPSPEED
+downcast_bf16: "no"
+main_training_function: main
+mixed_precision: "fp16"
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/configs/train/stage1.yaml b/configs/train/stage1.yaml
new file mode 100644
index 00000000..28760ed2
--- /dev/null
+++ b/configs/train/stage1.yaml
@@ -0,0 +1,63 @@
+data:
+ train_bs: 8
+ train_width: 512
+ train_height: 512
+ meta_paths:
+ - "./data/HDTF_meta.json"
+ # Margin of frame indexes between ref and tgt images
+ sample_margin: 30
+
+solver:
+ gradient_accumulation_steps: 1
+ mixed_precision: "no"
+ enable_xformers_memory_efficient_attention: True
+ gradient_checkpointing: False
+ max_train_steps: 30000
+ max_grad_norm: 1.0
+ # lr
+ learning_rate: 1.0e-5
+ scale_lr: False
+ lr_warmup_steps: 1
+ lr_scheduler: "constant"
+
+ # optimizer
+ use_8bit_adam: False
+ adam_beta1: 0.9
+ adam_beta2: 0.999
+ adam_weight_decay: 1.0e-2
+ adam_epsilon: 1.0e-8
+
+val:
+ validation_steps: 500
+
+noise_scheduler_kwargs:
+ num_train_timesteps: 1000
+ beta_start: 0.00085
+ beta_end: 0.012
+ beta_schedule: "scaled_linear"
+ steps_offset: 1
+ clip_sample: false
+
+base_model_path: "./pretrained_models/stable-diffusion-v1-5/"
+vae_model_path: "./pretrained_models/sd-vae-ft-mse"
+face_analysis_model_path: "./pretrained_models/face_analysis"
+
+weight_dtype: "fp16" # [fp16, fp32]
+uncond_ratio: 0.1
+noise_offset: 0.05
+snr_gamma: 5.0
+enable_zero_snr: True
+face_locator_pretrained: False
+
+seed: 42
+resume_from_checkpoint: "latest"
+checkpointing_steps: 500
+exp_name: "stage1"
+output_dir: "./exp_output"
+
+ref_image_paths:
+ - "examples/reference_images/1.jpg"
+
+mask_image_paths:
+ - "examples/masks/1.png"
+
diff --git a/configs/train/stage2.yaml b/configs/train/stage2.yaml
new file mode 100644
index 00000000..baa30729
--- /dev/null
+++ b/configs/train/stage2.yaml
@@ -0,0 +1,119 @@
+data:
+ train_bs: 4
+ val_bs: 1
+ train_width: 512
+ train_height: 512
+ fps: 25
+ sample_rate: 16000
+ n_motion_frames: 2
+ n_sample_frames: 14
+ audio_margin: 2
+ train_meta_paths:
+ - "./data/hdtf_split_stage2.json"
+
+wav2vec_config:
+ audio_type: "vocals" # audio vocals
+ model_scale: "base" # base large
+ features: "all" # last avg all
+ model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h
+audio_separator:
+ model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx
+face_expand_ratio: 1.2
+
+solver:
+ gradient_accumulation_steps: 1
+ mixed_precision: "no"
+ enable_xformers_memory_efficient_attention: True
+ gradient_checkpointing: True
+ max_train_steps: 30000
+ max_grad_norm: 1.0
+ # lr
+ learning_rate: 1e-5
+ scale_lr: False
+ lr_warmup_steps: 1
+ lr_scheduler: "constant"
+
+ # optimizer
+ use_8bit_adam: True
+ adam_beta1: 0.9
+ adam_beta2: 0.999
+ adam_weight_decay: 1.0e-2
+ adam_epsilon: 1.0e-8
+
+val:
+ validation_steps: 1000
+
+noise_scheduler_kwargs:
+ num_train_timesteps: 1000
+ beta_start: 0.00085
+ beta_end: 0.012
+ beta_schedule: "linear"
+ steps_offset: 1
+ clip_sample: false
+
+unet_additional_kwargs:
+ use_inflated_groupnorm: true
+ unet_use_cross_frame_attention: false
+ unet_use_temporal_attention: false
+ use_motion_module: true
+ use_audio_module: true
+ motion_module_resolutions:
+ - 1
+ - 2
+ - 4
+ - 8
+ motion_module_mid_block: true
+ motion_module_decoder_only: false
+ motion_module_type: Vanilla
+ motion_module_kwargs:
+ num_attention_heads: 8
+ num_transformer_block: 1
+ attention_block_types:
+ - Temporal_Self
+ - Temporal_Self
+ temporal_position_encoding: true
+ temporal_position_encoding_max_len: 32
+ temporal_attention_dim_div: 1
+ audio_attention_dim: 768
+ stack_enable_blocks_name:
+ - "up"
+ - "down"
+ - "mid"
+ stack_enable_blocks_depth: [0,1,2,3]
+
+trainable_para:
+ - audio_modules
+ - motion_modules
+
+base_model_path: "./pretrained_models/stable-diffusion-v1-5/"
+vae_model_path: "./pretrained_models/sd-vae-ft-mse"
+face_analysis_model_path: "./pretrained_models/face_analysis"
+mm_path: "./pretrained_models/motion_module/mm_sd_v15_v2.ckpt"
+
+weight_dtype: "fp16" # [fp16, fp32]
+uncond_img_ratio: 0.05
+uncond_audio_ratio: 0.05
+uncond_ia_ratio: 0.05
+start_ratio: 0.05
+noise_offset: 0.05
+snr_gamma: 5.0
+enable_zero_snr: True
+stage1_ckpt_dir: "./exp_output/stage1/"
+
+single_inference_times: 10
+inference_steps: 40
+cfg_scale: 3.5
+
+seed: 42
+resume_from_checkpoint: "latest"
+checkpointing_steps: 500
+exp_name: "stage2"
+output_dir: "./exp_output"
+
+ref_img_path:
+ - "examples/reference_images/1.jpg"
+
+audio_path:
+ - "examples/driving_audios/1.wav"
+
+
diff --git a/examples/masks/1.png b/examples/masks/1.png
new file mode 100644
index 00000000..c63e0757
Binary files /dev/null and b/examples/masks/1.png differ
diff --git a/hallo/datasets/audio_processor.py b/hallo/datasets/audio_processor.py
index 50738970..f340a52f 100644
--- a/hallo/datasets/audio_processor.py
+++ b/hallo/datasets/audio_processor.py
@@ -73,7 +73,7 @@ def __init__(
self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True)
- def preprocess(self, wav_file: str, clip_length: int):
+ def preprocess(self, wav_file: str, clip_length: int=-1):
"""
Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate.
The separated vocal track is then converted into wav2vec2 for further processing or analysis.
@@ -109,7 +109,8 @@ def preprocess(self, wav_file: str, clip_length: int):
audio_length = seq_len
audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device)
- if seq_len % clip_length != 0:
+
+ if clip_length>0 and seq_len % clip_length != 0:
audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - seq_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0)
seq_len += clip_length - seq_len % clip_length
audio_feature = audio_feature.unsqueeze(0)
diff --git a/hallo/datasets/image_processor.py b/hallo/datasets/image_processor.py
index c1456fe0..1eaa0230 100644
--- a/hallo/datasets/image_processor.py
+++ b/hallo/datasets/image_processor.py
@@ -1,3 +1,4 @@
+# pylint: disable=W0718
"""
This module is responsible for processing images, particularly for face-related tasks.
It uses various libraries such as OpenCV, NumPy, and InsightFace to perform tasks like
@@ -8,13 +9,15 @@
from typing import List
import cv2
+import mediapipe as mp
import numpy as np
import torch
from insightface.app import FaceAnalysis
from PIL import Image
from torchvision import transforms
-from ..utils.util import get_mask
+from ..utils.util import (blur_mask, get_landmark_overframes, get_mask,
+ get_union_face_mask, get_union_lip_mask)
MEAN = 0.5
STD = 0.5
@@ -207,3 +210,137 @@ def __enter__(self):
def __exit__(self, _exc_type, _exc_val, _exc_tb):
self.close()
+
+
+class ImageProcessorForDataProcessing():
+ """
+ ImageProcessor is a class responsible for processing images, particularly for face-related tasks.
+ It takes in an image and performs various operations such as augmentation, face detection,
+ face embedding extraction, and rendering a face mask. The processed images are then used for
+ further analysis or recognition purposes.
+
+ Attributes:
+ img_size (int): The size of the image to be processed.
+ face_analysis_model_path (str): The path to the face analysis model.
+
+ Methods:
+ preprocess(source_image_path, cache_dir):
+ Preprocesses the input image by performing augmentation, face detection,
+ face embedding extraction, and rendering a face mask.
+
+ close():
+ Closes the ImageProcessor and releases any resources being used.
+
+ _augmentation(images, transform, state=None):
+ Applies image augmentation to the input images using the given transform and state.
+
+ __enter__():
+ Enters a runtime context and returns the ImageProcessor object.
+
+ __exit__(_exc_type, _exc_val, _exc_tb):
+ Exits a runtime context and handles any exceptions that occurred during the processing.
+ """
+ def __init__(self, face_analysis_model_path, landmark_model_path, step) -> None:
+ if step == 2:
+ self.face_analysis = FaceAnalysis(
+ name="",
+ root=face_analysis_model_path,
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
+ )
+ self.face_analysis.prepare(ctx_id=0, det_size=(640, 640))
+ self.landmarker = None
+ else:
+ BaseOptions = mp.tasks.BaseOptions
+ FaceLandmarker = mp.tasks.vision.FaceLandmarker
+ FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
+ VisionRunningMode = mp.tasks.vision.RunningMode
+ # Create a face landmarker instance with the video mode:
+ options = FaceLandmarkerOptions(
+ base_options=BaseOptions(model_asset_path=landmark_model_path),
+ running_mode=VisionRunningMode.IMAGE,
+ )
+ self.landmarker = FaceLandmarker.create_from_options(options)
+ self.face_analysis = None
+
+ def preprocess(self, source_image_path: str):
+ """
+ Apply preprocessing to the source image to prepare for face analysis.
+
+ Parameters:
+ source_image_path (str): The path to the source image.
+ cache_dir (str): The directory to cache intermediate results.
+
+ Returns:
+ None
+ """
+ # 1. get face embdeding
+ face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask = None, None, None, None, None
+ if self.face_analysis:
+ for frame in sorted(os.listdir(source_image_path)):
+ try:
+ source_image = Image.open(
+ os.path.join(source_image_path, frame))
+ ref_image_pil = source_image.convert("RGB")
+ # 2.1 detect face
+ faces = self.face_analysis.get(cv2.cvtColor(
+ np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR))
+ # use max size face
+ face = sorted(faces, key=lambda x: (
+ x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[-1]
+ # 2.2 face embedding
+ face_emb = face["embedding"]
+ if face_emb is not None:
+ break
+ except Exception as _:
+ continue
+
+ if self.landmarker:
+ # 3.1 get landmark
+ landmarks, height, width = get_landmark_overframes(
+ self.landmarker, source_image_path)
+ assert len(landmarks) == len(os.listdir(source_image_path))
+
+ # 3 render face and lip mask
+ face_mask = get_union_face_mask(landmarks, height, width)
+ lip_mask = get_union_lip_mask(landmarks, height, width)
+
+ # 4 gaussian blur
+ blur_face_mask = blur_mask(face_mask, (64, 64), (51, 51))
+ blur_lip_mask = blur_mask(lip_mask, (64, 64), (31, 31))
+
+ # 5 seperate mask
+ sep_face_mask = cv2.subtract(blur_face_mask, blur_lip_mask)
+ sep_pose_mask = 255.0 - blur_face_mask
+ sep_lip_mask = blur_lip_mask
+
+ return face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask
+
+ def close(self):
+ """
+ Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance.
+
+ Args:
+ self: The ImageProcessor instance.
+
+ Returns:
+ None.
+ """
+ for _, model in self.face_analysis.models.items():
+ if hasattr(model, "Dispose"):
+ model.Dispose()
+
+ def _augmentation(self, images, transform, state=None):
+ if state is not None:
+ torch.set_rng_state(state)
+ if isinstance(images, List):
+ transformed_images = [transform(img) for img in images]
+ ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
+ else:
+ ret_tensor = transform(images) # (c, h, w)
+ return ret_tensor
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, _exc_type, _exc_val, _exc_tb):
+ self.close()
diff --git a/hallo/datasets/talk_video.py b/hallo/datasets/talk_video.py
index 4f9114ba..25c3ab81 100644
--- a/hallo/datasets/talk_video.py
+++ b/hallo/datasets/talk_video.py
@@ -145,25 +145,29 @@ def __init__(
)
self.attn_transform_64 = transforms.Compose(
[
- transforms.Resize((64,64)),
+ transforms.Resize(
+ (self.img_size[0] // 8, self.img_size[0] // 8)),
transforms.ToTensor(),
]
)
self.attn_transform_32 = transforms.Compose(
[
- transforms.Resize((32, 32)),
+ transforms.Resize(
+ (self.img_size[0] // 16, self.img_size[0] // 16)),
transforms.ToTensor(),
]
)
self.attn_transform_16 = transforms.Compose(
[
- transforms.Resize((16, 16)),
+ transforms.Resize(
+ (self.img_size[0] // 32, self.img_size[0] // 32)),
transforms.ToTensor(),
]
)
self.attn_transform_8 = transforms.Compose(
[
- transforms.Resize((8, 8)),
+ transforms.Resize(
+ (self.img_size[0] // 64, self.img_size[0] // 64)),
transforms.ToTensor(),
]
)
diff --git a/hallo/models/motion_module.py b/hallo/models/motion_module.py
index 07f98454..f62877d4 100644
--- a/hallo/models/motion_module.py
+++ b/hallo/models/motion_module.py
@@ -507,6 +507,7 @@ def extra_repr(self):
def set_use_memory_efficient_attention_xformers(
self,
use_memory_efficient_attention_xformers: bool,
+ attention_op = None,
):
"""
Sets the use of memory-efficient attention xformers for the VersatileAttention class.
diff --git a/hallo/utils/util.py b/hallo/utils/util.py
index 6ceca790..e29af026 100644
--- a/hallo/utils/util.py
+++ b/hallo/utils/util.py
@@ -1,6 +1,7 @@
# pylint: disable=C0116
# pylint: disable=W0718
# pylint: disable=R1732
+# pylint: disable=R0801
"""
utils.py
@@ -67,6 +68,7 @@
import subprocess
import sys
from pathlib import Path
+from typing import List
import av
import cv2
@@ -377,7 +379,32 @@ def get_landmark(file, model_path):
return np.array(face_landmark), height, width
-def get_lip_mask(landmarks, height, width, out_path):
+def get_landmark_overframes(landmark_model, frames_path):
+ """
+ This function iterate frames and returns the facial landmarks detected in each frame.
+
+ Args:
+ landmark_model: mediapipe landmark model instance
+ frames_path (str): The path to the video frames.
+
+ Returns:
+ List[List[float], float, float]: A List containing two lists of floats representing the x and y coordinates of the facial landmarks.
+ """
+
+ face_landmarks = []
+
+ for file in sorted(os.listdir(frames_path)):
+ image = mp.Image.create_from_file(os.path.join(frames_path, file))
+ height, width = image.height, image.width
+ landmarker_result = landmark_model.detect(image)
+ frame_landmark = compute_face_landmarks(
+ landmarker_result, height, width)
+ face_landmarks.append(frame_landmark)
+
+ return face_landmarks, height, width
+
+
+def get_lip_mask(landmarks, height, width, out_path=None, expand_ratio=2.0):
"""
Extracts the lip region from the given landmarks and saves it as an image.
@@ -386,19 +413,42 @@ def get_lip_mask(landmarks, height, width, out_path):
height (int): Height of the output lip mask image.
width (int): Width of the output lip mask image.
out_path (pathlib.Path): Path to save the lip mask image.
+ expand_ratio (float): Expand ratio of mask.
"""
lip_landmarks = np.take(landmarks, lip_ids, 0)
min_xy_lip = np.round(np.min(lip_landmarks, 0))
max_xy_lip = np.round(np.max(lip_landmarks, 0))
min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1] = expand_region(
- [min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1]], width, height, 2.0)
+ [min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1]], width, height, expand_ratio)
lip_mask = np.zeros((height, width), dtype=np.uint8)
lip_mask[round(min_xy_lip[1]):round(max_xy_lip[1]),
round(min_xy_lip[0]):round(max_xy_lip[0])] = 255
- cv2.imwrite(str(out_path), lip_mask)
+ if out_path:
+ cv2.imwrite(str(out_path), lip_mask)
+ return None
+
+ return lip_mask
+
+
+def get_union_lip_mask(landmarks, height, width, expand_ratio=1):
+ """
+ Extracts the lip region from the given landmarks and saves it as an image.
+
+ Parameters:
+ landmarks (numpy.ndarray): Array of facial landmarks.
+ height (int): Height of the output lip mask image.
+ width (int): Width of the output lip mask image.
+ expand_ratio (float): Expand ratio of mask.
+ """
+ lip_masks = []
+ for landmark in landmarks:
+ lip_masks.append(get_lip_mask(landmarks=landmark, height=height,
+ width=width, expand_ratio=expand_ratio))
+ union_mask = get_union_mask(lip_masks)
+ return union_mask
-def get_face_mask(landmarks, height, width, out_path, expand_ratio):
+def get_face_mask(landmarks, height, width, out_path=None, expand_ratio=1.2):
"""
Generate a face mask based on the given landmarks.
@@ -407,7 +457,7 @@ def get_face_mask(landmarks, height, width, out_path, expand_ratio):
height (int): The height of the output face mask image.
width (int): The width of the output face mask image.
out_path (pathlib.Path): The path to save the face mask image.
-
+ expand_ratio (float): Expand ratio of mask.
Returns:
None. The face mask image is saved at the specified path.
"""
@@ -419,8 +469,30 @@ def get_face_mask(landmarks, height, width, out_path, expand_ratio):
face_mask = np.zeros((height, width), dtype=np.uint8)
face_mask[round(min_xy_face[1]):round(max_xy_face[1]),
round(min_xy_face[0]):round(max_xy_face[0])] = 255
- cv2.imwrite(str(out_path), face_mask)
+ if out_path:
+ cv2.imwrite(str(out_path), face_mask)
+ return None
+ return face_mask
+
+
+def get_union_face_mask(landmarks, height, width, expand_ratio=1):
+ """
+ Generate a face mask based on the given landmarks.
+
+ Args:
+ landmarks (numpy.ndarray): The landmarks of the face.
+ height (int): The height of the output face mask image.
+ width (int): The width of the output face mask image.
+ expand_ratio (float): Expand ratio of mask.
+ Returns:
+ None. The face mask image is saved at the specified path.
+ """
+ face_masks = []
+ for landmark in landmarks:
+ face_masks.append(get_face_mask(landmarks=landmark,height=height,width=width,expand_ratio=expand_ratio))
+ union_mask = get_union_mask(face_masks)
+ return union_mask
def get_mask(file, cache_dir, face_expand_raio, landmark_model_path):
"""
@@ -506,6 +578,25 @@ def get_blur_mask(file_path, output_file_path, resize_dim=(64, 64), kernel_size=
mask = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
# Check if the image is loaded successfully
+ if mask is not None:
+ normalized_mask = blur_mask(mask,resize_dim=resize_dim,kernel_size=kernel_size)
+ # Save the normalized mask image
+ cv2.imwrite(output_file_path, normalized_mask)
+ return f"Processed, normalized, and saved: {output_file_path}"
+ return f"Failed to load image: {file_path}"
+
+
+def blur_mask(mask, resize_dim=(64, 64), kernel_size=(51, 51)):
+ """
+ Read, resize, blur, normalize, and save an image.
+
+ Parameters:
+ file_path (str): Path to the input image file.
+ resize_dim (tuple): Dimensions to resize the images to.
+ kernel_size (tuple): Size of the kernel to use for Gaussian blur.
+ """
+ # Check if the image is loaded successfully
+ normalized_mask = None
if mask is not None:
# Resize the mask image
resized_mask = cv2.resize(mask, resize_dim)
@@ -515,10 +606,7 @@ def get_blur_mask(file_path, output_file_path, resize_dim=(64, 64), kernel_size=
normalized_mask = cv2.normalize(
blurred_mask, None, 0, 255, cv2.NORM_MINMAX)
# Save the normalized mask image
- cv2.imwrite(output_file_path, normalized_mask)
- return f"Processed, normalized, and saved: {output_file_path}"
- return f"Failed to load image: {file_path}"
-
+ return normalized_mask
def get_background_mask(file_path, output_file_path):
"""
@@ -614,3 +702,279 @@ def get_face_region(image_path: str, detector):
except Exception as e:
print(f"Error processing image {image_path}: {e}")
return None, None
+
+
+def save_checkpoint(model: torch.nn.Module, save_dir: str, prefix: str, ckpt_num: int, total_limit: int = -1) -> None:
+ """
+ Save the model's state_dict to a checkpoint file.
+
+ If `total_limit` is provided, this function will remove the oldest checkpoints
+ until the total number of checkpoints is less than the specified limit.
+
+ Args:
+ model (nn.Module): The model whose state_dict is to be saved.
+ save_dir (str): The directory where the checkpoint will be saved.
+ prefix (str): The prefix for the checkpoint file name.
+ ckpt_num (int): The checkpoint number to be saved.
+ total_limit (int, optional): The maximum number of checkpoints to keep.
+ Defaults to None, in which case no checkpoints will be removed.
+
+ Raises:
+ FileNotFoundError: If the save directory does not exist.
+ ValueError: If the checkpoint number is negative.
+ OSError: If there is an error saving the checkpoint.
+ """
+
+ if not osp.exists(save_dir):
+ raise FileNotFoundError(
+ f"The save directory {save_dir} does not exist.")
+
+ if ckpt_num < 0:
+ raise ValueError(f"Checkpoint number {ckpt_num} must be non-negative.")
+
+ save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth")
+
+ if total_limit > 0:
+ checkpoints = os.listdir(save_dir)
+ checkpoints = [d for d in checkpoints if d.startswith(prefix)]
+ checkpoints = sorted(
+ checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
+ )
+
+ if len(checkpoints) >= total_limit:
+ num_to_remove = len(checkpoints) - total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+ print(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ print(
+ f"Removing checkpoints: {', '.join(removing_checkpoints)}"
+ )
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint_path = osp.join(
+ save_dir, removing_checkpoint)
+ try:
+ os.remove(removing_checkpoint_path)
+ except OSError as e:
+ print(
+ f"Error removing checkpoint {removing_checkpoint_path}: {e}")
+
+ state_dict = model.state_dict()
+ try:
+ torch.save(state_dict, save_path)
+ print(f"Checkpoint saved at {save_path}")
+ except OSError as e:
+ raise OSError(f"Error saving checkpoint at {save_path}: {e}") from e
+
+
+def init_output_dir(dir_list: List[str]):
+ """
+ Initialize the output directories.
+
+ This function creates the directories specified in the `dir_list`. If a directory already exists, it does nothing.
+
+ Args:
+ dir_list (List[str]): List of directory paths to create.
+ """
+ for path in dir_list:
+ os.makedirs(path, exist_ok=True)
+
+
+def load_checkpoint(cfg, save_dir, accelerator):
+ """
+ Load the most recent checkpoint from the specified directory.
+
+ This function loads the latest checkpoint from the `save_dir` if the `resume_from_checkpoint` parameter is set to "latest".
+ If a specific checkpoint is provided in `resume_from_checkpoint`, it loads that checkpoint. If no checkpoint is found,
+ it starts training from scratch.
+
+ Args:
+ cfg: The configuration object containing training parameters.
+ save_dir (str): The directory where checkpoints are saved.
+ accelerator: The accelerator object for distributed training.
+
+ Returns:
+ int: The global step at which to resume training.
+ """
+ if cfg.resume_from_checkpoint != "latest":
+ resume_dir = cfg.resume_from_checkpoint
+ else:
+ resume_dir = save_dir
+ # Get the most recent checkpoint
+ dirs = os.listdir(resume_dir)
+
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ if len(dirs) > 0:
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1]
+ accelerator.load_state(os.path.join(resume_dir, path))
+ accelerator.print(f"Resuming from checkpoint {path}")
+ global_step = int(path.split("-")[1])
+ else:
+ accelerator.print(
+ f"Could not find checkpoint under {resume_dir}, start training from scratch")
+ global_step = 0
+
+ return global_step
+
+
+def compute_snr(noise_scheduler, timesteps):
+ """
+ Computes SNR as per
+ https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/
+ 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
+ """
+ alphas_cumprod = noise_scheduler.alphas_cumprod
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
+
+ # Expand the tensors.
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/
+ # 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
+ timesteps
+ ].float()
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
+
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
+ device=timesteps.device
+ )[timesteps].float()
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
+
+ # Compute SNR.
+ snr = (alpha / sigma) ** 2
+ return snr
+
+
+def extract_audio_from_videos(video_path: Path, audio_output_path: Path) -> Path:
+ """
+ Extract audio from a video file and save it as a WAV file.
+
+ This function uses ffmpeg to extract the audio stream from a given video file and saves it as a WAV file
+ in the specified output directory.
+
+ Args:
+ video_path (Path): The path to the input video file.
+ output_dir (Path): The directory where the extracted audio file will be saved.
+
+ Returns:
+ Path: The path to the extracted audio file.
+
+ Raises:
+ subprocess.CalledProcessError: If the ffmpeg command fails to execute.
+ """
+ ffmpeg_command = [
+ 'ffmpeg', '-y',
+ '-i', str(video_path),
+ '-vn', '-acodec',
+ "pcm_s16le", '-ar', '16000', '-ac', '2',
+ str(audio_output_path)
+ ]
+
+ try:
+ print(f"Running command: {' '.join(ffmpeg_command)}")
+ subprocess.run(ffmpeg_command, check=True)
+ except subprocess.CalledProcessError as e:
+ print(f"Error extracting audio from video: {e}")
+ raise
+
+ return audio_output_path
+
+
+def convert_video_to_images(video_path: Path, output_dir: Path) -> Path:
+ """
+ Convert a video file into a sequence of images.
+
+ This function uses ffmpeg to convert each frame of the given video file into an image. The images are saved
+ in a directory named after the video file stem under the specified output directory.
+
+ Args:
+ video_path (Path): The path to the input video file.
+ output_dir (Path): The directory where the extracted images will be saved.
+
+ Returns:
+ Path: The path to the directory containing the extracted images.
+
+ Raises:
+ subprocess.CalledProcessError: If the ffmpeg command fails to execute.
+ """
+ ffmpeg_command = [
+ 'ffmpeg',
+ '-i', str(video_path),
+ '-vf', 'fps=25',
+ str(output_dir / '%04d.png')
+ ]
+
+ try:
+ print(f"Running command: {' '.join(ffmpeg_command)}")
+ subprocess.run(ffmpeg_command, check=True)
+ except subprocess.CalledProcessError as e:
+ print(f"Error converting video to images: {e}")
+ raise
+
+ return output_dir
+
+
+def get_union_mask(masks):
+ """
+ Compute the union of a list of masks.
+
+ This function takes a list of masks and computes their union by taking the maximum value at each pixel location.
+ Additionally, it finds the bounding box of the non-zero regions in the mask and sets the bounding box area to white.
+
+ Args:
+ masks (list of np.ndarray): List of masks to be combined.
+
+ Returns:
+ np.ndarray: The union of the input masks.
+ """
+ union_mask = None
+ for mask in masks:
+ if union_mask is None:
+ union_mask = mask
+ else:
+ union_mask = np.maximum(union_mask, mask)
+
+ if union_mask is not None:
+ # Find the bounding box of the non-zero regions in the mask
+ rows = np.any(union_mask, axis=1)
+ cols = np.any(union_mask, axis=0)
+ try:
+ ymin, ymax = np.where(rows)[0][[0, -1]]
+ xmin, xmax = np.where(cols)[0][[0, -1]]
+ except Exception as e:
+ print(str(e))
+ return 0.0
+
+ # Set bounding box area to white
+ union_mask[ymin: ymax + 1, xmin: xmax + 1] = np.max(union_mask)
+
+ return union_mask
+
+
+def move_final_checkpoint(save_dir, module_dir, prefix):
+ """
+ Move the final checkpoint file to the save directory.
+
+ This function identifies the latest checkpoint file based on the given prefix and moves it to the specified save directory.
+
+ Args:
+ save_dir (str): The directory where the final checkpoint file should be saved.
+ module_dir (str): The directory containing the checkpoint files.
+ prefix (str): The prefix used to identify checkpoint files.
+
+ Raises:
+ ValueError: If no checkpoint files are found with the specified prefix.
+ """
+ checkpoints = os.listdir(module_dir)
+ checkpoints = [d for d in checkpoints if d.startswith(prefix)]
+ checkpoints = sorted(
+ checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
+ )
+ shutil.copy2(os.path.join(
+ module_dir, checkpoints[-1]), os.path.join(save_dir, prefix + '.pth'))
diff --git a/nodes.py b/nodes.py
index 03da7973..46b44662 100644
--- a/nodes.py
+++ b/nodes.py
@@ -149,6 +149,8 @@ def inference(self, source_image, driving_audio, pose_weight, face_weight, lip_w
# get src audio
src_audio_path = os.path.join(folder_paths.get_input_directory(), driving_audio)
+ if not os.path.exists(src_audio_path):
+ src_audio_path = driving_audio # absolute path
env = ':'.join([os.environ.get('PYTHONPATH', ''), cur_dir])
cmd = f"""PYTHONPATH={env} python {infer_py} --config "{tmp_yaml_path}" --source_image "{src_img_path}" --driving_audio "{src_audio_path}" --output {output_video_path} --pose_weight {pose_weight} --face_weight {face_weight} --lip_weight {lip_weight} --face_expand_ratio {face_expand_ratio}"""
diff --git a/scripts/data_preprocess.py b/scripts/data_preprocess.py
new file mode 100644
index 00000000..92efc2fc
--- /dev/null
+++ b/scripts/data_preprocess.py
@@ -0,0 +1,191 @@
+# pylint: disable=W1203,W0718
+"""
+This module is used to process videos to prepare data for training. It utilizes various libraries and models
+to perform tasks such as video frame extraction, audio extraction, face mask generation, and face embedding extraction.
+The script takes in command-line arguments to specify the input and output directories, GPU status, level of parallelism,
+and rank for distributed processing.
+
+Usage:
+ python -m scripts.data_preprocess --input_dir /path/to/video_dir --dataset_name dataset_name --gpu_status --parallelism 4 --rank 0
+
+Example:
+ python -m scripts.data_preprocess -i data/videos -o data/output -g -p 4 -r 0
+"""
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import List
+
+import cv2
+import torch
+from tqdm import tqdm
+
+from hallo.datasets.audio_processor import AudioProcessor
+from hallo.datasets.image_processor import ImageProcessorForDataProcessing
+from hallo.utils.util import convert_video_to_images, extract_audio_from_videos
+
+# Configure logging
+logging.basicConfig(level=logging.INFO,
+ format='%(asctime)s - %(levelname)s - %(message)s')
+
+
+def setup_directories(video_path: Path) -> dict:
+ """
+ Setup directories for storing processed files.
+
+ Args:
+ video_path (Path): Path to the video file.
+
+ Returns:
+ dict: A dictionary containing paths for various directories.
+ """
+ base_dir = video_path.parent.parent
+ dirs = {
+ "face_mask": base_dir / "face_mask",
+ "sep_pose_mask": base_dir / "sep_pose_mask",
+ "sep_face_mask": base_dir / "sep_face_mask",
+ "sep_lip_mask": base_dir / "sep_lip_mask",
+ "face_emb": base_dir / "face_emb",
+ "audio_emb": base_dir / "audio_emb"
+ }
+
+ for path in dirs.values():
+ path.mkdir(parents=True, exist_ok=True)
+
+ return dirs
+
+
+def process_single_video(video_path: Path,
+ output_dir: Path,
+ image_processor: ImageProcessorForDataProcessing,
+ audio_processor: AudioProcessor,
+ step: int) -> None:
+ """
+ Process a single video file.
+
+ Args:
+ video_path (Path): Path to the video file.
+ output_dir (Path): Directory to save the output.
+ image_processor (ImageProcessorForDataProcessing): Image processor object.
+ audio_processor (AudioProcessor): Audio processor object.
+ gpu_status (bool): Whether to use GPU for processing.
+ """
+ assert video_path.exists(), f"Video path {video_path} does not exist"
+ dirs = setup_directories(video_path)
+ logging.info(f"Processing video: {video_path}")
+
+ try:
+ if step == 1:
+ images_output_dir = output_dir / 'images' / video_path.stem
+ images_output_dir.mkdir(parents=True, exist_ok=True)
+ images_output_dir = convert_video_to_images(
+ video_path, images_output_dir)
+ logging.info(f"Images saved to: {images_output_dir}")
+
+ audio_output_dir = output_dir / 'audios'
+ audio_output_dir.mkdir(parents=True, exist_ok=True)
+ audio_output_path = audio_output_dir / f'{video_path.stem}.wav'
+ audio_output_path = extract_audio_from_videos(
+ video_path, audio_output_path)
+ logging.info(f"Audio extracted to: {audio_output_path}")
+
+ face_mask, _, sep_pose_mask, sep_face_mask, sep_lip_mask = image_processor.preprocess(
+ images_output_dir)
+ cv2.imwrite(
+ str(dirs["face_mask"] / f"{video_path.stem}.png"), face_mask)
+ cv2.imwrite(str(dirs["sep_pose_mask"] /
+ f"{video_path.stem}.png"), sep_pose_mask)
+ cv2.imwrite(str(dirs["sep_face_mask"] /
+ f"{video_path.stem}.png"), sep_face_mask)
+ cv2.imwrite(str(dirs["sep_lip_mask"] /
+ f"{video_path.stem}.png"), sep_lip_mask)
+ else:
+ images_dir = output_dir / "images" / video_path.stem
+ audio_path = output_dir / "audios" / f"{video_path.stem}.wav"
+ _, face_emb, _, _, _ = image_processor.preprocess(images_dir)
+ torch.save(face_emb, str(
+ dirs["face_emb"] / f"{video_path.stem}.pt"))
+ audio_emb, _ = audio_processor.preprocess(audio_path)
+ torch.save(audio_emb, str(
+ dirs["audio_emb"] / f"{video_path.stem}.pt"))
+ except Exception as e:
+ logging.error(f"Failed to process video {video_path}: {e}")
+
+
+def process_all_videos(input_video_list: List[Path], output_dir: Path, step: int) -> None:
+ """
+ Process all videos in the input list.
+
+ Args:
+ input_video_list (List[Path]): List of video paths to process.
+ output_dir (Path): Directory to save the output.
+ gpu_status (bool): Whether to use GPU for processing.
+ """
+ face_analysis_model_path = "pretrained_models/face_analysis"
+ landmark_model_path = "pretrained_models/face_analysis/models/face_landmarker_v2_with_blendshapes.task"
+ audio_separator_model_file = "pretrained_models/audio_separator/Kim_Vocal_2.onnx"
+ wav2vec_model_path = 'pretrained_models/wav2vec/wav2vec2-base-960h'
+
+ audio_processor = AudioProcessor(
+ 16000,
+ 25,
+ wav2vec_model_path,
+ False,
+ os.path.dirname(audio_separator_model_file),
+ os.path.basename(audio_separator_model_file),
+ os.path.join(output_dir, "vocals"),
+ ) if step==2 else None
+
+ image_processor = ImageProcessorForDataProcessing(
+ face_analysis_model_path, landmark_model_path, step)
+
+ for video_path in tqdm(input_video_list, desc="Processing videos"):
+ process_single_video(video_path, output_dir,
+ image_processor, audio_processor, step)
+
+
+def get_video_paths(source_dir: Path, parallelism: int, rank: int) -> List[Path]:
+ """
+ Get paths of videos to process, partitioned for parallel processing.
+
+ Args:
+ source_dir (Path): Source directory containing videos.
+ parallelism (int): Level of parallelism.
+ rank (int): Rank for distributed processing.
+
+ Returns:
+ List[Path]: List of video paths to process.
+ """
+ video_paths = [item for item in sorted(
+ source_dir.iterdir()) if item.is_file() and item.suffix == '.mp4']
+ return [video_paths[i] for i in range(len(video_paths)) if i % parallelism == rank]
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Process videos to prepare data for training. Run this script twice with different GPU status parameters."
+ )
+ parser.add_argument("-i", "--input_dir", type=Path,
+ required=True, help="Directory containing videos")
+ parser.add_argument("-o", "--output_dir", type=Path,
+ help="Directory to save results, default is parent dir of input dir")
+ parser.add_argument("-s", "--step", type=int, default=1,
+ help="Specify data processing step 1 or 2, you should run 1 and 2 sequently")
+ parser.add_argument("-p", "--parallelism", default=1,
+ type=int, help="Level of parallelism")
+ parser.add_argument("-r", "--rank", default=0, type=int,
+ help="Rank for distributed processing")
+
+ args = parser.parse_args()
+
+ if args.output_dir is None:
+ args.output_dir = args.input_dir.parent
+
+ video_path_list = get_video_paths(
+ args.input_dir, args.parallelism, args.rank)
+
+ if not video_path_list:
+ logging.warning("No videos to process.")
+ else:
+ process_all_videos(video_path_list, args.output_dir, args.step)
diff --git a/scripts/extract_meta_info_stage1.py b/scripts/extract_meta_info_stage1.py
new file mode 100644
index 00000000..936cb06c
--- /dev/null
+++ b/scripts/extract_meta_info_stage1.py
@@ -0,0 +1,106 @@
+# pylint: disable=R0801
+"""
+This module is used to extract meta information from video directories.
+
+It takes in two command-line arguments: `root_path` and `dataset_name`. The `root_path`
+specifies the path to the video directory, while the `dataset_name` specifies the name
+of the dataset. The module then collects all the video folder paths, and for each video
+folder, it checks if a mask path and a face embedding path exist. If they do, it appends
+a dictionary containing the image path, mask path, and face embedding path to a list.
+
+Finally, the module writes the list of dictionaries to a JSON file with the filename
+constructed using the `dataset_name`.
+
+Usage:
+ python tools/extract_meta_info_stage1.py --root_path /path/to/video_dir --dataset_name hdtf
+
+"""
+
+import argparse
+import json
+import os
+from pathlib import Path
+
+import torch
+
+
+def collect_video_folder_paths(root_path: Path) -> list:
+ """
+ Collect all video folder paths from the root path.
+
+ Args:
+ root_path (Path): The root directory containing video folders.
+
+ Returns:
+ list: List of video folder paths.
+ """
+ return [frames_dir.resolve() for frames_dir in root_path.iterdir() if frames_dir.is_dir()]
+
+
+def construct_meta_info(frames_dir_path: Path) -> dict:
+ """
+ Construct meta information for a given frames directory.
+
+ Args:
+ frames_dir_path (Path): The path to the frames directory.
+
+ Returns:
+ dict: A dictionary containing the meta information for the frames directory, or None if the required files do not exist.
+ """
+ mask_path = str(frames_dir_path).replace("images", "face_mask") + ".png"
+ face_emb_path = str(frames_dir_path).replace("images", "face_emb") + ".pt"
+
+ if not os.path.exists(mask_path):
+ print(f"Mask path not found: {mask_path}")
+ return None
+
+ if torch.load(face_emb_path) is None:
+ print(f"Face emb is None: {face_emb_path}")
+ return None
+
+ return {
+ "image_path": str(frames_dir_path),
+ "mask_path": mask_path,
+ "face_emb": face_emb_path,
+ }
+
+
+def main():
+ """
+ Main function to extract meta info for training.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-r", "--root_path", type=str,
+ required=True, help="Root path of the video directories")
+ parser.add_argument("-n", "--dataset_name", type=str,
+ required=True, help="Name of the dataset")
+ parser.add_argument("--meta_info_name", type=str,
+ help="Name of the meta information file")
+
+ args = parser.parse_args()
+
+ if args.meta_info_name is None:
+ args.meta_info_name = args.dataset_name
+
+ image_dir = Path(args.root_path) / "images"
+ output_dir = Path("./data")
+ output_dir.mkdir(exist_ok=True)
+
+ # Collect all video folder paths
+ frames_dir_paths = collect_video_folder_paths(image_dir)
+
+ meta_infos = []
+ for frames_dir_path in frames_dir_paths:
+ meta_info = construct_meta_info(frames_dir_path)
+ if meta_info:
+ meta_infos.append(meta_info)
+
+ output_file = output_dir / f"{args.meta_info_name}_stage1.json"
+ with output_file.open("w", encoding="utf-8") as f:
+ json.dump(meta_infos, f, indent=4)
+
+ print(f"Final data count: {len(meta_infos)}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/extract_meta_info_stage2.py b/scripts/extract_meta_info_stage2.py
new file mode 100644
index 00000000..e2d9301c
--- /dev/null
+++ b/scripts/extract_meta_info_stage2.py
@@ -0,0 +1,192 @@
+# pylint: disable=R0801
+"""
+This module is used to extract meta information from video files and store them in a JSON file.
+
+The script takes in command line arguments to specify the root path of the video files,
+the dataset name, and the name of the meta information file. It then generates a list of
+dictionaries containing the meta information for each video file and writes it to a JSON
+file with the specified name.
+
+The meta information includes the path to the video file, the mask path, the face mask
+path, the face mask union path, the face mask gaussian path, the lip mask path, the lip
+mask union path, the lip mask gaussian path, the separate mask border, the separate mask
+face, the separate mask lip, the face embedding path, the audio path, the vocals embedding
+base last path, the vocals embedding base all path, the vocals embedding base average
+path, the vocals embedding large last path, the vocals embedding large all path, and the
+vocals embedding large average path.
+
+The script checks if the mask path exists before adding the information to the list.
+
+Usage:
+ python tools/extract_meta_info_stage2.py --root_path --dataset_name --meta_info_name
+
+Example:
+ python tools/extract_meta_info_stage2.py --root_path data/videos_25fps --dataset_name my_dataset --meta_info_name my_meta_info
+"""
+
+import argparse
+import json
+import os
+from pathlib import Path
+
+import torch
+from decord import VideoReader, cpu
+from tqdm import tqdm
+
+
+def get_video_paths(root_path: Path, extensions: list) -> list:
+ """
+ Get a list of video paths from the root path with the specified extensions.
+
+ Args:
+ root_path (Path): The root directory containing video files.
+ extensions (list): List of file extensions to include.
+
+ Returns:
+ list: List of video file paths.
+ """
+ return [str(path.resolve()) for path in root_path.iterdir() if path.suffix in extensions]
+
+
+def file_exists(file_path: str) -> bool:
+ """
+ Check if a file exists.
+
+ Args:
+ file_path (str): The path to the file.
+
+ Returns:
+ bool: True if the file exists, False otherwise.
+ """
+ return os.path.exists(file_path)
+
+
+def construct_paths(video_path: str, base_dir: str, new_dir: str, new_ext: str) -> str:
+ """
+ Construct a new path by replacing the base directory and extension in the original path.
+
+ Args:
+ video_path (str): The original video path.
+ base_dir (str): The base directory to be replaced.
+ new_dir (str): The new directory to replace the base directory.
+ new_ext (str): The new file extension.
+
+ Returns:
+ str: The constructed path.
+ """
+ return str(video_path).replace(base_dir, new_dir).replace(".mp4", new_ext)
+
+
+def extract_meta_info(video_path: str) -> dict:
+ """
+ Extract meta information for a given video file.
+
+ Args:
+ video_path (str): The path to the video file.
+
+ Returns:
+ dict: A dictionary containing the meta information for the video.
+ """
+ mask_path = construct_paths(
+ video_path, "videos", "face_mask", ".png")
+ sep_mask_border = construct_paths(
+ video_path, "videos", "sep_pose_mask", ".png")
+ sep_mask_face = construct_paths(
+ video_path, "videos", "sep_face_mask", ".png")
+ sep_mask_lip = construct_paths(
+ video_path, "videos", "sep_lip_mask", ".png")
+ face_emb_path = construct_paths(
+ video_path, "videos", "face_emb", ".pt")
+ audio_path = construct_paths(video_path, "videos", "audios", ".wav")
+ vocal_emb_base_all = construct_paths(
+ video_path, "videos", "audio_emb", ".pt")
+
+ assert_flag = True
+
+ if not file_exists(mask_path):
+ print(f"Mask path not found: {mask_path}")
+ assert_flag = False
+ if not file_exists(sep_mask_border):
+ print(f"Separate mask border not found: {sep_mask_border}")
+ assert_flag = False
+ if not file_exists(sep_mask_face):
+ print(f"Separate mask face not found: {sep_mask_face}")
+ assert_flag = False
+ if not file_exists(sep_mask_lip):
+ print(f"Separate mask lip not found: {sep_mask_lip}")
+ assert_flag = False
+ if not file_exists(face_emb_path):
+ print(f"Face embedding path not found: {face_emb_path}")
+ assert_flag = False
+ if not file_exists(audio_path):
+ print(f"Audio path not found: {audio_path}")
+ assert_flag = False
+ if not file_exists(vocal_emb_base_all):
+ print(f"Vocal embedding base all not found: {vocal_emb_base_all}")
+ assert_flag = False
+
+ video_frames = VideoReader(video_path, ctx=cpu(0))
+ audio_emb = torch.load(vocal_emb_base_all)
+ if abs(len(video_frames) - audio_emb.shape[0]) > 3:
+ print(f"Frame count mismatch for video: {video_path}")
+ assert_flag = False
+
+ face_emb = torch.load(face_emb_path)
+ if face_emb is None:
+ print(f"Face embedding is None for video: {video_path}")
+ assert_flag = False
+
+ del video_frames, audio_emb
+
+ if assert_flag:
+ return {
+ "video_path": str(video_path),
+ "mask_path": mask_path,
+ "sep_mask_border": sep_mask_border,
+ "sep_mask_face": sep_mask_face,
+ "sep_mask_lip": sep_mask_lip,
+ "face_emb_path": face_emb_path,
+ "audio_path": audio_path,
+ "vocals_emb_base_all": vocal_emb_base_all,
+ }
+ return None
+
+
+def main():
+ """
+ Main function to extract meta info for training.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-r", "--root_path", type=str,
+ required=True, help="Root path of the video files")
+ parser.add_argument("-n", "--dataset_name", type=str,
+ required=True, help="Name of the dataset")
+ parser.add_argument("--meta_info_name", type=str,
+ help="Name of the meta information file")
+
+ args = parser.parse_args()
+
+ if args.meta_info_name is None:
+ args.meta_info_name = args.dataset_name
+
+ video_dir = Path(args.root_path) / "videos"
+ video_paths = get_video_paths(video_dir, [".mp4"])
+
+ meta_infos = []
+
+ for video_path in tqdm(video_paths, desc="Extracting meta info"):
+ meta_info = extract_meta_info(video_path)
+ if meta_info:
+ meta_infos.append(meta_info)
+
+ print(f"Final data count: {len(meta_infos)}")
+
+ output_file = Path(f"./data/{args.meta_info_name}_stage2.json")
+ output_file.parent.mkdir(parents=True, exist_ok=True)
+
+ with output_file.open("w", encoding="utf-8") as f:
+ json.dump(meta_infos, f, indent=4)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/train_stage1.py b/scripts/train_stage1.py
new file mode 100644
index 00000000..e9e7e847
--- /dev/null
+++ b/scripts/train_stage1.py
@@ -0,0 +1,793 @@
+# pylint: disable=E1101,C0415,W0718,R0801
+# scripts/train_stage1.py
+"""
+This is the main training script for stage 1 of the project.
+It imports necessary packages, defines necessary classes and functions, and trains the model using the provided configuration.
+
+The script includes the following classes and functions:
+
+1. Net: A PyTorch model that takes noisy latents, timesteps, reference image latents, face embeddings,
+ and face masks as input and returns the denoised latents.
+3. log_validation: A function that logs the validation information using the given VAE, image encoder,
+ network, scheduler, accelerator, width, height, and configuration.
+4. train_stage1_process: A function that processes the training stage 1 using the given configuration.
+
+The script also includes the necessary imports and a brief description of the purpose of the file.
+"""
+
+import argparse
+import copy
+import logging
+import math
+import os
+import random
+import warnings
+from datetime import datetime
+
+import cv2
+import diffusers
+import mlflow
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs
+from diffusers import AutoencoderKL, DDIMScheduler
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version
+from diffusers.utils.import_utils import is_xformers_available
+from insightface.app import FaceAnalysis
+from omegaconf import OmegaConf
+from PIL import Image
+from torch import nn
+from tqdm.auto import tqdm
+
+from hallo.animate.face_animate_static import StaticPipeline
+from hallo.datasets.mask_image import FaceMaskDataset
+from hallo.models.face_locator import FaceLocator
+from hallo.models.image_proj import ImageProjModel
+from hallo.models.mutual_self_attention import ReferenceAttentionControl
+from hallo.models.unet_2d_condition import UNet2DConditionModel
+from hallo.models.unet_3d import UNet3DConditionModel
+from hallo.utils.util import (compute_snr, delete_additional_ckpt,
+ import_filename, init_output_dir,
+ load_checkpoint, move_final_checkpoint,
+ save_checkpoint, seed_everything)
+
+warnings.filterwarnings("ignore")
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.10.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+class Net(nn.Module):
+ """
+ The Net class defines a neural network model that combines a reference UNet2DConditionModel,
+ a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image.
+
+ Args:
+ reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation.
+ denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation.
+ face_locator (FaceLocator): The face locator model used for face animation.
+ reference_control_writer: The reference control writer component.
+ reference_control_reader: The reference control reader component.
+ imageproj: The image projection model.
+
+ Forward method:
+ noisy_latents (torch.Tensor): The noisy latents tensor.
+ timesteps (torch.Tensor): The timesteps tensor.
+ ref_image_latents (torch.Tensor): The reference image latents tensor.
+ face_emb (torch.Tensor): The face embeddings tensor.
+ face_mask (torch.Tensor): The face mask tensor.
+ uncond_fwd (bool): A flag indicating whether to perform unconditional forward pass.
+
+ Returns:
+ torch.Tensor: The output tensor of the neural network model.
+ """
+
+ def __init__(
+ self,
+ reference_unet: UNet2DConditionModel,
+ denoising_unet: UNet3DConditionModel,
+ face_locator: FaceLocator,
+ reference_control_writer: ReferenceAttentionControl,
+ reference_control_reader: ReferenceAttentionControl,
+ imageproj: ImageProjModel,
+ ):
+ super().__init__()
+ self.reference_unet = reference_unet
+ self.denoising_unet = denoising_unet
+ self.face_locator = face_locator
+ self.reference_control_writer = reference_control_writer
+ self.reference_control_reader = reference_control_reader
+ self.imageproj = imageproj
+
+ def forward(
+ self,
+ noisy_latents,
+ timesteps,
+ ref_image_latents,
+ face_emb,
+ face_mask,
+ uncond_fwd: bool = False,
+ ):
+ """
+ Forward pass of the model.
+ Args:
+ self (Net): The model instance.
+ noisy_latents (torch.Tensor): Noisy latents.
+ timesteps (torch.Tensor): Timesteps.
+ ref_image_latents (torch.Tensor): Reference image latents.
+ face_emb (torch.Tensor): Face embedding.
+ face_mask (torch.Tensor): Face mask.
+ uncond_fwd (bool, optional): Unconditional forward pass. Defaults to False.
+
+ Returns:
+ torch.Tensor: Model prediction.
+ """
+
+ face_emb = self.imageproj(face_emb)
+ face_mask = face_mask.to(device="cuda")
+ face_mask_feature = self.face_locator(face_mask)
+
+ if not uncond_fwd:
+ ref_timesteps = torch.zeros_like(timesteps)
+ self.reference_unet(
+ ref_image_latents,
+ ref_timesteps,
+ encoder_hidden_states=face_emb,
+ return_dict=False,
+ )
+ self.reference_control_reader.update(self.reference_control_writer)
+ model_pred = self.denoising_unet(
+ noisy_latents,
+ timesteps,
+ mask_cond_fea=face_mask_feature,
+ encoder_hidden_states=face_emb,
+ ).sample
+
+ return model_pred
+
+
+def get_noise_scheduler(cfg: argparse.Namespace):
+ """
+ Create noise scheduler for training
+
+ Args:
+ cfg (omegaconf.dictconfig.DictConfig): Configuration object.
+
+ Returns:
+ train noise scheduler and val noise scheduler
+ """
+ sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
+ if cfg.enable_zero_snr:
+ sched_kwargs.update(
+ rescale_betas_zero_snr=True,
+ timestep_spacing="trailing",
+ prediction_type="v_prediction",
+ )
+ val_noise_scheduler = DDIMScheduler(**sched_kwargs)
+ sched_kwargs.update({"beta_schedule": "scaled_linear"})
+ train_noise_scheduler = DDIMScheduler(**sched_kwargs)
+
+ return train_noise_scheduler, val_noise_scheduler
+
+
+def log_validation(
+ vae,
+ net,
+ scheduler,
+ accelerator,
+ width,
+ height,
+ imageproj,
+ cfg,
+ save_dir,
+ global_step,
+ face_analysis_model_path,
+):
+ """
+ Log validation generation image.
+
+ Args:
+ vae (nn.Module): Variational Autoencoder model.
+ net (Net): Main model.
+ scheduler (diffusers.SchedulerMixin): Noise scheduler.
+ accelerator (accelerate.Accelerator): Accelerator for training.
+ width (int): Width of the input images.
+ height (int): Height of the input images.
+ imageproj (nn.Module): Image projection model.
+ cfg (omegaconf.dictconfig.DictConfig): Configuration object.
+ save_dir (str): directory path to save log result.
+ global_step (int): Global step number.
+
+ Returns:
+ None
+ """
+ logger.info("Running validation... ")
+
+ ori_net = accelerator.unwrap_model(net)
+ ori_net = copy.deepcopy(ori_net)
+ reference_unet = ori_net.reference_unet
+ denoising_unet = ori_net.denoising_unet
+ face_locator = ori_net.face_locator
+
+ generator = torch.manual_seed(42)
+ image_enc = FaceAnalysis(
+ name="",
+ root=face_analysis_model_path,
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
+ )
+ image_enc.prepare(ctx_id=0, det_size=(640, 640))
+
+ pipe = StaticPipeline(
+ vae=vae,
+ reference_unet=reference_unet,
+ denoising_unet=denoising_unet,
+ face_locator=face_locator,
+ scheduler=scheduler,
+ imageproj=imageproj,
+ )
+
+ pil_images = []
+ for ref_image_path, mask_image_path in zip(cfg.ref_image_paths, cfg.mask_image_paths):
+ # for mask_image_path in mask_image_paths:
+ mask_name = os.path.splitext(
+ os.path.basename(mask_image_path))[0]
+ ref_name = os.path.splitext(
+ os.path.basename(ref_image_path))[0]
+ ref_image_pil = Image.open(ref_image_path).convert("RGB")
+ mask_image_pil = Image.open(mask_image_path).convert("RGB")
+
+ # Prepare face embeds
+ face_info = image_enc.get(
+ cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR))
+ face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (
+ x['bbox'][3] - x['bbox'][1]))[-1] # only use the maximum face
+ face_emb = torch.tensor(face_info['embedding'])
+ face_emb = face_emb.to(
+ imageproj.device, imageproj.dtype)
+
+ image = pipe(
+ ref_image_pil,
+ mask_image_pil,
+ width,
+ height,
+ 20,
+ 3.5,
+ face_emb,
+ generator=generator,
+ ).images
+ image = image[0, :, 0].permute(1, 2, 0).cpu().numpy() # (3, 512, 512)
+ res_image_pil = Image.fromarray((image * 255).astype(np.uint8))
+ # Save ref_image, src_image and the generated_image
+ w, h = res_image_pil.size
+ canvas = Image.new("RGB", (w * 3, h), "white")
+ ref_image_pil = ref_image_pil.resize((w, h))
+ mask_image_pil = mask_image_pil.resize((w, h))
+ canvas.paste(ref_image_pil, (0, 0))
+ canvas.paste(mask_image_pil, (w, 0))
+ canvas.paste(res_image_pil, (w * 2, 0))
+
+ out_file = os.path.join(
+ save_dir, f"{global_step:06d}-{ref_name}_{mask_name}.jpg"
+ )
+ canvas.save(out_file)
+
+ del pipe
+ del ori_net
+ torch.cuda.empty_cache()
+
+ return pil_images
+
+
+def train_stage1_process(cfg: argparse.Namespace) -> None:
+ """
+ Trains the model using the given configuration (cfg).
+
+ Args:
+ cfg (dict): The configuration dictionary containing the parameters for training.
+
+ Notes:
+ - This function trains the model using the given configuration.
+ - It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler.
+ - The training progress is logged and tracked using the accelerator.
+ - The trained model is saved after the training is completed.
+ """
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
+ mixed_precision=cfg.solver.mixed_precision,
+ log_with="mlflow",
+ project_dir="./mlruns",
+ kwargs_handlers=[kwargs],
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if cfg.seed is not None:
+ seed_everything(cfg.seed)
+
+ # create output dir for training
+ exp_name = cfg.exp_name
+ save_dir = f"{cfg.output_dir}/{exp_name}"
+ checkpoint_dir = os.path.join(save_dir, "checkpoints")
+ module_dir = os.path.join(save_dir, "modules")
+ validation_dir = os.path.join(save_dir, "validation")
+
+ if accelerator.is_main_process:
+ init_output_dir([save_dir, checkpoint_dir, module_dir, validation_dir])
+
+ accelerator.wait_for_everyone()
+
+ # create model
+ if cfg.weight_dtype == "fp16":
+ weight_dtype = torch.float16
+ elif cfg.weight_dtype == "bf16":
+ weight_dtype = torch.bfloat16
+ elif cfg.weight_dtype == "fp32":
+ weight_dtype = torch.float32
+ else:
+ raise ValueError(
+ f"Do not support weight dtype: {cfg.weight_dtype} during training"
+ )
+
+ # create model
+ vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
+ "cuda", dtype=weight_dtype
+ )
+ reference_unet = UNet2DConditionModel.from_pretrained(
+ cfg.base_model_path,
+ subfolder="unet",
+ ).to(device="cuda", dtype=weight_dtype)
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
+ cfg.base_model_path,
+ "",
+ subfolder="unet",
+ unet_additional_kwargs={
+ "use_motion_module": False,
+ "unet_use_temporal_attention": False,
+ },
+ use_landmark=False
+ ).to(device="cuda", dtype=weight_dtype)
+ imageproj = ImageProjModel(
+ cross_attention_dim=denoising_unet.config.cross_attention_dim,
+ clip_embeddings_dim=512,
+ clip_extra_context_tokens=4,
+ ).to(device="cuda", dtype=weight_dtype)
+
+ if cfg.face_locator_pretrained:
+ face_locator = FaceLocator(
+ conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)
+ ).to(device="cuda", dtype=weight_dtype)
+ miss, _ = face_locator.load_state_dict(
+ cfg.face_state_dict_path, strict=False)
+ logger.info(f"Missing key for face locator: {len(miss)}")
+ else:
+ face_locator = FaceLocator(
+ conditioning_embedding_channels=320,
+ ).to(device="cuda", dtype=weight_dtype)
+ # Freeze
+ vae.requires_grad_(False)
+ denoising_unet.requires_grad_(True)
+ reference_unet.requires_grad_(True)
+ imageproj.requires_grad_(True)
+ face_locator.requires_grad_(True)
+
+ reference_control_writer = ReferenceAttentionControl(
+ reference_unet,
+ do_classifier_free_guidance=False,
+ mode="write",
+ fusion_blocks="full",
+ )
+ reference_control_reader = ReferenceAttentionControl(
+ denoising_unet,
+ do_classifier_free_guidance=False,
+ mode="read",
+ fusion_blocks="full",
+ )
+
+ net = Net(
+ reference_unet,
+ denoising_unet,
+ face_locator,
+ reference_control_writer,
+ reference_control_reader,
+ imageproj,
+ ).to(dtype=weight_dtype)
+
+ # get noise scheduler
+ train_noise_scheduler, val_noise_scheduler = get_noise_scheduler(cfg)
+
+ # init optimizer
+ if cfg.solver.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ reference_unet.enable_xformers_memory_efficient_attention()
+ denoising_unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError(
+ "xformers is not available. Make sure it is installed correctly"
+ )
+
+ if cfg.solver.gradient_checkpointing:
+ reference_unet.enable_gradient_checkpointing()
+ denoising_unet.enable_gradient_checkpointing()
+
+ if cfg.solver.scale_lr:
+ learning_rate = (
+ cfg.solver.learning_rate
+ * cfg.solver.gradient_accumulation_steps
+ * cfg.data.train_bs
+ * accelerator.num_processes
+ )
+ else:
+ learning_rate = cfg.solver.learning_rate
+
+ # Initialize the optimizer
+ if cfg.solver.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError as exc:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ ) from exc
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ trainable_params = list(
+ filter(lambda p: p.requires_grad, net.parameters()))
+ optimizer = optimizer_cls(
+ trainable_params,
+ lr=learning_rate,
+ betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
+ weight_decay=cfg.solver.adam_weight_decay,
+ eps=cfg.solver.adam_epsilon,
+ )
+
+ # init scheduler
+ lr_scheduler = get_scheduler(
+ cfg.solver.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=cfg.solver.lr_warmup_steps
+ * cfg.solver.gradient_accumulation_steps,
+ num_training_steps=cfg.solver.max_train_steps
+ * cfg.solver.gradient_accumulation_steps,
+ )
+
+ # get data loader
+ train_dataset = FaceMaskDataset(
+ img_size=(cfg.data.train_width, cfg.data.train_height),
+ data_meta_paths=cfg.data.meta_paths,
+ sample_margin=cfg.data.sample_margin,
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4
+ )
+
+ # Prepare everything with our `accelerator`.
+ (
+ net,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ ) = accelerator.prepare(
+ net,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ )
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(
+ len(train_dataloader) / cfg.solver.gradient_accumulation_steps
+ )
+ # Afterwards we recalculate our number of training epochs
+ num_train_epochs = math.ceil(
+ cfg.solver.max_train_steps / num_update_steps_per_epoch
+ )
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ run_time = datetime.now().strftime("%Y%m%d-%H%M")
+ accelerator.init_trackers(
+ cfg.exp_name,
+ init_kwargs={"mlflow": {"run_name": run_time}},
+ )
+ # dump config file
+ mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml")
+
+ logger.info(f"save config to {save_dir}")
+ OmegaConf.save(
+ cfg, os.path.join(save_dir, "config.yaml")
+ )
+ # Train!
+ total_batch_size = (
+ cfg.data.train_bs
+ * accelerator.num_processes
+ * cfg.solver.gradient_accumulation_steps
+ )
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}")
+ logger.info(
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
+ )
+ logger.info(
+ f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}"
+ )
+ logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # load checkpoint
+ # Potentially load in the weights and states from a previous save
+ if cfg.resume_from_checkpoint:
+ logger.info(f"Loading checkpoint from {checkpoint_dir}")
+ global_step = load_checkpoint(cfg, checkpoint_dir, accelerator)
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(
+ range(global_step, cfg.solver.max_train_steps),
+ disable=not accelerator.is_main_process,
+ )
+ progress_bar.set_description("Steps")
+ net.train()
+ for _ in range(first_epoch, num_train_epochs):
+ train_loss = 0.0
+ for _, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(net):
+ # Convert videos to latent space
+ pixel_values = batch["img"].to(weight_dtype)
+ with torch.no_grad():
+ latents = vae.encode(pixel_values).latent_dist.sample()
+ latents = latents.unsqueeze(2) # (b, c, 1, h, w)
+ latents = latents * 0.18215
+
+ noise = torch.randn_like(latents)
+ if cfg.noise_offset > 0.0:
+ noise += cfg.noise_offset * torch.randn(
+ (noise.shape[0], noise.shape[1], 1, 1, 1),
+ device=noise.device,
+ )
+
+ bsz = latents.shape[0]
+ # Sample a random timestep for each video
+ timesteps = torch.randint(
+ 0,
+ train_noise_scheduler.num_train_timesteps,
+ (bsz,),
+ device=latents.device,
+ )
+ timesteps = timesteps.long()
+
+ face_mask_img = batch["tgt_mask"]
+ face_mask_img = face_mask_img.unsqueeze(
+ 2)
+ face_mask_img = face_mask_img.to(weight_dtype)
+
+ uncond_fwd = random.random() < cfg.uncond_ratio
+ face_emb_list = []
+ ref_image_list = []
+ for _, (ref_img, face_emb) in enumerate(
+ zip(batch["ref_img"], batch["face_emb"])
+ ):
+ if uncond_fwd:
+ face_emb_list.append(torch.zeros_like(face_emb))
+ else:
+ face_emb_list.append(face_emb)
+ ref_image_list.append(ref_img)
+
+ with torch.no_grad():
+ ref_img = torch.stack(ref_image_list, dim=0).to(
+ dtype=vae.dtype, device=vae.device
+ )
+ ref_image_latents = vae.encode(
+ ref_img
+ ).latent_dist.sample()
+ ref_image_latents = ref_image_latents * 0.18215
+
+ face_emb = torch.stack(face_emb_list, dim=0).to(
+ dtype=imageproj.dtype, device=imageproj.device
+ )
+
+ # add noise
+ noisy_latents = train_noise_scheduler.add_noise(
+ latents, noise, timesteps
+ )
+
+ # Get the target for loss depending on the prediction type
+ if train_noise_scheduler.prediction_type == "epsilon":
+ target = noise
+ elif train_noise_scheduler.prediction_type == "v_prediction":
+ target = train_noise_scheduler.get_velocity(
+ latents, noise, timesteps
+ )
+ else:
+ raise ValueError(
+ f"Unknown prediction type {train_noise_scheduler.prediction_type}"
+ )
+ model_pred = net(
+ noisy_latents,
+ timesteps,
+ ref_image_latents,
+ face_emb,
+ face_mask_img,
+ uncond_fwd,
+ )
+
+ if cfg.snr_gamma == 0:
+ loss = F.mse_loss(
+ model_pred.float(), target.float(), reduction="mean"
+ )
+ else:
+ snr = compute_snr(train_noise_scheduler, timesteps)
+ if train_noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack(
+ [snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
+ ).min(dim=1)[0]
+ / snr
+ )
+ loss = F.mse_loss(
+ model_pred.float(), target.float(), reduction="none"
+ )
+ loss = (
+ loss.mean(dim=list(range(1, len(loss.shape))))
+ * mse_loss_weights
+ )
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(
+ loss.repeat(cfg.data.train_bs)).mean()
+ train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(
+ trainable_params,
+ cfg.solver.max_grad_norm,
+ )
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ if accelerator.sync_gradients:
+ reference_control_reader.clear()
+ reference_control_writer.clear()
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+ if global_step % cfg.checkpointing_steps == 0 or global_step == cfg.solver.max_train_steps:
+ accelerator.wait_for_everyone()
+ save_path = os.path.join(
+ checkpoint_dir, f"checkpoint-{global_step}")
+ if accelerator.is_main_process:
+ delete_additional_ckpt(checkpoint_dir, 3)
+ accelerator.save_state(save_path)
+ accelerator.wait_for_everyone()
+ unwrap_net = accelerator.unwrap_model(net)
+ if accelerator.is_main_process:
+ save_checkpoint(
+ unwrap_net.reference_unet,
+ module_dir,
+ "reference_unet",
+ global_step,
+ total_limit=3,
+ )
+ save_checkpoint(
+ unwrap_net.imageproj,
+ module_dir,
+ "imageproj",
+ global_step,
+ total_limit=3,
+ )
+ save_checkpoint(
+ unwrap_net.denoising_unet,
+ module_dir,
+ "denoising_unet",
+ global_step,
+ total_limit=3,
+ )
+ save_checkpoint(
+ unwrap_net.face_locator,
+ module_dir,
+ "face_locator",
+ global_step,
+ total_limit=3,
+ )
+
+ if global_step % cfg.val.validation_steps == 0 or global_step == 1:
+ if accelerator.is_main_process:
+ generator = torch.Generator(device=accelerator.device)
+ generator.manual_seed(cfg.seed)
+ log_validation(
+ vae=vae,
+ net=net,
+ scheduler=val_noise_scheduler,
+ accelerator=accelerator,
+ width=cfg.data.train_width,
+ height=cfg.data.train_height,
+ imageproj=imageproj,
+ cfg=cfg,
+ save_dir=validation_dir,
+ global_step=global_step,
+ face_analysis_model_path=cfg.face_analysis_model_path
+ )
+
+ logs = {
+ "step_loss": loss.detach().item(),
+ "lr": lr_scheduler.get_last_lr()[0],
+ }
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= cfg.solver.max_train_steps:
+ # process final module weight for stage2
+ if accelerator.is_main_process:
+ move_final_checkpoint(save_dir, module_dir, "reference_unet")
+ move_final_checkpoint(save_dir, module_dir, "imageproj")
+ move_final_checkpoint(save_dir, module_dir, "denoising_unet")
+ move_final_checkpoint(save_dir, module_dir, "face_locator")
+ break
+
+ accelerator.wait_for_everyone()
+ accelerator.end_training()
+
+
+def load_config(config_path: str) -> dict:
+ """
+ Loads the configuration file.
+
+ Args:
+ config_path (str): Path to the configuration file.
+
+ Returns:
+ dict: The configuration dictionary.
+ """
+
+ if config_path.endswith(".yaml"):
+ return OmegaConf.load(config_path)
+ if config_path.endswith(".py"):
+ return import_filename(config_path).cfg
+ raise ValueError("Unsupported format for config file")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str,
+ default="./configs/train/stage1.yaml")
+ args = parser.parse_args()
+
+ try:
+ config = load_config(args.config)
+ train_stage1_process(config)
+ except Exception as e:
+ logging.error("Failed to execute the training process: %s", e)
diff --git a/scripts/train_stage2.py b/scripts/train_stage2.py
new file mode 100644
index 00000000..8cff266d
--- /dev/null
+++ b/scripts/train_stage2.py
@@ -0,0 +1,991 @@
+# pylint: disable=E1101,C0415,W0718,R0801
+# scripts/train_stage2.py
+"""
+This is the main training script for stage 2 of the project.
+It imports necessary packages, defines necessary classes and functions, and trains the model using the provided configuration.
+
+The script includes the following classes and functions:
+
+1. Net: A PyTorch model that takes noisy latents, timesteps, reference image latents, face embeddings,
+ and face masks as input and returns the denoised latents.
+2. get_attention_mask: A function that rearranges the mask tensors to the required format.
+3. get_noise_scheduler: A function that creates and returns the noise schedulers for training and validation.
+4. process_audio_emb: A function that processes the audio embeddings to concatenate with other tensors.
+5. log_validation: A function that logs the validation information using the given VAE, image encoder,
+ network, scheduler, accelerator, width, height, and configuration.
+6. train_stage2_process: A function that processes the training stage 2 using the given configuration.
+7. load_config: A function that loads the configuration file from the given path.
+
+The script also includes the necessary imports and a brief description of the purpose of the file.
+"""
+
+import argparse
+import copy
+import logging
+import math
+import os
+import random
+import time
+import warnings
+from datetime import datetime
+from typing import List, Tuple
+
+import diffusers
+import mlflow
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs
+from diffusers import AutoencoderKL, DDIMScheduler
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version
+from diffusers.utils.import_utils import is_xformers_available
+from einops import rearrange, repeat
+from omegaconf import OmegaConf
+from torch import nn
+from tqdm.auto import tqdm
+
+from hallo.animate.face_animate import FaceAnimatePipeline
+from hallo.datasets.audio_processor import AudioProcessor
+from hallo.datasets.image_processor import ImageProcessor
+from hallo.datasets.talk_video import TalkingVideoDataset
+from hallo.models.audio_proj import AudioProjModel
+from hallo.models.face_locator import FaceLocator
+from hallo.models.image_proj import ImageProjModel
+from hallo.models.mutual_self_attention import ReferenceAttentionControl
+from hallo.models.unet_2d_condition import UNet2DConditionModel
+from hallo.models.unet_3d import UNet3DConditionModel
+from hallo.utils.util import (compute_snr, delete_additional_ckpt,
+ import_filename, init_output_dir,
+ load_checkpoint, save_checkpoint,
+ seed_everything, tensor_to_video)
+
+warnings.filterwarnings("ignore")
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.10.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+class Net(nn.Module):
+ """
+ The Net class defines a neural network model that combines a reference UNet2DConditionModel,
+ a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image.
+
+ Args:
+ reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation.
+ denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation.
+ face_locator (FaceLocator): The face locator model used for face animation.
+ reference_control_writer: The reference control writer component.
+ reference_control_reader: The reference control reader component.
+ imageproj: The image projection model.
+ audioproj: The audio projection model.
+
+ Forward method:
+ noisy_latents (torch.Tensor): The noisy latents tensor.
+ timesteps (torch.Tensor): The timesteps tensor.
+ ref_image_latents (torch.Tensor): The reference image latents tensor.
+ face_emb (torch.Tensor): The face embeddings tensor.
+ audio_emb (torch.Tensor): The audio embeddings tensor.
+ mask (torch.Tensor): Hard face mask for face locator.
+ full_mask (torch.Tensor): Pose Mask.
+ face_mask (torch.Tensor): Face Mask
+ lip_mask (torch.Tensor): Lip Mask
+ uncond_img_fwd (bool): A flag indicating whether to perform reference image unconditional forward pass.
+ uncond_audio_fwd (bool): A flag indicating whether to perform audio unconditional forward pass.
+
+ Returns:
+ torch.Tensor: The output tensor of the neural network model.
+ """
+ def __init__(
+ self,
+ reference_unet: UNet2DConditionModel,
+ denoising_unet: UNet3DConditionModel,
+ face_locator: FaceLocator,
+ reference_control_writer,
+ reference_control_reader,
+ imageproj,
+ audioproj,
+ ):
+ super().__init__()
+ self.reference_unet = reference_unet
+ self.denoising_unet = denoising_unet
+ self.face_locator = face_locator
+ self.reference_control_writer = reference_control_writer
+ self.reference_control_reader = reference_control_reader
+ self.imageproj = imageproj
+ self.audioproj = audioproj
+
+ def forward(
+ self,
+ noisy_latents: torch.Tensor,
+ timesteps: torch.Tensor,
+ ref_image_latents: torch.Tensor,
+ face_emb: torch.Tensor,
+ audio_emb: torch.Tensor,
+ mask: torch.Tensor,
+ full_mask: torch.Tensor,
+ face_mask: torch.Tensor,
+ lip_mask: torch.Tensor,
+ uncond_img_fwd: bool = False,
+ uncond_audio_fwd: bool = False,
+ ):
+ """
+ simple docstring to prevent pylint error
+ """
+ face_emb = self.imageproj(face_emb)
+ mask = mask.to(device="cuda")
+ mask_feature = self.face_locator(mask)
+ audio_emb = audio_emb.to(
+ device=self.audioproj.device, dtype=self.audioproj.dtype)
+ audio_emb = self.audioproj(audio_emb)
+
+ # condition forward
+ if not uncond_img_fwd:
+ ref_timesteps = torch.zeros_like(timesteps)
+ ref_timesteps = repeat(
+ ref_timesteps,
+ "b -> (repeat b)",
+ repeat=ref_image_latents.size(0) // ref_timesteps.size(0),
+ )
+ self.reference_unet(
+ ref_image_latents,
+ ref_timesteps,
+ encoder_hidden_states=face_emb,
+ return_dict=False,
+ )
+ self.reference_control_reader.update(self.reference_control_writer)
+
+ if uncond_audio_fwd:
+ audio_emb = torch.zeros_like(audio_emb).to(
+ device=audio_emb.device, dtype=audio_emb.dtype
+ )
+
+ model_pred = self.denoising_unet(
+ noisy_latents,
+ timesteps,
+ mask_cond_fea=mask_feature,
+ encoder_hidden_states=face_emb,
+ audio_embedding=audio_emb,
+ full_mask=full_mask,
+ face_mask=face_mask,
+ lip_mask=lip_mask
+ ).sample
+
+ return model_pred
+
+
+def get_attention_mask(mask: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor:
+ """
+ Rearrange the mask tensors to the required format.
+
+ Args:
+ mask (torch.Tensor): The input mask tensor.
+ weight_dtype (torch.dtype): The data type for the mask tensor.
+
+ Returns:
+ torch.Tensor: The rearranged mask tensor.
+ """
+ if isinstance(mask, List):
+ _mask = []
+ for m in mask:
+ _mask.append(
+ rearrange(m, "b f 1 h w -> (b f) (h w)").to(weight_dtype))
+ return _mask
+ mask = rearrange(mask, "b f 1 h w -> (b f) (h w)").to(weight_dtype)
+ return mask
+
+
+def get_noise_scheduler(cfg: argparse.Namespace) -> Tuple[DDIMScheduler, DDIMScheduler]:
+ """
+ Create noise scheduler for training.
+
+ Args:
+ cfg (argparse.Namespace): Configuration object.
+
+ Returns:
+ Tuple[DDIMScheduler, DDIMScheduler]: Train noise scheduler and validation noise scheduler.
+ """
+
+ sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
+ if cfg.enable_zero_snr:
+ sched_kwargs.update(
+ rescale_betas_zero_snr=True,
+ timestep_spacing="trailing",
+ prediction_type="v_prediction",
+ )
+ val_noise_scheduler = DDIMScheduler(**sched_kwargs)
+ sched_kwargs.update({"beta_schedule": "scaled_linear"})
+ train_noise_scheduler = DDIMScheduler(**sched_kwargs)
+
+ return train_noise_scheduler, val_noise_scheduler
+
+
+def process_audio_emb(audio_emb: torch.Tensor) -> torch.Tensor:
+ """
+ Process the audio embedding to concatenate with other tensors.
+
+ Parameters:
+ audio_emb (torch.Tensor): The audio embedding tensor to process.
+
+ Returns:
+ concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
+ """
+ concatenated_tensors = []
+
+ for i in range(audio_emb.shape[0]):
+ vectors_to_concat = [
+ audio_emb[max(min(i + j, audio_emb.shape[0] - 1), 0)]for j in range(-2, 3)]
+ concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
+
+ audio_emb = torch.stack(concatenated_tensors, dim=0)
+
+ return audio_emb
+
+
+def log_validation(
+ accelerator: Accelerator,
+ vae: AutoencoderKL,
+ net: Net,
+ scheduler: DDIMScheduler,
+ width: int,
+ height: int,
+ clip_length: int = 24,
+ generator: torch.Generator = None,
+ cfg: dict = None,
+ save_dir: str = None,
+ global_step: int = 0,
+ times: int = None,
+ face_analysis_model_path: str = "",
+) -> None:
+ """
+ Log validation video during the training process.
+
+ Args:
+ accelerator (Accelerator): The accelerator for distributed training.
+ vae (AutoencoderKL): The autoencoder model.
+ net (Net): The main neural network model.
+ scheduler (DDIMScheduler): The scheduler for noise.
+ width (int): The width of the input images.
+ height (int): The height of the input images.
+ clip_length (int): The length of the video clips. Defaults to 24.
+ generator (torch.Generator): The random number generator. Defaults to None.
+ cfg (dict): The configuration dictionary. Defaults to None.
+ save_dir (str): The directory to save validation results. Defaults to None.
+ global_step (int): The current global step in training. Defaults to 0.
+ times (int): The number of inference times. Defaults to None.
+ face_analysis_model_path (str): The path to the face analysis model. Defaults to "".
+
+ Returns:
+ torch.Tensor: The tensor result of the validation.
+ """
+ ori_net = accelerator.unwrap_model(net)
+ reference_unet = ori_net.reference_unet
+ denoising_unet = ori_net.denoising_unet
+ face_locator = ori_net.face_locator
+ imageproj = ori_net.imageproj
+ audioproj = ori_net.audioproj
+
+ generator = torch.manual_seed(42)
+ tmp_denoising_unet = copy.deepcopy(denoising_unet)
+
+ pipeline = FaceAnimatePipeline(
+ vae=vae,
+ reference_unet=reference_unet,
+ denoising_unet=tmp_denoising_unet,
+ face_locator=face_locator,
+ image_proj=imageproj,
+ scheduler=scheduler,
+ )
+ pipeline = pipeline.to("cuda")
+
+ image_processor = ImageProcessor((width, height), face_analysis_model_path)
+ audio_processor = AudioProcessor(
+ cfg.data.sample_rate,
+ cfg.data.fps,
+ cfg.wav2vec_config.model_path,
+ cfg.wav2vec_config.features == "last",
+ os.path.dirname(cfg.audio_separator.model_path),
+ os.path.basename(cfg.audio_separator.model_path),
+ os.path.join(save_dir, '.cache', "audio_preprocess")
+ )
+
+ for idx, ref_img_path in enumerate(cfg.ref_img_path):
+ audio_path = cfg.audio_path[idx]
+ source_image_pixels, \
+ source_image_face_region, \
+ source_image_face_emb, \
+ source_image_full_mask, \
+ source_image_face_mask, \
+ source_image_lip_mask = image_processor.preprocess(
+ ref_img_path, os.path.join(save_dir, '.cache'), cfg.face_expand_ratio)
+ audio_emb, audio_length = audio_processor.preprocess(
+ audio_path, clip_length)
+
+ audio_emb = process_audio_emb(audio_emb)
+
+ source_image_pixels = source_image_pixels.unsqueeze(0)
+ source_image_face_region = source_image_face_region.unsqueeze(0)
+ source_image_face_emb = source_image_face_emb.reshape(1, -1)
+ source_image_face_emb = torch.tensor(source_image_face_emb)
+
+ source_image_full_mask = [
+ (mask.repeat(clip_length, 1))
+ for mask in source_image_full_mask
+ ]
+ source_image_face_mask = [
+ (mask.repeat(clip_length, 1))
+ for mask in source_image_face_mask
+ ]
+ source_image_lip_mask = [
+ (mask.repeat(clip_length, 1))
+ for mask in source_image_lip_mask
+ ]
+
+ times = audio_emb.shape[0] // clip_length
+ tensor_result = []
+ generator = torch.manual_seed(42)
+ for t in range(times):
+ print(f"[{t+1}/{times}]")
+
+ if len(tensor_result) == 0:
+ # The first iteration
+ motion_zeros = source_image_pixels.repeat(
+ cfg.data.n_motion_frames, 1, 1, 1)
+ motion_zeros = motion_zeros.to(
+ dtype=source_image_pixels.dtype, device=source_image_pixels.device)
+ pixel_values_ref_img = torch.cat(
+ [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames
+ else:
+ motion_frames = tensor_result[-1][0]
+ motion_frames = motion_frames.permute(1, 0, 2, 3)
+ motion_frames = motion_frames[0 - cfg.data.n_motion_frames:]
+ motion_frames = motion_frames * 2.0 - 1.0
+ motion_frames = motion_frames.to(
+ dtype=source_image_pixels.dtype, device=source_image_pixels.device)
+ pixel_values_ref_img = torch.cat(
+ [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames
+
+ pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
+
+ audio_tensor = audio_emb[
+ t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
+ ]
+ audio_tensor = audio_tensor.unsqueeze(0)
+ audio_tensor = audio_tensor.to(
+ device=audioproj.device, dtype=audioproj.dtype)
+ audio_tensor = audioproj(audio_tensor)
+
+ pipeline_output = pipeline(
+ ref_image=pixel_values_ref_img,
+ audio_tensor=audio_tensor,
+ face_emb=source_image_face_emb,
+ face_mask=source_image_face_region,
+ pixel_values_full_mask=source_image_full_mask,
+ pixel_values_face_mask=source_image_face_mask,
+ pixel_values_lip_mask=source_image_lip_mask,
+ width=cfg.data.train_width,
+ height=cfg.data.train_height,
+ video_length=clip_length,
+ num_inference_steps=cfg.inference_steps,
+ guidance_scale=cfg.cfg_scale,
+ generator=generator,
+ )
+
+ tensor_result.append(pipeline_output.videos)
+
+ tensor_result = torch.cat(tensor_result, dim=2)
+ tensor_result = tensor_result.squeeze(0)
+ tensor_result = tensor_result[:, :audio_length]
+ audio_name = os.path.basename(audio_path).split('.')[0]
+ ref_name = os.path.basename(ref_img_path).split('.')[0]
+ output_file = os.path.join(save_dir,f"{global_step}_{ref_name}_{audio_name}.mp4")
+ # save the result after all iteration
+ tensor_to_video(tensor_result, output_file, audio_path)
+
+
+ # clean up
+ del tmp_denoising_unet
+ del pipeline
+ del image_processor
+ del audio_processor
+ torch.cuda.empty_cache()
+
+ return tensor_result
+
+
+def train_stage2_process(cfg: argparse.Namespace) -> None:
+ """
+ Trains the model using the given configuration (cfg).
+
+ Args:
+ cfg (dict): The configuration dictionary containing the parameters for training.
+
+ Notes:
+ - This function trains the model using the given configuration.
+ - It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler.
+ - The training progress is logged and tracked using the accelerator.
+ - The trained model is saved after the training is completed.
+ """
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
+ mixed_precision=cfg.solver.mixed_precision,
+ log_with="mlflow",
+ project_dir="./mlruns",
+ kwargs_handlers=[kwargs],
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if cfg.seed is not None:
+ seed_everything(cfg.seed)
+
+ # create output dir for training
+ exp_name = cfg.exp_name
+ save_dir = f"{cfg.output_dir}/{exp_name}"
+ checkpoint_dir = os.path.join(save_dir, "checkpoints")
+ module_dir = os.path.join(save_dir, "modules")
+ validation_dir = os.path.join(save_dir, "validation")
+ if accelerator.is_main_process:
+ init_output_dir([save_dir, checkpoint_dir, module_dir, validation_dir])
+
+ accelerator.wait_for_everyone()
+
+ if cfg.weight_dtype == "fp16":
+ weight_dtype = torch.float16
+ elif cfg.weight_dtype == "bf16":
+ weight_dtype = torch.bfloat16
+ elif cfg.weight_dtype == "fp32":
+ weight_dtype = torch.float32
+ else:
+ raise ValueError(
+ f"Do not support weight dtype: {cfg.weight_dtype} during training"
+ )
+
+ # Create Models
+ vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
+ "cuda", dtype=weight_dtype
+ )
+ reference_unet = UNet2DConditionModel.from_pretrained(
+ cfg.base_model_path,
+ subfolder="unet",
+ ).to(device="cuda", dtype=weight_dtype)
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
+ cfg.base_model_path,
+ cfg.mm_path,
+ subfolder="unet",
+ unet_additional_kwargs=OmegaConf.to_container(
+ cfg.unet_additional_kwargs),
+ use_landmark=False
+ ).to(device="cuda", dtype=weight_dtype)
+ imageproj = ImageProjModel(
+ cross_attention_dim=denoising_unet.config.cross_attention_dim,
+ clip_embeddings_dim=512,
+ clip_extra_context_tokens=4,
+ ).to(device="cuda", dtype=weight_dtype)
+ face_locator = FaceLocator(
+ conditioning_embedding_channels=320,
+ ).to(device="cuda", dtype=weight_dtype)
+ audioproj = AudioProjModel(
+ seq_len=5,
+ blocks=12,
+ channels=768,
+ intermediate_dim=512,
+ output_dim=768,
+ context_tokens=32,
+ ).to(device="cuda", dtype=weight_dtype)
+
+ # load module weight from stage 1
+ stage1_ckpt_dir = cfg.stage1_ckpt_dir
+ denoising_unet.load_state_dict(
+ torch.load(
+ os.path.join(stage1_ckpt_dir, "denoising_unet.pth"),
+ map_location="cpu",
+ ),
+ strict=False,
+ )
+ reference_unet.load_state_dict(
+ torch.load(
+ os.path.join(stage1_ckpt_dir, "reference_unet.pth"),
+ map_location="cpu",
+ ),
+ strict=False,
+ )
+ face_locator.load_state_dict(
+ torch.load(
+ os.path.join(stage1_ckpt_dir, "face_locator.pth"),
+ map_location="cpu",
+ ),
+ strict=False,
+ )
+ imageproj.load_state_dict(
+ torch.load(
+ os.path.join(stage1_ckpt_dir, "imageproj.pth"),
+ map_location="cpu",
+ ),
+ strict=False,
+ )
+
+ # Freeze
+ vae.requires_grad_(False)
+ imageproj.requires_grad_(False)
+ reference_unet.requires_grad_(False)
+ denoising_unet.requires_grad_(False)
+ face_locator.requires_grad_(False)
+ audioproj.requires_grad_(True)
+
+ # Set motion module learnable
+ trainable_modules = cfg.trainable_para
+ for name, module in denoising_unet.named_modules():
+ if any(trainable_mod in name for trainable_mod in trainable_modules):
+ for params in module.parameters():
+ params.requires_grad_(True)
+
+ reference_control_writer = ReferenceAttentionControl(
+ reference_unet,
+ do_classifier_free_guidance=False,
+ mode="write",
+ fusion_blocks="full",
+ )
+ reference_control_reader = ReferenceAttentionControl(
+ denoising_unet,
+ do_classifier_free_guidance=False,
+ mode="read",
+ fusion_blocks="full",
+ )
+
+ net = Net(
+ reference_unet,
+ denoising_unet,
+ face_locator,
+ reference_control_writer,
+ reference_control_reader,
+ imageproj,
+ audioproj,
+ ).to(dtype=weight_dtype)
+
+ # get noise scheduler
+ train_noise_scheduler, val_noise_scheduler = get_noise_scheduler(cfg)
+
+ if cfg.solver.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ reference_unet.enable_xformers_memory_efficient_attention()
+ denoising_unet.enable_xformers_memory_efficient_attention()
+
+ else:
+ raise ValueError(
+ "xformers is not available. Make sure it is installed correctly"
+ )
+
+ if cfg.solver.gradient_checkpointing:
+ reference_unet.enable_gradient_checkpointing()
+ denoising_unet.enable_gradient_checkpointing()
+
+ if cfg.solver.scale_lr:
+ learning_rate = (
+ cfg.solver.learning_rate
+ * cfg.solver.gradient_accumulation_steps
+ * cfg.data.train_bs
+ * accelerator.num_processes
+ )
+ else:
+ learning_rate = cfg.solver.learning_rate
+
+ # Initialize the optimizer
+ if cfg.solver.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError as exc:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ ) from exc
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ trainable_params = list(
+ filter(lambda p: p.requires_grad, net.parameters()))
+ logger.info(f"Total trainable params {len(trainable_params)}")
+ optimizer = optimizer_cls(
+ trainable_params,
+ lr=learning_rate,
+ betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
+ weight_decay=cfg.solver.adam_weight_decay,
+ eps=cfg.solver.adam_epsilon,
+ )
+
+ # Scheduler
+ lr_scheduler = get_scheduler(
+ cfg.solver.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=cfg.solver.lr_warmup_steps
+ * cfg.solver.gradient_accumulation_steps,
+ num_training_steps=cfg.solver.max_train_steps
+ * cfg.solver.gradient_accumulation_steps,
+ )
+
+ # get data loader
+ train_dataset = TalkingVideoDataset(
+ img_size=(cfg.data.train_width, cfg.data.train_height),
+ sample_rate=cfg.data.sample_rate,
+ n_sample_frames=cfg.data.n_sample_frames,
+ n_motion_frames=cfg.data.n_motion_frames,
+ audio_margin=cfg.data.audio_margin,
+ data_meta_paths=cfg.data.train_meta_paths,
+ wav2vec_cfg=cfg.wav2vec_config,
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=16
+ )
+
+ # Prepare everything with our `accelerator`.
+ (
+ net,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ ) = accelerator.prepare(
+ net,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(
+ len(train_dataloader) / cfg.solver.gradient_accumulation_steps
+ )
+ # Afterwards we recalculate our number of training epochs
+ num_train_epochs = math.ceil(
+ cfg.solver.max_train_steps / num_update_steps_per_epoch
+ )
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ run_time = datetime.now().strftime("%Y%m%d-%H%M")
+ accelerator.init_trackers(
+ exp_name,
+ init_kwargs={"mlflow": {"run_name": run_time}},
+ )
+ # dump config file
+ mlflow.log_dict(
+ OmegaConf.to_container(
+ cfg), "config.yaml"
+ )
+ logger.info(f"save config to {save_dir}")
+ OmegaConf.save(
+ cfg, os.path.join(save_dir, "config.yaml")
+ )
+
+ # Train!
+ total_batch_size = (
+ cfg.data.train_bs
+ * accelerator.num_processes
+ * cfg.solver.gradient_accumulation_steps
+ )
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}")
+ logger.info(
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
+ )
+ logger.info(
+ f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}"
+ )
+ logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # # Potentially load in the weights and states from a previous save
+ if cfg.resume_from_checkpoint:
+ logger.info(f"Loading checkpoint from {checkpoint_dir}")
+ global_step = load_checkpoint(cfg, checkpoint_dir, accelerator)
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(
+ range(global_step, cfg.solver.max_train_steps),
+ disable=not accelerator.is_local_main_process,
+ )
+ progress_bar.set_description("Steps")
+
+ for _ in range(first_epoch, num_train_epochs):
+ train_loss = 0.0
+ t_data_start = time.time()
+ for _, batch in enumerate(train_dataloader):
+ t_data = time.time() - t_data_start
+ with accelerator.accumulate(net):
+ # Convert videos to latent space
+ pixel_values_vid = batch["pixel_values_vid"].to(weight_dtype)
+
+ pixel_values_face_mask = batch["pixel_values_face_mask"]
+ pixel_values_face_mask = get_attention_mask(
+ pixel_values_face_mask, weight_dtype
+ )
+ pixel_values_lip_mask = batch["pixel_values_lip_mask"]
+ pixel_values_lip_mask = get_attention_mask(
+ pixel_values_lip_mask, weight_dtype
+ )
+ pixel_values_full_mask = batch["pixel_values_full_mask"]
+ pixel_values_full_mask = get_attention_mask(
+ pixel_values_full_mask, weight_dtype
+ )
+
+ with torch.no_grad():
+ video_length = pixel_values_vid.shape[1]
+ pixel_values_vid = rearrange(
+ pixel_values_vid, "b f c h w -> (b f) c h w"
+ )
+ latents = vae.encode(pixel_values_vid).latent_dist.sample()
+ latents = rearrange(
+ latents, "(b f) c h w -> b c f h w", f=video_length
+ )
+ latents = latents * 0.18215
+
+ noise = torch.randn_like(latents)
+ if cfg.noise_offset > 0:
+ noise += cfg.noise_offset * torch.randn(
+ (latents.shape[0], latents.shape[1], 1, 1, 1),
+ device=latents.device,
+ )
+
+ bsz = latents.shape[0]
+ # Sample a random timestep for each video
+ timesteps = torch.randint(
+ 0,
+ train_noise_scheduler.num_train_timesteps,
+ (bsz,),
+ device=latents.device,
+ )
+ timesteps = timesteps.long()
+
+ # mask for face locator
+ pixel_values_mask = (
+ batch["pixel_values_mask"].unsqueeze(
+ 1).to(dtype=weight_dtype)
+ )
+ pixel_values_mask = repeat(
+ pixel_values_mask,
+ "b f c h w -> b (repeat f) c h w",
+ repeat=video_length,
+ )
+ pixel_values_mask = pixel_values_mask.transpose(
+ 1, 2)
+
+ uncond_img_fwd = random.random() < cfg.uncond_img_ratio
+ uncond_audio_fwd = random.random() < cfg.uncond_audio_ratio
+
+ start_frame = random.random() < cfg.start_ratio
+ pixel_values_ref_img = batch["pixel_values_ref_img"].to(
+ dtype=weight_dtype
+ )
+ # initialize the motion frames as zero maps
+ if start_frame:
+ pixel_values_ref_img[:, 1:] = 0.0
+
+ ref_img_and_motion = rearrange(
+ pixel_values_ref_img, "b f c h w -> (b f) c h w"
+ )
+
+ with torch.no_grad():
+ ref_image_latents = vae.encode(
+ ref_img_and_motion
+ ).latent_dist.sample()
+ ref_image_latents = ref_image_latents * 0.18215
+ image_prompt_embeds = batch["face_emb"].to(
+ dtype=imageproj.dtype, device=imageproj.device
+ )
+
+ # add noise
+ noisy_latents = train_noise_scheduler.add_noise(
+ latents, noise, timesteps
+ )
+
+ # Get the target for loss depending on the prediction type
+ if train_noise_scheduler.prediction_type == "epsilon":
+ target = noise
+ elif train_noise_scheduler.prediction_type == "v_prediction":
+ target = train_noise_scheduler.get_velocity(
+ latents, noise, timesteps
+ )
+ else:
+ raise ValueError(
+ f"Unknown prediction type {train_noise_scheduler.prediction_type}"
+ )
+
+ # ---- Forward!!! -----
+ model_pred = net(
+ noisy_latents=noisy_latents,
+ timesteps=timesteps,
+ ref_image_latents=ref_image_latents,
+ face_emb=image_prompt_embeds,
+ mask=pixel_values_mask,
+ full_mask=pixel_values_full_mask,
+ face_mask=pixel_values_face_mask,
+ lip_mask=pixel_values_lip_mask,
+ audio_emb=batch["audio_tensor"].to(
+ dtype=weight_dtype),
+ uncond_img_fwd=uncond_img_fwd,
+ uncond_audio_fwd=uncond_audio_fwd,
+ )
+
+ if cfg.snr_gamma == 0:
+ loss = F.mse_loss(
+ model_pred.float(),
+ target.float(),
+ reduction="mean",
+ )
+ else:
+ snr = compute_snr(train_noise_scheduler, timesteps)
+ if train_noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack(
+ [snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
+ ).min(dim=1)[0]
+ / snr
+ )
+ loss = F.mse_loss(
+ model_pred.float(),
+ target.float(),
+ reduction="mean",
+ )
+ loss = (
+ loss.mean(dim=list(range(1, len(loss.shape))))
+ * mse_loss_weights
+ ).mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(
+ loss.repeat(cfg.data.train_bs)).mean()
+ train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(
+ trainable_params,
+ cfg.solver.max_grad_norm,
+ )
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ if accelerator.sync_gradients:
+ reference_control_reader.clear()
+ reference_control_writer.clear()
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % cfg.val.validation_steps == 0 or global_step==1:
+ if accelerator.is_main_process:
+ generator = torch.Generator(device=accelerator.device)
+ generator.manual_seed(cfg.seed)
+
+ log_validation(
+ accelerator=accelerator,
+ vae=vae,
+ net=net,
+ scheduler=val_noise_scheduler,
+ width=cfg.data.train_width,
+ height=cfg.data.train_height,
+ clip_length=cfg.data.n_sample_frames,
+ cfg=cfg,
+ save_dir=validation_dir,
+ global_step=global_step,
+ times=cfg.single_inference_times if cfg.single_inference_times is not None else None,
+ face_analysis_model_path=cfg.face_analysis_model_path
+ )
+
+ logs = {
+ "step_loss": loss.detach().item(),
+ "lr": lr_scheduler.get_last_lr()[0],
+ "td": f"{t_data:.2f}s",
+ }
+ t_data_start = time.time()
+ progress_bar.set_postfix(**logs)
+
+ if (
+ global_step % cfg.checkpointing_steps == 0
+ or global_step == cfg.solver.max_train_steps
+ ):
+ # save model
+ save_path = os.path.join(
+ checkpoint_dir, f"checkpoint-{global_step}")
+ if accelerator.is_main_process:
+ delete_additional_ckpt(checkpoint_dir, 30)
+ accelerator.wait_for_everyone()
+ accelerator.save_state(save_path)
+
+ # save model weight
+ unwrap_net = accelerator.unwrap_model(net)
+ if accelerator.is_main_process:
+ save_checkpoint(
+ unwrap_net,
+ module_dir,
+ "net",
+ global_step,
+ total_limit=30,
+ )
+ if global_step >= cfg.solver.max_train_steps:
+ break
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ accelerator.end_training()
+
+
+def load_config(config_path: str) -> dict:
+ """
+ Loads the configuration file.
+
+ Args:
+ config_path (str): Path to the configuration file.
+
+ Returns:
+ dict: The configuration dictionary.
+ """
+
+ if config_path.endswith(".yaml"):
+ return OmegaConf.load(config_path)
+ if config_path.endswith(".py"):
+ return import_filename(config_path).cfg
+ raise ValueError("Unsupported format for config file")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--config", type=str, default="./configs/train/stage2.yaml"
+ )
+ args = parser.parse_args()
+
+ try:
+ config = load_config(args.config)
+ train_stage2_process(config)
+ except Exception as e:
+ logging.error("Failed to execute the training process: %s", e)