diff --git a/README.md b/README.md index ead6e978..86c3a3eb 100644 --- a/README.md +++ b/README.md @@ -23,4 +23,4 @@ pip install xformers==0.0.22.post7 ``` ## Download Checkpoints -All models will be downloaded automatically to ComfyUI's model folder, just no wrries. +All models will be downloaded automatically to ComfyUI's model folder, just no wrries. \ No newline at end of file diff --git a/README.md.bak b/README.md.bak index 9720bc1b..74a2493f 100644 --- a/README.md.bak +++ b/README.md.bak @@ -24,7 +24,7 @@ - + @@ -65,6 +65,7 @@ Explore [more examples](https://fudan-generative-vision.github.io/hallo). ## πŸ“° News +- **`2024/06/28`**: πŸŽ‰πŸŽ‰πŸŽ‰ We are proud to announce the release of our model training code. Try your own training data. Here is [tutorial](#training). - **`2024/06/21`**: πŸš€πŸš€πŸš€ Cloned a Gradio demo on [πŸ€—Huggingface space](https://huggingface.co/spaces/fudan-generative-ai/hallo). - **`2024/06/20`**: 🌟🌟🌟 Received numerous contributions from the community, including a [Windows version](https://github.com/sdbds/hallo-for-windows), [ComfyUI](https://github.com/AIFSH/ComfyUI-Hallo), [WebUI](https://github.com/fudan-generative-vision/hallo/pull/51), and [Docker template](https://github.com/ashleykleynhans/hallo-docker). - **`2024/06/15`**: ✨✨✨ Released some images and audios for inference testing on [πŸ€—Huggingface](https://huggingface.co/datasets/fudan-generative-ai/hallo_inference_samples). @@ -74,6 +75,7 @@ Explore [more examples](https://fudan-generative-vision.github.io/hallo). Explore the resources developed by our community to enhance your experience with Hallo: +- [TTS x Hallo Talking Portrait Generator](https://huggingface.co/spaces/fffiloni/tts-hallo-talking-portrait) - Check out this awesome Gradio demo by [@Sylvain Filoni](https://huggingface.co/fffiloni)! With this tool, you can conveniently prepare portrait image and audio for Hallo. - [Demo on Huggingface](https://huggingface.co/spaces/multimodalart/hallo) - Check out this easy-to-use Gradio demo by [@multimodalart](https://huggingface.co/multimodalart). - [hallo-webui](https://github.com/daswer123/hallo-webui) - Explore the WebUI created by [@daswer123](https://github.com/daswer123). - [hallo-for-windows](https://github.com/sdbds/hallo-for-windows) - Utilize Hallo on Windows with the guide by [@sdbds](https://github.com/sdbds). @@ -233,15 +235,113 @@ options: face region ``` +## Training + +### Prepare Data for Training + +The training data, which utilizes some talking-face videos similar to the source images used for inference, also needs to meet the following requirements: + +1. It should be cropped into squares. +2. The face should be the main focus, making up 50%-70% of the image. +3. The face should be facing forward, with a rotation angle of less than 30Β° (no side profiles). + +Organize your raw videos into the following directory structure: + + +```text +dataset_name/ +|-- videos/ +| |-- 0001.mp4 +| |-- 0002.mp4 +| |-- 0003.mp4 +| `-- 0004.mp4 +``` + +You can use any `dataset_name`, but ensure the `videos` directory is named as shown above. + +Next, process the videos with the following commands: + +```bash +python -m scripts.data_preprocess --input_dir dataset_name/videos --step 1 +python -m scripts.data_preprocess --input_dir dataset_name/videos --step 2 +``` + +**Note:** Execute steps 1 and 2 sequentially as they perform different tasks. Step 1 converts videos into frames, extracts audio from each video, and generates the necessary masks. Step 2 generates face embeddings using InsightFace and audio embeddings using Wav2Vec, and requires a GPU. For parallel processing, use the `-p` and `-r` arguments. The `-p` argument specifies the total number of instances to launch, dividing the data into `p` parts. The `-r` argument specifies which part the current process should handle. You need to manually launch multiple instances with different values for `-r`. + +Generate the metadata JSON files with the following commands: + +```bash +python scripts/extract_meta_info_stage1.py -r path/to/dataset -n dataset_name +python scripts/extract_meta_info_stage2.py -r path/to/dataset -n dataset_name +``` + +Replace `path/to/dataset` with the path to the parent directory of `videos`, such as `dataset_name` in the example above. This will generate `dataset_name_stage1.json` and `dataset_name_stage2.json` in the `./data` directory. + +### Training + +Update the data meta path settings in the configuration YAML files, `configs/train/stage1.yaml` and `configs/train/stage2.yaml`: + + +```yaml +#stage1.yaml +data: + meta_paths: + - ./data/dataset_name_stage1.json + +#stage2.yaml +data: + meta_paths: + - ./data/dataset_name_stage2.json +``` + +Start training with the following command: + +```shell +accelerate launch -m \ + --config_file accelerate_config.yaml \ + --machine_rank 0 \ + --main_process_ip 0.0.0.0 \ + --main_process_port 20055 \ + --num_machines 1 \ + --num_processes 8 \ + scripts.train_stage1 --config ./configs/train/stage1.yaml +``` + +#### Accelerate Usage Explanation + +The `accelerate launch` command is used to start the training process with distributed settings. + +```shell +accelerate launch [arguments] {training_script} --{training_script-argument-1} --{training_script-argument-2} ... +``` + +**Arguments for Accelerate:** + +- `-m, --module`: Interpret the launch script as a Python module. +- `--config_file`: Configuration file for Hugging Face Accelerate. +- `--machine_rank`: Rank of the current machine in a multi-node setup. +- `--main_process_ip`: IP address of the master node. +- `--main_process_port`: Port of the master node. +- `--num_machines`: Total number of nodes participating in the training. +- `--num_processes`: Total number of processes for training, matching the total number of GPUs across all machines. + +**Arguments for Training:** + +- `{training_script}`: The training script, such as `scripts.train_stage1` or `scripts.train_stage2`. +- `--{training_script-argument-1}`: Arguments specific to the training script. Our training scripts accept one argument, `--config`, to specify the training configuration file. + +For multi-node training, you need to manually run the command with different `machine_rank` on each node separately. + +For more settings, refer to the [Accelerate documentation](https://huggingface.co/docs/accelerate/en/index). + ## πŸ“…οΈ Roadmap | Status | Milestone | ETA | | :----: | :---------------------------------------------------------------------------------------------------- | :--------: | | βœ… | **[Inference source code meet everyone on GitHub](https://github.com/fudan-generative-vision/hallo)** | 2024-06-15 | | βœ… | **[Pretrained models on Huggingface](https://huggingface.co/fudan-generative-ai/hallo)** | 2024-06-15 | -| 🚧 | **[Optimizing Performance on images with a resolution of 256x256.]()** | 2024-06-23 | -| πŸš€ | **[Improving the model's performance on Mandarin Chinese]()** | 2024-06-25 | -| πŸš€ | **[Releasing data preparation and training scripts]()** | 2024-06-28 | +| βœ… | **[Releasing data preparation and training scripts](#training)** | 2024-06-28 | +| πŸš€ | **[Improving the model's performance on Mandarin Chinese]()** | TBD |
Other Enhancements @@ -250,7 +350,6 @@ options: - [x] Bug: Output video may lose several frames. [#41](https://github.com/fudan-generative-vision/hallo/issues/41) - [ ] Bug: Sound volume affecting inference results (audio normalization). - [ ] ~~Enhancement: Inference code logic optimization~~. This solution doesn't show significant performance improvements. Trying other approaches. -- [ ] Enhancement: Enhancing performance on low resolutions(256x256) to support more efficient usage.
@@ -297,4 +396,4 @@ Thank you to all the contributors who have helped to make this project better! - + \ No newline at end of file diff --git a/accelerate_config.yaml b/accelerate_config.yaml new file mode 100644 index 00000000..6fa766f1 --- /dev/null +++ b/accelerate_config.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +debug: true +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: "no" +main_training_function: main +mixed_precision: "fp16" +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/configs/train/stage1.yaml b/configs/train/stage1.yaml new file mode 100644 index 00000000..28760ed2 --- /dev/null +++ b/configs/train/stage1.yaml @@ -0,0 +1,63 @@ +data: + train_bs: 8 + train_width: 512 + train_height: 512 + meta_paths: + - "./data/HDTF_meta.json" + # Margin of frame indexes between ref and tgt images + sample_margin: 30 + +solver: + gradient_accumulation_steps: 1 + mixed_precision: "no" + enable_xformers_memory_efficient_attention: True + gradient_checkpointing: False + max_train_steps: 30000 + max_grad_norm: 1.0 + # lr + learning_rate: 1.0e-5 + scale_lr: False + lr_warmup_steps: 1 + lr_scheduler: "constant" + + # optimizer + use_8bit_adam: False + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-2 + adam_epsilon: 1.0e-8 + +val: + validation_steps: 500 + +noise_scheduler_kwargs: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "scaled_linear" + steps_offset: 1 + clip_sample: false + +base_model_path: "./pretrained_models/stable-diffusion-v1-5/" +vae_model_path: "./pretrained_models/sd-vae-ft-mse" +face_analysis_model_path: "./pretrained_models/face_analysis" + +weight_dtype: "fp16" # [fp16, fp32] +uncond_ratio: 0.1 +noise_offset: 0.05 +snr_gamma: 5.0 +enable_zero_snr: True +face_locator_pretrained: False + +seed: 42 +resume_from_checkpoint: "latest" +checkpointing_steps: 500 +exp_name: "stage1" +output_dir: "./exp_output" + +ref_image_paths: + - "examples/reference_images/1.jpg" + +mask_image_paths: + - "examples/masks/1.png" + diff --git a/configs/train/stage2.yaml b/configs/train/stage2.yaml new file mode 100644 index 00000000..baa30729 --- /dev/null +++ b/configs/train/stage2.yaml @@ -0,0 +1,119 @@ +data: + train_bs: 4 + val_bs: 1 + train_width: 512 + train_height: 512 + fps: 25 + sample_rate: 16000 + n_motion_frames: 2 + n_sample_frames: 14 + audio_margin: 2 + train_meta_paths: + - "./data/hdtf_split_stage2.json" + +wav2vec_config: + audio_type: "vocals" # audio vocals + model_scale: "base" # base large + features: "all" # last avg all + model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h +audio_separator: + model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx +face_expand_ratio: 1.2 + +solver: + gradient_accumulation_steps: 1 + mixed_precision: "no" + enable_xformers_memory_efficient_attention: True + gradient_checkpointing: True + max_train_steps: 30000 + max_grad_norm: 1.0 + # lr + learning_rate: 1e-5 + scale_lr: False + lr_warmup_steps: 1 + lr_scheduler: "constant" + + # optimizer + use_8bit_adam: True + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-2 + adam_epsilon: 1.0e-8 + +val: + validation_steps: 1000 + +noise_scheduler_kwargs: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + steps_offset: 1 + clip_sample: false + +unet_additional_kwargs: + use_inflated_groupnorm: true + unet_use_cross_frame_attention: false + unet_use_temporal_attention: false + use_motion_module: true + use_audio_module: true + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: true + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 + audio_attention_dim: 768 + stack_enable_blocks_name: + - "up" + - "down" + - "mid" + stack_enable_blocks_depth: [0,1,2,3] + +trainable_para: + - audio_modules + - motion_modules + +base_model_path: "./pretrained_models/stable-diffusion-v1-5/" +vae_model_path: "./pretrained_models/sd-vae-ft-mse" +face_analysis_model_path: "./pretrained_models/face_analysis" +mm_path: "./pretrained_models/motion_module/mm_sd_v15_v2.ckpt" + +weight_dtype: "fp16" # [fp16, fp32] +uncond_img_ratio: 0.05 +uncond_audio_ratio: 0.05 +uncond_ia_ratio: 0.05 +start_ratio: 0.05 +noise_offset: 0.05 +snr_gamma: 5.0 +enable_zero_snr: True +stage1_ckpt_dir: "./exp_output/stage1/" + +single_inference_times: 10 +inference_steps: 40 +cfg_scale: 3.5 + +seed: 42 +resume_from_checkpoint: "latest" +checkpointing_steps: 500 +exp_name: "stage2" +output_dir: "./exp_output" + +ref_img_path: + - "examples/reference_images/1.jpg" + +audio_path: + - "examples/driving_audios/1.wav" + + diff --git a/examples/masks/1.png b/examples/masks/1.png new file mode 100644 index 00000000..c63e0757 Binary files /dev/null and b/examples/masks/1.png differ diff --git a/hallo/datasets/audio_processor.py b/hallo/datasets/audio_processor.py index 50738970..f340a52f 100644 --- a/hallo/datasets/audio_processor.py +++ b/hallo/datasets/audio_processor.py @@ -73,7 +73,7 @@ def __init__( self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True) - def preprocess(self, wav_file: str, clip_length: int): + def preprocess(self, wav_file: str, clip_length: int=-1): """ Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate. The separated vocal track is then converted into wav2vec2 for further processing or analysis. @@ -109,7 +109,8 @@ def preprocess(self, wav_file: str, clip_length: int): audio_length = seq_len audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device) - if seq_len % clip_length != 0: + + if clip_length>0 and seq_len % clip_length != 0: audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - seq_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0) seq_len += clip_length - seq_len % clip_length audio_feature = audio_feature.unsqueeze(0) diff --git a/hallo/datasets/image_processor.py b/hallo/datasets/image_processor.py index c1456fe0..1eaa0230 100644 --- a/hallo/datasets/image_processor.py +++ b/hallo/datasets/image_processor.py @@ -1,3 +1,4 @@ +# pylint: disable=W0718 """ This module is responsible for processing images, particularly for face-related tasks. It uses various libraries such as OpenCV, NumPy, and InsightFace to perform tasks like @@ -8,13 +9,15 @@ from typing import List import cv2 +import mediapipe as mp import numpy as np import torch from insightface.app import FaceAnalysis from PIL import Image from torchvision import transforms -from ..utils.util import get_mask +from ..utils.util import (blur_mask, get_landmark_overframes, get_mask, + get_union_face_mask, get_union_lip_mask) MEAN = 0.5 STD = 0.5 @@ -207,3 +210,137 @@ def __enter__(self): def __exit__(self, _exc_type, _exc_val, _exc_tb): self.close() + + +class ImageProcessorForDataProcessing(): + """ + ImageProcessor is a class responsible for processing images, particularly for face-related tasks. + It takes in an image and performs various operations such as augmentation, face detection, + face embedding extraction, and rendering a face mask. The processed images are then used for + further analysis or recognition purposes. + + Attributes: + img_size (int): The size of the image to be processed. + face_analysis_model_path (str): The path to the face analysis model. + + Methods: + preprocess(source_image_path, cache_dir): + Preprocesses the input image by performing augmentation, face detection, + face embedding extraction, and rendering a face mask. + + close(): + Closes the ImageProcessor and releases any resources being used. + + _augmentation(images, transform, state=None): + Applies image augmentation to the input images using the given transform and state. + + __enter__(): + Enters a runtime context and returns the ImageProcessor object. + + __exit__(_exc_type, _exc_val, _exc_tb): + Exits a runtime context and handles any exceptions that occurred during the processing. + """ + def __init__(self, face_analysis_model_path, landmark_model_path, step) -> None: + if step == 2: + self.face_analysis = FaceAnalysis( + name="", + root=face_analysis_model_path, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + self.face_analysis.prepare(ctx_id=0, det_size=(640, 640)) + self.landmarker = None + else: + BaseOptions = mp.tasks.BaseOptions + FaceLandmarker = mp.tasks.vision.FaceLandmarker + FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions + VisionRunningMode = mp.tasks.vision.RunningMode + # Create a face landmarker instance with the video mode: + options = FaceLandmarkerOptions( + base_options=BaseOptions(model_asset_path=landmark_model_path), + running_mode=VisionRunningMode.IMAGE, + ) + self.landmarker = FaceLandmarker.create_from_options(options) + self.face_analysis = None + + def preprocess(self, source_image_path: str): + """ + Apply preprocessing to the source image to prepare for face analysis. + + Parameters: + source_image_path (str): The path to the source image. + cache_dir (str): The directory to cache intermediate results. + + Returns: + None + """ + # 1. get face embdeding + face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask = None, None, None, None, None + if self.face_analysis: + for frame in sorted(os.listdir(source_image_path)): + try: + source_image = Image.open( + os.path.join(source_image_path, frame)) + ref_image_pil = source_image.convert("RGB") + # 2.1 detect face + faces = self.face_analysis.get(cv2.cvtColor( + np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR)) + # use max size face + face = sorted(faces, key=lambda x: ( + x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[-1] + # 2.2 face embedding + face_emb = face["embedding"] + if face_emb is not None: + break + except Exception as _: + continue + + if self.landmarker: + # 3.1 get landmark + landmarks, height, width = get_landmark_overframes( + self.landmarker, source_image_path) + assert len(landmarks) == len(os.listdir(source_image_path)) + + # 3 render face and lip mask + face_mask = get_union_face_mask(landmarks, height, width) + lip_mask = get_union_lip_mask(landmarks, height, width) + + # 4 gaussian blur + blur_face_mask = blur_mask(face_mask, (64, 64), (51, 51)) + blur_lip_mask = blur_mask(lip_mask, (64, 64), (31, 31)) + + # 5 seperate mask + sep_face_mask = cv2.subtract(blur_face_mask, blur_lip_mask) + sep_pose_mask = 255.0 - blur_face_mask + sep_lip_mask = blur_lip_mask + + return face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask + + def close(self): + """ + Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance. + + Args: + self: The ImageProcessor instance. + + Returns: + None. + """ + for _, model in self.face_analysis.models.items(): + if hasattr(model, "Dispose"): + model.Dispose() + + def _augmentation(self, images, transform, state=None): + if state is not None: + torch.set_rng_state(state) + if isinstance(images, List): + transformed_images = [transform(img) for img in images] + ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w) + else: + ret_tensor = transform(images) # (c, h, w) + return ret_tensor + + def __enter__(self): + return self + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + self.close() diff --git a/hallo/datasets/talk_video.py b/hallo/datasets/talk_video.py index 4f9114ba..25c3ab81 100644 --- a/hallo/datasets/talk_video.py +++ b/hallo/datasets/talk_video.py @@ -145,25 +145,29 @@ def __init__( ) self.attn_transform_64 = transforms.Compose( [ - transforms.Resize((64,64)), + transforms.Resize( + (self.img_size[0] // 8, self.img_size[0] // 8)), transforms.ToTensor(), ] ) self.attn_transform_32 = transforms.Compose( [ - transforms.Resize((32, 32)), + transforms.Resize( + (self.img_size[0] // 16, self.img_size[0] // 16)), transforms.ToTensor(), ] ) self.attn_transform_16 = transforms.Compose( [ - transforms.Resize((16, 16)), + transforms.Resize( + (self.img_size[0] // 32, self.img_size[0] // 32)), transforms.ToTensor(), ] ) self.attn_transform_8 = transforms.Compose( [ - transforms.Resize((8, 8)), + transforms.Resize( + (self.img_size[0] // 64, self.img_size[0] // 64)), transforms.ToTensor(), ] ) diff --git a/hallo/models/motion_module.py b/hallo/models/motion_module.py index 07f98454..f62877d4 100644 --- a/hallo/models/motion_module.py +++ b/hallo/models/motion_module.py @@ -507,6 +507,7 @@ def extra_repr(self): def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, + attention_op = None, ): """ Sets the use of memory-efficient attention xformers for the VersatileAttention class. diff --git a/hallo/utils/util.py b/hallo/utils/util.py index 6ceca790..e29af026 100644 --- a/hallo/utils/util.py +++ b/hallo/utils/util.py @@ -1,6 +1,7 @@ # pylint: disable=C0116 # pylint: disable=W0718 # pylint: disable=R1732 +# pylint: disable=R0801 """ utils.py @@ -67,6 +68,7 @@ import subprocess import sys from pathlib import Path +from typing import List import av import cv2 @@ -377,7 +379,32 @@ def get_landmark(file, model_path): return np.array(face_landmark), height, width -def get_lip_mask(landmarks, height, width, out_path): +def get_landmark_overframes(landmark_model, frames_path): + """ + This function iterate frames and returns the facial landmarks detected in each frame. + + Args: + landmark_model: mediapipe landmark model instance + frames_path (str): The path to the video frames. + + Returns: + List[List[float], float, float]: A List containing two lists of floats representing the x and y coordinates of the facial landmarks. + """ + + face_landmarks = [] + + for file in sorted(os.listdir(frames_path)): + image = mp.Image.create_from_file(os.path.join(frames_path, file)) + height, width = image.height, image.width + landmarker_result = landmark_model.detect(image) + frame_landmark = compute_face_landmarks( + landmarker_result, height, width) + face_landmarks.append(frame_landmark) + + return face_landmarks, height, width + + +def get_lip_mask(landmarks, height, width, out_path=None, expand_ratio=2.0): """ Extracts the lip region from the given landmarks and saves it as an image. @@ -386,19 +413,42 @@ def get_lip_mask(landmarks, height, width, out_path): height (int): Height of the output lip mask image. width (int): Width of the output lip mask image. out_path (pathlib.Path): Path to save the lip mask image. + expand_ratio (float): Expand ratio of mask. """ lip_landmarks = np.take(landmarks, lip_ids, 0) min_xy_lip = np.round(np.min(lip_landmarks, 0)) max_xy_lip = np.round(np.max(lip_landmarks, 0)) min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1] = expand_region( - [min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1]], width, height, 2.0) + [min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1]], width, height, expand_ratio) lip_mask = np.zeros((height, width), dtype=np.uint8) lip_mask[round(min_xy_lip[1]):round(max_xy_lip[1]), round(min_xy_lip[0]):round(max_xy_lip[0])] = 255 - cv2.imwrite(str(out_path), lip_mask) + if out_path: + cv2.imwrite(str(out_path), lip_mask) + return None + + return lip_mask + + +def get_union_lip_mask(landmarks, height, width, expand_ratio=1): + """ + Extracts the lip region from the given landmarks and saves it as an image. + + Parameters: + landmarks (numpy.ndarray): Array of facial landmarks. + height (int): Height of the output lip mask image. + width (int): Width of the output lip mask image. + expand_ratio (float): Expand ratio of mask. + """ + lip_masks = [] + for landmark in landmarks: + lip_masks.append(get_lip_mask(landmarks=landmark, height=height, + width=width, expand_ratio=expand_ratio)) + union_mask = get_union_mask(lip_masks) + return union_mask -def get_face_mask(landmarks, height, width, out_path, expand_ratio): +def get_face_mask(landmarks, height, width, out_path=None, expand_ratio=1.2): """ Generate a face mask based on the given landmarks. @@ -407,7 +457,7 @@ def get_face_mask(landmarks, height, width, out_path, expand_ratio): height (int): The height of the output face mask image. width (int): The width of the output face mask image. out_path (pathlib.Path): The path to save the face mask image. - + expand_ratio (float): Expand ratio of mask. Returns: None. The face mask image is saved at the specified path. """ @@ -419,8 +469,30 @@ def get_face_mask(landmarks, height, width, out_path, expand_ratio): face_mask = np.zeros((height, width), dtype=np.uint8) face_mask[round(min_xy_face[1]):round(max_xy_face[1]), round(min_xy_face[0]):round(max_xy_face[0])] = 255 - cv2.imwrite(str(out_path), face_mask) + if out_path: + cv2.imwrite(str(out_path), face_mask) + return None + return face_mask + + +def get_union_face_mask(landmarks, height, width, expand_ratio=1): + """ + Generate a face mask based on the given landmarks. + + Args: + landmarks (numpy.ndarray): The landmarks of the face. + height (int): The height of the output face mask image. + width (int): The width of the output face mask image. + expand_ratio (float): Expand ratio of mask. + Returns: + None. The face mask image is saved at the specified path. + """ + face_masks = [] + for landmark in landmarks: + face_masks.append(get_face_mask(landmarks=landmark,height=height,width=width,expand_ratio=expand_ratio)) + union_mask = get_union_mask(face_masks) + return union_mask def get_mask(file, cache_dir, face_expand_raio, landmark_model_path): """ @@ -506,6 +578,25 @@ def get_blur_mask(file_path, output_file_path, resize_dim=(64, 64), kernel_size= mask = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE) # Check if the image is loaded successfully + if mask is not None: + normalized_mask = blur_mask(mask,resize_dim=resize_dim,kernel_size=kernel_size) + # Save the normalized mask image + cv2.imwrite(output_file_path, normalized_mask) + return f"Processed, normalized, and saved: {output_file_path}" + return f"Failed to load image: {file_path}" + + +def blur_mask(mask, resize_dim=(64, 64), kernel_size=(51, 51)): + """ + Read, resize, blur, normalize, and save an image. + + Parameters: + file_path (str): Path to the input image file. + resize_dim (tuple): Dimensions to resize the images to. + kernel_size (tuple): Size of the kernel to use for Gaussian blur. + """ + # Check if the image is loaded successfully + normalized_mask = None if mask is not None: # Resize the mask image resized_mask = cv2.resize(mask, resize_dim) @@ -515,10 +606,7 @@ def get_blur_mask(file_path, output_file_path, resize_dim=(64, 64), kernel_size= normalized_mask = cv2.normalize( blurred_mask, None, 0, 255, cv2.NORM_MINMAX) # Save the normalized mask image - cv2.imwrite(output_file_path, normalized_mask) - return f"Processed, normalized, and saved: {output_file_path}" - return f"Failed to load image: {file_path}" - + return normalized_mask def get_background_mask(file_path, output_file_path): """ @@ -614,3 +702,279 @@ def get_face_region(image_path: str, detector): except Exception as e: print(f"Error processing image {image_path}: {e}") return None, None + + +def save_checkpoint(model: torch.nn.Module, save_dir: str, prefix: str, ckpt_num: int, total_limit: int = -1) -> None: + """ + Save the model's state_dict to a checkpoint file. + + If `total_limit` is provided, this function will remove the oldest checkpoints + until the total number of checkpoints is less than the specified limit. + + Args: + model (nn.Module): The model whose state_dict is to be saved. + save_dir (str): The directory where the checkpoint will be saved. + prefix (str): The prefix for the checkpoint file name. + ckpt_num (int): The checkpoint number to be saved. + total_limit (int, optional): The maximum number of checkpoints to keep. + Defaults to None, in which case no checkpoints will be removed. + + Raises: + FileNotFoundError: If the save directory does not exist. + ValueError: If the checkpoint number is negative. + OSError: If there is an error saving the checkpoint. + """ + + if not osp.exists(save_dir): + raise FileNotFoundError( + f"The save directory {save_dir} does not exist.") + + if ckpt_num < 0: + raise ValueError(f"Checkpoint number {ckpt_num} must be non-negative.") + + save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth") + + if total_limit > 0: + checkpoints = os.listdir(save_dir) + checkpoints = [d for d in checkpoints if d.startswith(prefix)] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0]) + ) + + if len(checkpoints) >= total_limit: + num_to_remove = len(checkpoints) - total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + print( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + print( + f"Removing checkpoints: {', '.join(removing_checkpoints)}" + ) + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint_path = osp.join( + save_dir, removing_checkpoint) + try: + os.remove(removing_checkpoint_path) + except OSError as e: + print( + f"Error removing checkpoint {removing_checkpoint_path}: {e}") + + state_dict = model.state_dict() + try: + torch.save(state_dict, save_path) + print(f"Checkpoint saved at {save_path}") + except OSError as e: + raise OSError(f"Error saving checkpoint at {save_path}: {e}") from e + + +def init_output_dir(dir_list: List[str]): + """ + Initialize the output directories. + + This function creates the directories specified in the `dir_list`. If a directory already exists, it does nothing. + + Args: + dir_list (List[str]): List of directory paths to create. + """ + for path in dir_list: + os.makedirs(path, exist_ok=True) + + +def load_checkpoint(cfg, save_dir, accelerator): + """ + Load the most recent checkpoint from the specified directory. + + This function loads the latest checkpoint from the `save_dir` if the `resume_from_checkpoint` parameter is set to "latest". + If a specific checkpoint is provided in `resume_from_checkpoint`, it loads that checkpoint. If no checkpoint is found, + it starts training from scratch. + + Args: + cfg: The configuration object containing training parameters. + save_dir (str): The directory where checkpoints are saved. + accelerator: The accelerator object for distributed training. + + Returns: + int: The global step at which to resume training. + """ + if cfg.resume_from_checkpoint != "latest": + resume_dir = cfg.resume_from_checkpoint + else: + resume_dir = save_dir + # Get the most recent checkpoint + dirs = os.listdir(resume_dir) + + dirs = [d for d in dirs if d.startswith("checkpoint")] + if len(dirs) > 0: + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.load_state(os.path.join(resume_dir, path)) + accelerator.print(f"Resuming from checkpoint {path}") + global_step = int(path.split("-")[1]) + else: + accelerator.print( + f"Could not find checkpoint under {resume_dir}, start training from scratch") + global_step = 0 + + return global_step + + +def compute_snr(noise_scheduler, timesteps): + """ + Computes SNR as per + https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/ + 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/ + # 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ + timesteps + ].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( + device=timesteps.device + )[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + +def extract_audio_from_videos(video_path: Path, audio_output_path: Path) -> Path: + """ + Extract audio from a video file and save it as a WAV file. + + This function uses ffmpeg to extract the audio stream from a given video file and saves it as a WAV file + in the specified output directory. + + Args: + video_path (Path): The path to the input video file. + output_dir (Path): The directory where the extracted audio file will be saved. + + Returns: + Path: The path to the extracted audio file. + + Raises: + subprocess.CalledProcessError: If the ffmpeg command fails to execute. + """ + ffmpeg_command = [ + 'ffmpeg', '-y', + '-i', str(video_path), + '-vn', '-acodec', + "pcm_s16le", '-ar', '16000', '-ac', '2', + str(audio_output_path) + ] + + try: + print(f"Running command: {' '.join(ffmpeg_command)}") + subprocess.run(ffmpeg_command, check=True) + except subprocess.CalledProcessError as e: + print(f"Error extracting audio from video: {e}") + raise + + return audio_output_path + + +def convert_video_to_images(video_path: Path, output_dir: Path) -> Path: + """ + Convert a video file into a sequence of images. + + This function uses ffmpeg to convert each frame of the given video file into an image. The images are saved + in a directory named after the video file stem under the specified output directory. + + Args: + video_path (Path): The path to the input video file. + output_dir (Path): The directory where the extracted images will be saved. + + Returns: + Path: The path to the directory containing the extracted images. + + Raises: + subprocess.CalledProcessError: If the ffmpeg command fails to execute. + """ + ffmpeg_command = [ + 'ffmpeg', + '-i', str(video_path), + '-vf', 'fps=25', + str(output_dir / '%04d.png') + ] + + try: + print(f"Running command: {' '.join(ffmpeg_command)}") + subprocess.run(ffmpeg_command, check=True) + except subprocess.CalledProcessError as e: + print(f"Error converting video to images: {e}") + raise + + return output_dir + + +def get_union_mask(masks): + """ + Compute the union of a list of masks. + + This function takes a list of masks and computes their union by taking the maximum value at each pixel location. + Additionally, it finds the bounding box of the non-zero regions in the mask and sets the bounding box area to white. + + Args: + masks (list of np.ndarray): List of masks to be combined. + + Returns: + np.ndarray: The union of the input masks. + """ + union_mask = None + for mask in masks: + if union_mask is None: + union_mask = mask + else: + union_mask = np.maximum(union_mask, mask) + + if union_mask is not None: + # Find the bounding box of the non-zero regions in the mask + rows = np.any(union_mask, axis=1) + cols = np.any(union_mask, axis=0) + try: + ymin, ymax = np.where(rows)[0][[0, -1]] + xmin, xmax = np.where(cols)[0][[0, -1]] + except Exception as e: + print(str(e)) + return 0.0 + + # Set bounding box area to white + union_mask[ymin: ymax + 1, xmin: xmax + 1] = np.max(union_mask) + + return union_mask + + +def move_final_checkpoint(save_dir, module_dir, prefix): + """ + Move the final checkpoint file to the save directory. + + This function identifies the latest checkpoint file based on the given prefix and moves it to the specified save directory. + + Args: + save_dir (str): The directory where the final checkpoint file should be saved. + module_dir (str): The directory containing the checkpoint files. + prefix (str): The prefix used to identify checkpoint files. + + Raises: + ValueError: If no checkpoint files are found with the specified prefix. + """ + checkpoints = os.listdir(module_dir) + checkpoints = [d for d in checkpoints if d.startswith(prefix)] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0]) + ) + shutil.copy2(os.path.join( + module_dir, checkpoints[-1]), os.path.join(save_dir, prefix + '.pth')) diff --git a/nodes.py b/nodes.py index 03da7973..46b44662 100644 --- a/nodes.py +++ b/nodes.py @@ -149,6 +149,8 @@ def inference(self, source_image, driving_audio, pose_weight, face_weight, lip_w # get src audio src_audio_path = os.path.join(folder_paths.get_input_directory(), driving_audio) + if not os.path.exists(src_audio_path): + src_audio_path = driving_audio # absolute path env = ':'.join([os.environ.get('PYTHONPATH', ''), cur_dir]) cmd = f"""PYTHONPATH={env} python {infer_py} --config "{tmp_yaml_path}" --source_image "{src_img_path}" --driving_audio "{src_audio_path}" --output {output_video_path} --pose_weight {pose_weight} --face_weight {face_weight} --lip_weight {lip_weight} --face_expand_ratio {face_expand_ratio}""" diff --git a/scripts/data_preprocess.py b/scripts/data_preprocess.py new file mode 100644 index 00000000..92efc2fc --- /dev/null +++ b/scripts/data_preprocess.py @@ -0,0 +1,191 @@ +# pylint: disable=W1203,W0718 +""" +This module is used to process videos to prepare data for training. It utilizes various libraries and models +to perform tasks such as video frame extraction, audio extraction, face mask generation, and face embedding extraction. +The script takes in command-line arguments to specify the input and output directories, GPU status, level of parallelism, +and rank for distributed processing. + +Usage: + python -m scripts.data_preprocess --input_dir /path/to/video_dir --dataset_name dataset_name --gpu_status --parallelism 4 --rank 0 + +Example: + python -m scripts.data_preprocess -i data/videos -o data/output -g -p 4 -r 0 +""" +import argparse +import logging +import os +from pathlib import Path +from typing import List + +import cv2 +import torch +from tqdm import tqdm + +from hallo.datasets.audio_processor import AudioProcessor +from hallo.datasets.image_processor import ImageProcessorForDataProcessing +from hallo.utils.util import convert_video_to_images, extract_audio_from_videos + +# Configure logging +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s') + + +def setup_directories(video_path: Path) -> dict: + """ + Setup directories for storing processed files. + + Args: + video_path (Path): Path to the video file. + + Returns: + dict: A dictionary containing paths for various directories. + """ + base_dir = video_path.parent.parent + dirs = { + "face_mask": base_dir / "face_mask", + "sep_pose_mask": base_dir / "sep_pose_mask", + "sep_face_mask": base_dir / "sep_face_mask", + "sep_lip_mask": base_dir / "sep_lip_mask", + "face_emb": base_dir / "face_emb", + "audio_emb": base_dir / "audio_emb" + } + + for path in dirs.values(): + path.mkdir(parents=True, exist_ok=True) + + return dirs + + +def process_single_video(video_path: Path, + output_dir: Path, + image_processor: ImageProcessorForDataProcessing, + audio_processor: AudioProcessor, + step: int) -> None: + """ + Process a single video file. + + Args: + video_path (Path): Path to the video file. + output_dir (Path): Directory to save the output. + image_processor (ImageProcessorForDataProcessing): Image processor object. + audio_processor (AudioProcessor): Audio processor object. + gpu_status (bool): Whether to use GPU for processing. + """ + assert video_path.exists(), f"Video path {video_path} does not exist" + dirs = setup_directories(video_path) + logging.info(f"Processing video: {video_path}") + + try: + if step == 1: + images_output_dir = output_dir / 'images' / video_path.stem + images_output_dir.mkdir(parents=True, exist_ok=True) + images_output_dir = convert_video_to_images( + video_path, images_output_dir) + logging.info(f"Images saved to: {images_output_dir}") + + audio_output_dir = output_dir / 'audios' + audio_output_dir.mkdir(parents=True, exist_ok=True) + audio_output_path = audio_output_dir / f'{video_path.stem}.wav' + audio_output_path = extract_audio_from_videos( + video_path, audio_output_path) + logging.info(f"Audio extracted to: {audio_output_path}") + + face_mask, _, sep_pose_mask, sep_face_mask, sep_lip_mask = image_processor.preprocess( + images_output_dir) + cv2.imwrite( + str(dirs["face_mask"] / f"{video_path.stem}.png"), face_mask) + cv2.imwrite(str(dirs["sep_pose_mask"] / + f"{video_path.stem}.png"), sep_pose_mask) + cv2.imwrite(str(dirs["sep_face_mask"] / + f"{video_path.stem}.png"), sep_face_mask) + cv2.imwrite(str(dirs["sep_lip_mask"] / + f"{video_path.stem}.png"), sep_lip_mask) + else: + images_dir = output_dir / "images" / video_path.stem + audio_path = output_dir / "audios" / f"{video_path.stem}.wav" + _, face_emb, _, _, _ = image_processor.preprocess(images_dir) + torch.save(face_emb, str( + dirs["face_emb"] / f"{video_path.stem}.pt")) + audio_emb, _ = audio_processor.preprocess(audio_path) + torch.save(audio_emb, str( + dirs["audio_emb"] / f"{video_path.stem}.pt")) + except Exception as e: + logging.error(f"Failed to process video {video_path}: {e}") + + +def process_all_videos(input_video_list: List[Path], output_dir: Path, step: int) -> None: + """ + Process all videos in the input list. + + Args: + input_video_list (List[Path]): List of video paths to process. + output_dir (Path): Directory to save the output. + gpu_status (bool): Whether to use GPU for processing. + """ + face_analysis_model_path = "pretrained_models/face_analysis" + landmark_model_path = "pretrained_models/face_analysis/models/face_landmarker_v2_with_blendshapes.task" + audio_separator_model_file = "pretrained_models/audio_separator/Kim_Vocal_2.onnx" + wav2vec_model_path = 'pretrained_models/wav2vec/wav2vec2-base-960h' + + audio_processor = AudioProcessor( + 16000, + 25, + wav2vec_model_path, + False, + os.path.dirname(audio_separator_model_file), + os.path.basename(audio_separator_model_file), + os.path.join(output_dir, "vocals"), + ) if step==2 else None + + image_processor = ImageProcessorForDataProcessing( + face_analysis_model_path, landmark_model_path, step) + + for video_path in tqdm(input_video_list, desc="Processing videos"): + process_single_video(video_path, output_dir, + image_processor, audio_processor, step) + + +def get_video_paths(source_dir: Path, parallelism: int, rank: int) -> List[Path]: + """ + Get paths of videos to process, partitioned for parallel processing. + + Args: + source_dir (Path): Source directory containing videos. + parallelism (int): Level of parallelism. + rank (int): Rank for distributed processing. + + Returns: + List[Path]: List of video paths to process. + """ + video_paths = [item for item in sorted( + source_dir.iterdir()) if item.is_file() and item.suffix == '.mp4'] + return [video_paths[i] for i in range(len(video_paths)) if i % parallelism == rank] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Process videos to prepare data for training. Run this script twice with different GPU status parameters." + ) + parser.add_argument("-i", "--input_dir", type=Path, + required=True, help="Directory containing videos") + parser.add_argument("-o", "--output_dir", type=Path, + help="Directory to save results, default is parent dir of input dir") + parser.add_argument("-s", "--step", type=int, default=1, + help="Specify data processing step 1 or 2, you should run 1 and 2 sequently") + parser.add_argument("-p", "--parallelism", default=1, + type=int, help="Level of parallelism") + parser.add_argument("-r", "--rank", default=0, type=int, + help="Rank for distributed processing") + + args = parser.parse_args() + + if args.output_dir is None: + args.output_dir = args.input_dir.parent + + video_path_list = get_video_paths( + args.input_dir, args.parallelism, args.rank) + + if not video_path_list: + logging.warning("No videos to process.") + else: + process_all_videos(video_path_list, args.output_dir, args.step) diff --git a/scripts/extract_meta_info_stage1.py b/scripts/extract_meta_info_stage1.py new file mode 100644 index 00000000..936cb06c --- /dev/null +++ b/scripts/extract_meta_info_stage1.py @@ -0,0 +1,106 @@ +# pylint: disable=R0801 +""" +This module is used to extract meta information from video directories. + +It takes in two command-line arguments: `root_path` and `dataset_name`. The `root_path` +specifies the path to the video directory, while the `dataset_name` specifies the name +of the dataset. The module then collects all the video folder paths, and for each video +folder, it checks if a mask path and a face embedding path exist. If they do, it appends +a dictionary containing the image path, mask path, and face embedding path to a list. + +Finally, the module writes the list of dictionaries to a JSON file with the filename +constructed using the `dataset_name`. + +Usage: + python tools/extract_meta_info_stage1.py --root_path /path/to/video_dir --dataset_name hdtf + +""" + +import argparse +import json +import os +from pathlib import Path + +import torch + + +def collect_video_folder_paths(root_path: Path) -> list: + """ + Collect all video folder paths from the root path. + + Args: + root_path (Path): The root directory containing video folders. + + Returns: + list: List of video folder paths. + """ + return [frames_dir.resolve() for frames_dir in root_path.iterdir() if frames_dir.is_dir()] + + +def construct_meta_info(frames_dir_path: Path) -> dict: + """ + Construct meta information for a given frames directory. + + Args: + frames_dir_path (Path): The path to the frames directory. + + Returns: + dict: A dictionary containing the meta information for the frames directory, or None if the required files do not exist. + """ + mask_path = str(frames_dir_path).replace("images", "face_mask") + ".png" + face_emb_path = str(frames_dir_path).replace("images", "face_emb") + ".pt" + + if not os.path.exists(mask_path): + print(f"Mask path not found: {mask_path}") + return None + + if torch.load(face_emb_path) is None: + print(f"Face emb is None: {face_emb_path}") + return None + + return { + "image_path": str(frames_dir_path), + "mask_path": mask_path, + "face_emb": face_emb_path, + } + + +def main(): + """ + Main function to extract meta info for training. + """ + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--root_path", type=str, + required=True, help="Root path of the video directories") + parser.add_argument("-n", "--dataset_name", type=str, + required=True, help="Name of the dataset") + parser.add_argument("--meta_info_name", type=str, + help="Name of the meta information file") + + args = parser.parse_args() + + if args.meta_info_name is None: + args.meta_info_name = args.dataset_name + + image_dir = Path(args.root_path) / "images" + output_dir = Path("./data") + output_dir.mkdir(exist_ok=True) + + # Collect all video folder paths + frames_dir_paths = collect_video_folder_paths(image_dir) + + meta_infos = [] + for frames_dir_path in frames_dir_paths: + meta_info = construct_meta_info(frames_dir_path) + if meta_info: + meta_infos.append(meta_info) + + output_file = output_dir / f"{args.meta_info_name}_stage1.json" + with output_file.open("w", encoding="utf-8") as f: + json.dump(meta_infos, f, indent=4) + + print(f"Final data count: {len(meta_infos)}") + + +if __name__ == "__main__": + main() diff --git a/scripts/extract_meta_info_stage2.py b/scripts/extract_meta_info_stage2.py new file mode 100644 index 00000000..e2d9301c --- /dev/null +++ b/scripts/extract_meta_info_stage2.py @@ -0,0 +1,192 @@ +# pylint: disable=R0801 +""" +This module is used to extract meta information from video files and store them in a JSON file. + +The script takes in command line arguments to specify the root path of the video files, +the dataset name, and the name of the meta information file. It then generates a list of +dictionaries containing the meta information for each video file and writes it to a JSON +file with the specified name. + +The meta information includes the path to the video file, the mask path, the face mask +path, the face mask union path, the face mask gaussian path, the lip mask path, the lip +mask union path, the lip mask gaussian path, the separate mask border, the separate mask +face, the separate mask lip, the face embedding path, the audio path, the vocals embedding +base last path, the vocals embedding base all path, the vocals embedding base average +path, the vocals embedding large last path, the vocals embedding large all path, and the +vocals embedding large average path. + +The script checks if the mask path exists before adding the information to the list. + +Usage: + python tools/extract_meta_info_stage2.py --root_path --dataset_name --meta_info_name + +Example: + python tools/extract_meta_info_stage2.py --root_path data/videos_25fps --dataset_name my_dataset --meta_info_name my_meta_info +""" + +import argparse +import json +import os +from pathlib import Path + +import torch +from decord import VideoReader, cpu +from tqdm import tqdm + + +def get_video_paths(root_path: Path, extensions: list) -> list: + """ + Get a list of video paths from the root path with the specified extensions. + + Args: + root_path (Path): The root directory containing video files. + extensions (list): List of file extensions to include. + + Returns: + list: List of video file paths. + """ + return [str(path.resolve()) for path in root_path.iterdir() if path.suffix in extensions] + + +def file_exists(file_path: str) -> bool: + """ + Check if a file exists. + + Args: + file_path (str): The path to the file. + + Returns: + bool: True if the file exists, False otherwise. + """ + return os.path.exists(file_path) + + +def construct_paths(video_path: str, base_dir: str, new_dir: str, new_ext: str) -> str: + """ + Construct a new path by replacing the base directory and extension in the original path. + + Args: + video_path (str): The original video path. + base_dir (str): The base directory to be replaced. + new_dir (str): The new directory to replace the base directory. + new_ext (str): The new file extension. + + Returns: + str: The constructed path. + """ + return str(video_path).replace(base_dir, new_dir).replace(".mp4", new_ext) + + +def extract_meta_info(video_path: str) -> dict: + """ + Extract meta information for a given video file. + + Args: + video_path (str): The path to the video file. + + Returns: + dict: A dictionary containing the meta information for the video. + """ + mask_path = construct_paths( + video_path, "videos", "face_mask", ".png") + sep_mask_border = construct_paths( + video_path, "videos", "sep_pose_mask", ".png") + sep_mask_face = construct_paths( + video_path, "videos", "sep_face_mask", ".png") + sep_mask_lip = construct_paths( + video_path, "videos", "sep_lip_mask", ".png") + face_emb_path = construct_paths( + video_path, "videos", "face_emb", ".pt") + audio_path = construct_paths(video_path, "videos", "audios", ".wav") + vocal_emb_base_all = construct_paths( + video_path, "videos", "audio_emb", ".pt") + + assert_flag = True + + if not file_exists(mask_path): + print(f"Mask path not found: {mask_path}") + assert_flag = False + if not file_exists(sep_mask_border): + print(f"Separate mask border not found: {sep_mask_border}") + assert_flag = False + if not file_exists(sep_mask_face): + print(f"Separate mask face not found: {sep_mask_face}") + assert_flag = False + if not file_exists(sep_mask_lip): + print(f"Separate mask lip not found: {sep_mask_lip}") + assert_flag = False + if not file_exists(face_emb_path): + print(f"Face embedding path not found: {face_emb_path}") + assert_flag = False + if not file_exists(audio_path): + print(f"Audio path not found: {audio_path}") + assert_flag = False + if not file_exists(vocal_emb_base_all): + print(f"Vocal embedding base all not found: {vocal_emb_base_all}") + assert_flag = False + + video_frames = VideoReader(video_path, ctx=cpu(0)) + audio_emb = torch.load(vocal_emb_base_all) + if abs(len(video_frames) - audio_emb.shape[0]) > 3: + print(f"Frame count mismatch for video: {video_path}") + assert_flag = False + + face_emb = torch.load(face_emb_path) + if face_emb is None: + print(f"Face embedding is None for video: {video_path}") + assert_flag = False + + del video_frames, audio_emb + + if assert_flag: + return { + "video_path": str(video_path), + "mask_path": mask_path, + "sep_mask_border": sep_mask_border, + "sep_mask_face": sep_mask_face, + "sep_mask_lip": sep_mask_lip, + "face_emb_path": face_emb_path, + "audio_path": audio_path, + "vocals_emb_base_all": vocal_emb_base_all, + } + return None + + +def main(): + """ + Main function to extract meta info for training. + """ + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--root_path", type=str, + required=True, help="Root path of the video files") + parser.add_argument("-n", "--dataset_name", type=str, + required=True, help="Name of the dataset") + parser.add_argument("--meta_info_name", type=str, + help="Name of the meta information file") + + args = parser.parse_args() + + if args.meta_info_name is None: + args.meta_info_name = args.dataset_name + + video_dir = Path(args.root_path) / "videos" + video_paths = get_video_paths(video_dir, [".mp4"]) + + meta_infos = [] + + for video_path in tqdm(video_paths, desc="Extracting meta info"): + meta_info = extract_meta_info(video_path) + if meta_info: + meta_infos.append(meta_info) + + print(f"Final data count: {len(meta_infos)}") + + output_file = Path(f"./data/{args.meta_info_name}_stage2.json") + output_file.parent.mkdir(parents=True, exist_ok=True) + + with output_file.open("w", encoding="utf-8") as f: + json.dump(meta_infos, f, indent=4) + + +if __name__ == "__main__": + main() diff --git a/scripts/train_stage1.py b/scripts/train_stage1.py new file mode 100644 index 00000000..e9e7e847 --- /dev/null +++ b/scripts/train_stage1.py @@ -0,0 +1,793 @@ +# pylint: disable=E1101,C0415,W0718,R0801 +# scripts/train_stage1.py +""" +This is the main training script for stage 1 of the project. +It imports necessary packages, defines necessary classes and functions, and trains the model using the provided configuration. + +The script includes the following classes and functions: + +1. Net: A PyTorch model that takes noisy latents, timesteps, reference image latents, face embeddings, + and face masks as input and returns the denoised latents. +3. log_validation: A function that logs the validation information using the given VAE, image encoder, + network, scheduler, accelerator, width, height, and configuration. +4. train_stage1_process: A function that processes the training stage 1 using the given configuration. + +The script also includes the necessary imports and a brief description of the purpose of the file. +""" + +import argparse +import copy +import logging +import math +import os +import random +import warnings +from datetime import datetime + +import cv2 +import diffusers +import mlflow +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs +from diffusers import AutoencoderKL, DDIMScheduler +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from insightface.app import FaceAnalysis +from omegaconf import OmegaConf +from PIL import Image +from torch import nn +from tqdm.auto import tqdm + +from hallo.animate.face_animate_static import StaticPipeline +from hallo.datasets.mask_image import FaceMaskDataset +from hallo.models.face_locator import FaceLocator +from hallo.models.image_proj import ImageProjModel +from hallo.models.mutual_self_attention import ReferenceAttentionControl +from hallo.models.unet_2d_condition import UNet2DConditionModel +from hallo.models.unet_3d import UNet3DConditionModel +from hallo.utils.util import (compute_snr, delete_additional_ckpt, + import_filename, init_output_dir, + load_checkpoint, move_final_checkpoint, + save_checkpoint, seed_everything) + +warnings.filterwarnings("ignore") + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +class Net(nn.Module): + """ + The Net class defines a neural network model that combines a reference UNet2DConditionModel, + a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image. + + Args: + reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation. + denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation. + face_locator (FaceLocator): The face locator model used for face animation. + reference_control_writer: The reference control writer component. + reference_control_reader: The reference control reader component. + imageproj: The image projection model. + + Forward method: + noisy_latents (torch.Tensor): The noisy latents tensor. + timesteps (torch.Tensor): The timesteps tensor. + ref_image_latents (torch.Tensor): The reference image latents tensor. + face_emb (torch.Tensor): The face embeddings tensor. + face_mask (torch.Tensor): The face mask tensor. + uncond_fwd (bool): A flag indicating whether to perform unconditional forward pass. + + Returns: + torch.Tensor: The output tensor of the neural network model. + """ + + def __init__( + self, + reference_unet: UNet2DConditionModel, + denoising_unet: UNet3DConditionModel, + face_locator: FaceLocator, + reference_control_writer: ReferenceAttentionControl, + reference_control_reader: ReferenceAttentionControl, + imageproj: ImageProjModel, + ): + super().__init__() + self.reference_unet = reference_unet + self.denoising_unet = denoising_unet + self.face_locator = face_locator + self.reference_control_writer = reference_control_writer + self.reference_control_reader = reference_control_reader + self.imageproj = imageproj + + def forward( + self, + noisy_latents, + timesteps, + ref_image_latents, + face_emb, + face_mask, + uncond_fwd: bool = False, + ): + """ + Forward pass of the model. + Args: + self (Net): The model instance. + noisy_latents (torch.Tensor): Noisy latents. + timesteps (torch.Tensor): Timesteps. + ref_image_latents (torch.Tensor): Reference image latents. + face_emb (torch.Tensor): Face embedding. + face_mask (torch.Tensor): Face mask. + uncond_fwd (bool, optional): Unconditional forward pass. Defaults to False. + + Returns: + torch.Tensor: Model prediction. + """ + + face_emb = self.imageproj(face_emb) + face_mask = face_mask.to(device="cuda") + face_mask_feature = self.face_locator(face_mask) + + if not uncond_fwd: + ref_timesteps = torch.zeros_like(timesteps) + self.reference_unet( + ref_image_latents, + ref_timesteps, + encoder_hidden_states=face_emb, + return_dict=False, + ) + self.reference_control_reader.update(self.reference_control_writer) + model_pred = self.denoising_unet( + noisy_latents, + timesteps, + mask_cond_fea=face_mask_feature, + encoder_hidden_states=face_emb, + ).sample + + return model_pred + + +def get_noise_scheduler(cfg: argparse.Namespace): + """ + Create noise scheduler for training + + Args: + cfg (omegaconf.dictconfig.DictConfig): Configuration object. + + Returns: + train noise scheduler and val noise scheduler + """ + sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs) + if cfg.enable_zero_snr: + sched_kwargs.update( + rescale_betas_zero_snr=True, + timestep_spacing="trailing", + prediction_type="v_prediction", + ) + val_noise_scheduler = DDIMScheduler(**sched_kwargs) + sched_kwargs.update({"beta_schedule": "scaled_linear"}) + train_noise_scheduler = DDIMScheduler(**sched_kwargs) + + return train_noise_scheduler, val_noise_scheduler + + +def log_validation( + vae, + net, + scheduler, + accelerator, + width, + height, + imageproj, + cfg, + save_dir, + global_step, + face_analysis_model_path, +): + """ + Log validation generation image. + + Args: + vae (nn.Module): Variational Autoencoder model. + net (Net): Main model. + scheduler (diffusers.SchedulerMixin): Noise scheduler. + accelerator (accelerate.Accelerator): Accelerator for training. + width (int): Width of the input images. + height (int): Height of the input images. + imageproj (nn.Module): Image projection model. + cfg (omegaconf.dictconfig.DictConfig): Configuration object. + save_dir (str): directory path to save log result. + global_step (int): Global step number. + + Returns: + None + """ + logger.info("Running validation... ") + + ori_net = accelerator.unwrap_model(net) + ori_net = copy.deepcopy(ori_net) + reference_unet = ori_net.reference_unet + denoising_unet = ori_net.denoising_unet + face_locator = ori_net.face_locator + + generator = torch.manual_seed(42) + image_enc = FaceAnalysis( + name="", + root=face_analysis_model_path, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + image_enc.prepare(ctx_id=0, det_size=(640, 640)) + + pipe = StaticPipeline( + vae=vae, + reference_unet=reference_unet, + denoising_unet=denoising_unet, + face_locator=face_locator, + scheduler=scheduler, + imageproj=imageproj, + ) + + pil_images = [] + for ref_image_path, mask_image_path in zip(cfg.ref_image_paths, cfg.mask_image_paths): + # for mask_image_path in mask_image_paths: + mask_name = os.path.splitext( + os.path.basename(mask_image_path))[0] + ref_name = os.path.splitext( + os.path.basename(ref_image_path))[0] + ref_image_pil = Image.open(ref_image_path).convert("RGB") + mask_image_pil = Image.open(mask_image_path).convert("RGB") + + # Prepare face embeds + face_info = image_enc.get( + cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR)) + face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * ( + x['bbox'][3] - x['bbox'][1]))[-1] # only use the maximum face + face_emb = torch.tensor(face_info['embedding']) + face_emb = face_emb.to( + imageproj.device, imageproj.dtype) + + image = pipe( + ref_image_pil, + mask_image_pil, + width, + height, + 20, + 3.5, + face_emb, + generator=generator, + ).images + image = image[0, :, 0].permute(1, 2, 0).cpu().numpy() # (3, 512, 512) + res_image_pil = Image.fromarray((image * 255).astype(np.uint8)) + # Save ref_image, src_image and the generated_image + w, h = res_image_pil.size + canvas = Image.new("RGB", (w * 3, h), "white") + ref_image_pil = ref_image_pil.resize((w, h)) + mask_image_pil = mask_image_pil.resize((w, h)) + canvas.paste(ref_image_pil, (0, 0)) + canvas.paste(mask_image_pil, (w, 0)) + canvas.paste(res_image_pil, (w * 2, 0)) + + out_file = os.path.join( + save_dir, f"{global_step:06d}-{ref_name}_{mask_name}.jpg" + ) + canvas.save(out_file) + + del pipe + del ori_net + torch.cuda.empty_cache() + + return pil_images + + +def train_stage1_process(cfg: argparse.Namespace) -> None: + """ + Trains the model using the given configuration (cfg). + + Args: + cfg (dict): The configuration dictionary containing the parameters for training. + + Notes: + - This function trains the model using the given configuration. + - It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler. + - The training progress is logged and tracked using the accelerator. + - The trained model is saved after the training is completed. + """ + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps, + mixed_precision=cfg.solver.mixed_precision, + log_with="mlflow", + project_dir="./mlruns", + kwargs_handlers=[kwargs], + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if cfg.seed is not None: + seed_everything(cfg.seed) + + # create output dir for training + exp_name = cfg.exp_name + save_dir = f"{cfg.output_dir}/{exp_name}" + checkpoint_dir = os.path.join(save_dir, "checkpoints") + module_dir = os.path.join(save_dir, "modules") + validation_dir = os.path.join(save_dir, "validation") + + if accelerator.is_main_process: + init_output_dir([save_dir, checkpoint_dir, module_dir, validation_dir]) + + accelerator.wait_for_everyone() + + # create model + if cfg.weight_dtype == "fp16": + weight_dtype = torch.float16 + elif cfg.weight_dtype == "bf16": + weight_dtype = torch.bfloat16 + elif cfg.weight_dtype == "fp32": + weight_dtype = torch.float32 + else: + raise ValueError( + f"Do not support weight dtype: {cfg.weight_dtype} during training" + ) + + # create model + vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to( + "cuda", dtype=weight_dtype + ) + reference_unet = UNet2DConditionModel.from_pretrained( + cfg.base_model_path, + subfolder="unet", + ).to(device="cuda", dtype=weight_dtype) + denoising_unet = UNet3DConditionModel.from_pretrained_2d( + cfg.base_model_path, + "", + subfolder="unet", + unet_additional_kwargs={ + "use_motion_module": False, + "unet_use_temporal_attention": False, + }, + use_landmark=False + ).to(device="cuda", dtype=weight_dtype) + imageproj = ImageProjModel( + cross_attention_dim=denoising_unet.config.cross_attention_dim, + clip_embeddings_dim=512, + clip_extra_context_tokens=4, + ).to(device="cuda", dtype=weight_dtype) + + if cfg.face_locator_pretrained: + face_locator = FaceLocator( + conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256) + ).to(device="cuda", dtype=weight_dtype) + miss, _ = face_locator.load_state_dict( + cfg.face_state_dict_path, strict=False) + logger.info(f"Missing key for face locator: {len(miss)}") + else: + face_locator = FaceLocator( + conditioning_embedding_channels=320, + ).to(device="cuda", dtype=weight_dtype) + # Freeze + vae.requires_grad_(False) + denoising_unet.requires_grad_(True) + reference_unet.requires_grad_(True) + imageproj.requires_grad_(True) + face_locator.requires_grad_(True) + + reference_control_writer = ReferenceAttentionControl( + reference_unet, + do_classifier_free_guidance=False, + mode="write", + fusion_blocks="full", + ) + reference_control_reader = ReferenceAttentionControl( + denoising_unet, + do_classifier_free_guidance=False, + mode="read", + fusion_blocks="full", + ) + + net = Net( + reference_unet, + denoising_unet, + face_locator, + reference_control_writer, + reference_control_reader, + imageproj, + ).to(dtype=weight_dtype) + + # get noise scheduler + train_noise_scheduler, val_noise_scheduler = get_noise_scheduler(cfg) + + # init optimizer + if cfg.solver.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + reference_unet.enable_xformers_memory_efficient_attention() + denoising_unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError( + "xformers is not available. Make sure it is installed correctly" + ) + + if cfg.solver.gradient_checkpointing: + reference_unet.enable_gradient_checkpointing() + denoising_unet.enable_gradient_checkpointing() + + if cfg.solver.scale_lr: + learning_rate = ( + cfg.solver.learning_rate + * cfg.solver.gradient_accumulation_steps + * cfg.data.train_bs + * accelerator.num_processes + ) + else: + learning_rate = cfg.solver.learning_rate + + # Initialize the optimizer + if cfg.solver.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError as exc: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) from exc + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + trainable_params = list( + filter(lambda p: p.requires_grad, net.parameters())) + optimizer = optimizer_cls( + trainable_params, + lr=learning_rate, + betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), + weight_decay=cfg.solver.adam_weight_decay, + eps=cfg.solver.adam_epsilon, + ) + + # init scheduler + lr_scheduler = get_scheduler( + cfg.solver.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.solver.lr_warmup_steps + * cfg.solver.gradient_accumulation_steps, + num_training_steps=cfg.solver.max_train_steps + * cfg.solver.gradient_accumulation_steps, + ) + + # get data loader + train_dataset = FaceMaskDataset( + img_size=(cfg.data.train_width, cfg.data.train_height), + data_meta_paths=cfg.data.meta_paths, + sample_margin=cfg.data.sample_margin, + ) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4 + ) + + # Prepare everything with our `accelerator`. + ( + net, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + net, + optimizer, + train_dataloader, + lr_scheduler, + ) + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / cfg.solver.gradient_accumulation_steps + ) + # Afterwards we recalculate our number of training epochs + num_train_epochs = math.ceil( + cfg.solver.max_train_steps / num_update_steps_per_epoch + ) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + run_time = datetime.now().strftime("%Y%m%d-%H%M") + accelerator.init_trackers( + cfg.exp_name, + init_kwargs={"mlflow": {"run_name": run_time}}, + ) + # dump config file + mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml") + + logger.info(f"save config to {save_dir}") + OmegaConf.save( + cfg, os.path.join(save_dir, "config.yaml") + ) + # Train! + total_batch_size = ( + cfg.data.train_bs + * accelerator.num_processes + * cfg.solver.gradient_accumulation_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + logger.info( + f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}" + ) + logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # load checkpoint + # Potentially load in the weights and states from a previous save + if cfg.resume_from_checkpoint: + logger.info(f"Loading checkpoint from {checkpoint_dir}") + global_step = load_checkpoint(cfg, checkpoint_dir, accelerator) + first_epoch = global_step // num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm( + range(global_step, cfg.solver.max_train_steps), + disable=not accelerator.is_main_process, + ) + progress_bar.set_description("Steps") + net.train() + for _ in range(first_epoch, num_train_epochs): + train_loss = 0.0 + for _, batch in enumerate(train_dataloader): + with accelerator.accumulate(net): + # Convert videos to latent space + pixel_values = batch["img"].to(weight_dtype) + with torch.no_grad(): + latents = vae.encode(pixel_values).latent_dist.sample() + latents = latents.unsqueeze(2) # (b, c, 1, h, w) + latents = latents * 0.18215 + + noise = torch.randn_like(latents) + if cfg.noise_offset > 0.0: + noise += cfg.noise_offset * torch.randn( + (noise.shape[0], noise.shape[1], 1, 1, 1), + device=noise.device, + ) + + bsz = latents.shape[0] + # Sample a random timestep for each video + timesteps = torch.randint( + 0, + train_noise_scheduler.num_train_timesteps, + (bsz,), + device=latents.device, + ) + timesteps = timesteps.long() + + face_mask_img = batch["tgt_mask"] + face_mask_img = face_mask_img.unsqueeze( + 2) + face_mask_img = face_mask_img.to(weight_dtype) + + uncond_fwd = random.random() < cfg.uncond_ratio + face_emb_list = [] + ref_image_list = [] + for _, (ref_img, face_emb) in enumerate( + zip(batch["ref_img"], batch["face_emb"]) + ): + if uncond_fwd: + face_emb_list.append(torch.zeros_like(face_emb)) + else: + face_emb_list.append(face_emb) + ref_image_list.append(ref_img) + + with torch.no_grad(): + ref_img = torch.stack(ref_image_list, dim=0).to( + dtype=vae.dtype, device=vae.device + ) + ref_image_latents = vae.encode( + ref_img + ).latent_dist.sample() + ref_image_latents = ref_image_latents * 0.18215 + + face_emb = torch.stack(face_emb_list, dim=0).to( + dtype=imageproj.dtype, device=imageproj.device + ) + + # add noise + noisy_latents = train_noise_scheduler.add_noise( + latents, noise, timesteps + ) + + # Get the target for loss depending on the prediction type + if train_noise_scheduler.prediction_type == "epsilon": + target = noise + elif train_noise_scheduler.prediction_type == "v_prediction": + target = train_noise_scheduler.get_velocity( + latents, noise, timesteps + ) + else: + raise ValueError( + f"Unknown prediction type {train_noise_scheduler.prediction_type}" + ) + model_pred = net( + noisy_latents, + timesteps, + ref_image_latents, + face_emb, + face_mask_img, + uncond_fwd, + ) + + if cfg.snr_gamma == 0: + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="mean" + ) + else: + snr = compute_snr(train_noise_scheduler, timesteps) + if train_noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + torch.stack( + [snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="none" + ) + loss = ( + loss.mean(dim=list(range(1, len(loss.shape)))) + * mse_loss_weights + ) + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather( + loss.repeat(cfg.data.train_bs)).mean() + train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + trainable_params, + cfg.solver.max_grad_norm, + ) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if accelerator.sync_gradients: + reference_control_reader.clear() + reference_control_writer.clear() + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + if global_step % cfg.checkpointing_steps == 0 or global_step == cfg.solver.max_train_steps: + accelerator.wait_for_everyone() + save_path = os.path.join( + checkpoint_dir, f"checkpoint-{global_step}") + if accelerator.is_main_process: + delete_additional_ckpt(checkpoint_dir, 3) + accelerator.save_state(save_path) + accelerator.wait_for_everyone() + unwrap_net = accelerator.unwrap_model(net) + if accelerator.is_main_process: + save_checkpoint( + unwrap_net.reference_unet, + module_dir, + "reference_unet", + global_step, + total_limit=3, + ) + save_checkpoint( + unwrap_net.imageproj, + module_dir, + "imageproj", + global_step, + total_limit=3, + ) + save_checkpoint( + unwrap_net.denoising_unet, + module_dir, + "denoising_unet", + global_step, + total_limit=3, + ) + save_checkpoint( + unwrap_net.face_locator, + module_dir, + "face_locator", + global_step, + total_limit=3, + ) + + if global_step % cfg.val.validation_steps == 0 or global_step == 1: + if accelerator.is_main_process: + generator = torch.Generator(device=accelerator.device) + generator.manual_seed(cfg.seed) + log_validation( + vae=vae, + net=net, + scheduler=val_noise_scheduler, + accelerator=accelerator, + width=cfg.data.train_width, + height=cfg.data.train_height, + imageproj=imageproj, + cfg=cfg, + save_dir=validation_dir, + global_step=global_step, + face_analysis_model_path=cfg.face_analysis_model_path + ) + + logs = { + "step_loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + + if global_step >= cfg.solver.max_train_steps: + # process final module weight for stage2 + if accelerator.is_main_process: + move_final_checkpoint(save_dir, module_dir, "reference_unet") + move_final_checkpoint(save_dir, module_dir, "imageproj") + move_final_checkpoint(save_dir, module_dir, "denoising_unet") + move_final_checkpoint(save_dir, module_dir, "face_locator") + break + + accelerator.wait_for_everyone() + accelerator.end_training() + + +def load_config(config_path: str) -> dict: + """ + Loads the configuration file. + + Args: + config_path (str): Path to the configuration file. + + Returns: + dict: The configuration dictionary. + """ + + if config_path.endswith(".yaml"): + return OmegaConf.load(config_path) + if config_path.endswith(".py"): + return import_filename(config_path).cfg + raise ValueError("Unsupported format for config file") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, + default="./configs/train/stage1.yaml") + args = parser.parse_args() + + try: + config = load_config(args.config) + train_stage1_process(config) + except Exception as e: + logging.error("Failed to execute the training process: %s", e) diff --git a/scripts/train_stage2.py b/scripts/train_stage2.py new file mode 100644 index 00000000..8cff266d --- /dev/null +++ b/scripts/train_stage2.py @@ -0,0 +1,991 @@ +# pylint: disable=E1101,C0415,W0718,R0801 +# scripts/train_stage2.py +""" +This is the main training script for stage 2 of the project. +It imports necessary packages, defines necessary classes and functions, and trains the model using the provided configuration. + +The script includes the following classes and functions: + +1. Net: A PyTorch model that takes noisy latents, timesteps, reference image latents, face embeddings, + and face masks as input and returns the denoised latents. +2. get_attention_mask: A function that rearranges the mask tensors to the required format. +3. get_noise_scheduler: A function that creates and returns the noise schedulers for training and validation. +4. process_audio_emb: A function that processes the audio embeddings to concatenate with other tensors. +5. log_validation: A function that logs the validation information using the given VAE, image encoder, + network, scheduler, accelerator, width, height, and configuration. +6. train_stage2_process: A function that processes the training stage 2 using the given configuration. +7. load_config: A function that loads the configuration file from the given path. + +The script also includes the necessary imports and a brief description of the purpose of the file. +""" + +import argparse +import copy +import logging +import math +import os +import random +import time +import warnings +from datetime import datetime +from typing import List, Tuple + +import diffusers +import mlflow +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs +from diffusers import AutoencoderKL, DDIMScheduler +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from einops import rearrange, repeat +from omegaconf import OmegaConf +from torch import nn +from tqdm.auto import tqdm + +from hallo.animate.face_animate import FaceAnimatePipeline +from hallo.datasets.audio_processor import AudioProcessor +from hallo.datasets.image_processor import ImageProcessor +from hallo.datasets.talk_video import TalkingVideoDataset +from hallo.models.audio_proj import AudioProjModel +from hallo.models.face_locator import FaceLocator +from hallo.models.image_proj import ImageProjModel +from hallo.models.mutual_self_attention import ReferenceAttentionControl +from hallo.models.unet_2d_condition import UNet2DConditionModel +from hallo.models.unet_3d import UNet3DConditionModel +from hallo.utils.util import (compute_snr, delete_additional_ckpt, + import_filename, init_output_dir, + load_checkpoint, save_checkpoint, + seed_everything, tensor_to_video) + +warnings.filterwarnings("ignore") + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +class Net(nn.Module): + """ + The Net class defines a neural network model that combines a reference UNet2DConditionModel, + a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image. + + Args: + reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation. + denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation. + face_locator (FaceLocator): The face locator model used for face animation. + reference_control_writer: The reference control writer component. + reference_control_reader: The reference control reader component. + imageproj: The image projection model. + audioproj: The audio projection model. + + Forward method: + noisy_latents (torch.Tensor): The noisy latents tensor. + timesteps (torch.Tensor): The timesteps tensor. + ref_image_latents (torch.Tensor): The reference image latents tensor. + face_emb (torch.Tensor): The face embeddings tensor. + audio_emb (torch.Tensor): The audio embeddings tensor. + mask (torch.Tensor): Hard face mask for face locator. + full_mask (torch.Tensor): Pose Mask. + face_mask (torch.Tensor): Face Mask + lip_mask (torch.Tensor): Lip Mask + uncond_img_fwd (bool): A flag indicating whether to perform reference image unconditional forward pass. + uncond_audio_fwd (bool): A flag indicating whether to perform audio unconditional forward pass. + + Returns: + torch.Tensor: The output tensor of the neural network model. + """ + def __init__( + self, + reference_unet: UNet2DConditionModel, + denoising_unet: UNet3DConditionModel, + face_locator: FaceLocator, + reference_control_writer, + reference_control_reader, + imageproj, + audioproj, + ): + super().__init__() + self.reference_unet = reference_unet + self.denoising_unet = denoising_unet + self.face_locator = face_locator + self.reference_control_writer = reference_control_writer + self.reference_control_reader = reference_control_reader + self.imageproj = imageproj + self.audioproj = audioproj + + def forward( + self, + noisy_latents: torch.Tensor, + timesteps: torch.Tensor, + ref_image_latents: torch.Tensor, + face_emb: torch.Tensor, + audio_emb: torch.Tensor, + mask: torch.Tensor, + full_mask: torch.Tensor, + face_mask: torch.Tensor, + lip_mask: torch.Tensor, + uncond_img_fwd: bool = False, + uncond_audio_fwd: bool = False, + ): + """ + simple docstring to prevent pylint error + """ + face_emb = self.imageproj(face_emb) + mask = mask.to(device="cuda") + mask_feature = self.face_locator(mask) + audio_emb = audio_emb.to( + device=self.audioproj.device, dtype=self.audioproj.dtype) + audio_emb = self.audioproj(audio_emb) + + # condition forward + if not uncond_img_fwd: + ref_timesteps = torch.zeros_like(timesteps) + ref_timesteps = repeat( + ref_timesteps, + "b -> (repeat b)", + repeat=ref_image_latents.size(0) // ref_timesteps.size(0), + ) + self.reference_unet( + ref_image_latents, + ref_timesteps, + encoder_hidden_states=face_emb, + return_dict=False, + ) + self.reference_control_reader.update(self.reference_control_writer) + + if uncond_audio_fwd: + audio_emb = torch.zeros_like(audio_emb).to( + device=audio_emb.device, dtype=audio_emb.dtype + ) + + model_pred = self.denoising_unet( + noisy_latents, + timesteps, + mask_cond_fea=mask_feature, + encoder_hidden_states=face_emb, + audio_embedding=audio_emb, + full_mask=full_mask, + face_mask=face_mask, + lip_mask=lip_mask + ).sample + + return model_pred + + +def get_attention_mask(mask: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor: + """ + Rearrange the mask tensors to the required format. + + Args: + mask (torch.Tensor): The input mask tensor. + weight_dtype (torch.dtype): The data type for the mask tensor. + + Returns: + torch.Tensor: The rearranged mask tensor. + """ + if isinstance(mask, List): + _mask = [] + for m in mask: + _mask.append( + rearrange(m, "b f 1 h w -> (b f) (h w)").to(weight_dtype)) + return _mask + mask = rearrange(mask, "b f 1 h w -> (b f) (h w)").to(weight_dtype) + return mask + + +def get_noise_scheduler(cfg: argparse.Namespace) -> Tuple[DDIMScheduler, DDIMScheduler]: + """ + Create noise scheduler for training. + + Args: + cfg (argparse.Namespace): Configuration object. + + Returns: + Tuple[DDIMScheduler, DDIMScheduler]: Train noise scheduler and validation noise scheduler. + """ + + sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs) + if cfg.enable_zero_snr: + sched_kwargs.update( + rescale_betas_zero_snr=True, + timestep_spacing="trailing", + prediction_type="v_prediction", + ) + val_noise_scheduler = DDIMScheduler(**sched_kwargs) + sched_kwargs.update({"beta_schedule": "scaled_linear"}) + train_noise_scheduler = DDIMScheduler(**sched_kwargs) + + return train_noise_scheduler, val_noise_scheduler + + +def process_audio_emb(audio_emb: torch.Tensor) -> torch.Tensor: + """ + Process the audio embedding to concatenate with other tensors. + + Parameters: + audio_emb (torch.Tensor): The audio embedding tensor to process. + + Returns: + concatenated_tensors (List[torch.Tensor]): The concatenated tensor list. + """ + concatenated_tensors = [] + + for i in range(audio_emb.shape[0]): + vectors_to_concat = [ + audio_emb[max(min(i + j, audio_emb.shape[0] - 1), 0)]for j in range(-2, 3)] + concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0)) + + audio_emb = torch.stack(concatenated_tensors, dim=0) + + return audio_emb + + +def log_validation( + accelerator: Accelerator, + vae: AutoencoderKL, + net: Net, + scheduler: DDIMScheduler, + width: int, + height: int, + clip_length: int = 24, + generator: torch.Generator = None, + cfg: dict = None, + save_dir: str = None, + global_step: int = 0, + times: int = None, + face_analysis_model_path: str = "", +) -> None: + """ + Log validation video during the training process. + + Args: + accelerator (Accelerator): The accelerator for distributed training. + vae (AutoencoderKL): The autoencoder model. + net (Net): The main neural network model. + scheduler (DDIMScheduler): The scheduler for noise. + width (int): The width of the input images. + height (int): The height of the input images. + clip_length (int): The length of the video clips. Defaults to 24. + generator (torch.Generator): The random number generator. Defaults to None. + cfg (dict): The configuration dictionary. Defaults to None. + save_dir (str): The directory to save validation results. Defaults to None. + global_step (int): The current global step in training. Defaults to 0. + times (int): The number of inference times. Defaults to None. + face_analysis_model_path (str): The path to the face analysis model. Defaults to "". + + Returns: + torch.Tensor: The tensor result of the validation. + """ + ori_net = accelerator.unwrap_model(net) + reference_unet = ori_net.reference_unet + denoising_unet = ori_net.denoising_unet + face_locator = ori_net.face_locator + imageproj = ori_net.imageproj + audioproj = ori_net.audioproj + + generator = torch.manual_seed(42) + tmp_denoising_unet = copy.deepcopy(denoising_unet) + + pipeline = FaceAnimatePipeline( + vae=vae, + reference_unet=reference_unet, + denoising_unet=tmp_denoising_unet, + face_locator=face_locator, + image_proj=imageproj, + scheduler=scheduler, + ) + pipeline = pipeline.to("cuda") + + image_processor = ImageProcessor((width, height), face_analysis_model_path) + audio_processor = AudioProcessor( + cfg.data.sample_rate, + cfg.data.fps, + cfg.wav2vec_config.model_path, + cfg.wav2vec_config.features == "last", + os.path.dirname(cfg.audio_separator.model_path), + os.path.basename(cfg.audio_separator.model_path), + os.path.join(save_dir, '.cache', "audio_preprocess") + ) + + for idx, ref_img_path in enumerate(cfg.ref_img_path): + audio_path = cfg.audio_path[idx] + source_image_pixels, \ + source_image_face_region, \ + source_image_face_emb, \ + source_image_full_mask, \ + source_image_face_mask, \ + source_image_lip_mask = image_processor.preprocess( + ref_img_path, os.path.join(save_dir, '.cache'), cfg.face_expand_ratio) + audio_emb, audio_length = audio_processor.preprocess( + audio_path, clip_length) + + audio_emb = process_audio_emb(audio_emb) + + source_image_pixels = source_image_pixels.unsqueeze(0) + source_image_face_region = source_image_face_region.unsqueeze(0) + source_image_face_emb = source_image_face_emb.reshape(1, -1) + source_image_face_emb = torch.tensor(source_image_face_emb) + + source_image_full_mask = [ + (mask.repeat(clip_length, 1)) + for mask in source_image_full_mask + ] + source_image_face_mask = [ + (mask.repeat(clip_length, 1)) + for mask in source_image_face_mask + ] + source_image_lip_mask = [ + (mask.repeat(clip_length, 1)) + for mask in source_image_lip_mask + ] + + times = audio_emb.shape[0] // clip_length + tensor_result = [] + generator = torch.manual_seed(42) + for t in range(times): + print(f"[{t+1}/{times}]") + + if len(tensor_result) == 0: + # The first iteration + motion_zeros = source_image_pixels.repeat( + cfg.data.n_motion_frames, 1, 1, 1) + motion_zeros = motion_zeros.to( + dtype=source_image_pixels.dtype, device=source_image_pixels.device) + pixel_values_ref_img = torch.cat( + [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames + else: + motion_frames = tensor_result[-1][0] + motion_frames = motion_frames.permute(1, 0, 2, 3) + motion_frames = motion_frames[0 - cfg.data.n_motion_frames:] + motion_frames = motion_frames * 2.0 - 1.0 + motion_frames = motion_frames.to( + dtype=source_image_pixels.dtype, device=source_image_pixels.device) + pixel_values_ref_img = torch.cat( + [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames + + pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) + + audio_tensor = audio_emb[ + t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0]) + ] + audio_tensor = audio_tensor.unsqueeze(0) + audio_tensor = audio_tensor.to( + device=audioproj.device, dtype=audioproj.dtype) + audio_tensor = audioproj(audio_tensor) + + pipeline_output = pipeline( + ref_image=pixel_values_ref_img, + audio_tensor=audio_tensor, + face_emb=source_image_face_emb, + face_mask=source_image_face_region, + pixel_values_full_mask=source_image_full_mask, + pixel_values_face_mask=source_image_face_mask, + pixel_values_lip_mask=source_image_lip_mask, + width=cfg.data.train_width, + height=cfg.data.train_height, + video_length=clip_length, + num_inference_steps=cfg.inference_steps, + guidance_scale=cfg.cfg_scale, + generator=generator, + ) + + tensor_result.append(pipeline_output.videos) + + tensor_result = torch.cat(tensor_result, dim=2) + tensor_result = tensor_result.squeeze(0) + tensor_result = tensor_result[:, :audio_length] + audio_name = os.path.basename(audio_path).split('.')[0] + ref_name = os.path.basename(ref_img_path).split('.')[0] + output_file = os.path.join(save_dir,f"{global_step}_{ref_name}_{audio_name}.mp4") + # save the result after all iteration + tensor_to_video(tensor_result, output_file, audio_path) + + + # clean up + del tmp_denoising_unet + del pipeline + del image_processor + del audio_processor + torch.cuda.empty_cache() + + return tensor_result + + +def train_stage2_process(cfg: argparse.Namespace) -> None: + """ + Trains the model using the given configuration (cfg). + + Args: + cfg (dict): The configuration dictionary containing the parameters for training. + + Notes: + - This function trains the model using the given configuration. + - It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler. + - The training progress is logged and tracked using the accelerator. + - The trained model is saved after the training is completed. + """ + kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps, + mixed_precision=cfg.solver.mixed_precision, + log_with="mlflow", + project_dir="./mlruns", + kwargs_handlers=[kwargs], + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if cfg.seed is not None: + seed_everything(cfg.seed) + + # create output dir for training + exp_name = cfg.exp_name + save_dir = f"{cfg.output_dir}/{exp_name}" + checkpoint_dir = os.path.join(save_dir, "checkpoints") + module_dir = os.path.join(save_dir, "modules") + validation_dir = os.path.join(save_dir, "validation") + if accelerator.is_main_process: + init_output_dir([save_dir, checkpoint_dir, module_dir, validation_dir]) + + accelerator.wait_for_everyone() + + if cfg.weight_dtype == "fp16": + weight_dtype = torch.float16 + elif cfg.weight_dtype == "bf16": + weight_dtype = torch.bfloat16 + elif cfg.weight_dtype == "fp32": + weight_dtype = torch.float32 + else: + raise ValueError( + f"Do not support weight dtype: {cfg.weight_dtype} during training" + ) + + # Create Models + vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to( + "cuda", dtype=weight_dtype + ) + reference_unet = UNet2DConditionModel.from_pretrained( + cfg.base_model_path, + subfolder="unet", + ).to(device="cuda", dtype=weight_dtype) + denoising_unet = UNet3DConditionModel.from_pretrained_2d( + cfg.base_model_path, + cfg.mm_path, + subfolder="unet", + unet_additional_kwargs=OmegaConf.to_container( + cfg.unet_additional_kwargs), + use_landmark=False + ).to(device="cuda", dtype=weight_dtype) + imageproj = ImageProjModel( + cross_attention_dim=denoising_unet.config.cross_attention_dim, + clip_embeddings_dim=512, + clip_extra_context_tokens=4, + ).to(device="cuda", dtype=weight_dtype) + face_locator = FaceLocator( + conditioning_embedding_channels=320, + ).to(device="cuda", dtype=weight_dtype) + audioproj = AudioProjModel( + seq_len=5, + blocks=12, + channels=768, + intermediate_dim=512, + output_dim=768, + context_tokens=32, + ).to(device="cuda", dtype=weight_dtype) + + # load module weight from stage 1 + stage1_ckpt_dir = cfg.stage1_ckpt_dir + denoising_unet.load_state_dict( + torch.load( + os.path.join(stage1_ckpt_dir, "denoising_unet.pth"), + map_location="cpu", + ), + strict=False, + ) + reference_unet.load_state_dict( + torch.load( + os.path.join(stage1_ckpt_dir, "reference_unet.pth"), + map_location="cpu", + ), + strict=False, + ) + face_locator.load_state_dict( + torch.load( + os.path.join(stage1_ckpt_dir, "face_locator.pth"), + map_location="cpu", + ), + strict=False, + ) + imageproj.load_state_dict( + torch.load( + os.path.join(stage1_ckpt_dir, "imageproj.pth"), + map_location="cpu", + ), + strict=False, + ) + + # Freeze + vae.requires_grad_(False) + imageproj.requires_grad_(False) + reference_unet.requires_grad_(False) + denoising_unet.requires_grad_(False) + face_locator.requires_grad_(False) + audioproj.requires_grad_(True) + + # Set motion module learnable + trainable_modules = cfg.trainable_para + for name, module in denoising_unet.named_modules(): + if any(trainable_mod in name for trainable_mod in trainable_modules): + for params in module.parameters(): + params.requires_grad_(True) + + reference_control_writer = ReferenceAttentionControl( + reference_unet, + do_classifier_free_guidance=False, + mode="write", + fusion_blocks="full", + ) + reference_control_reader = ReferenceAttentionControl( + denoising_unet, + do_classifier_free_guidance=False, + mode="read", + fusion_blocks="full", + ) + + net = Net( + reference_unet, + denoising_unet, + face_locator, + reference_control_writer, + reference_control_reader, + imageproj, + audioproj, + ).to(dtype=weight_dtype) + + # get noise scheduler + train_noise_scheduler, val_noise_scheduler = get_noise_scheduler(cfg) + + if cfg.solver.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + reference_unet.enable_xformers_memory_efficient_attention() + denoising_unet.enable_xformers_memory_efficient_attention() + + else: + raise ValueError( + "xformers is not available. Make sure it is installed correctly" + ) + + if cfg.solver.gradient_checkpointing: + reference_unet.enable_gradient_checkpointing() + denoising_unet.enable_gradient_checkpointing() + + if cfg.solver.scale_lr: + learning_rate = ( + cfg.solver.learning_rate + * cfg.solver.gradient_accumulation_steps + * cfg.data.train_bs + * accelerator.num_processes + ) + else: + learning_rate = cfg.solver.learning_rate + + # Initialize the optimizer + if cfg.solver.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError as exc: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) from exc + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + trainable_params = list( + filter(lambda p: p.requires_grad, net.parameters())) + logger.info(f"Total trainable params {len(trainable_params)}") + optimizer = optimizer_cls( + trainable_params, + lr=learning_rate, + betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), + weight_decay=cfg.solver.adam_weight_decay, + eps=cfg.solver.adam_epsilon, + ) + + # Scheduler + lr_scheduler = get_scheduler( + cfg.solver.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.solver.lr_warmup_steps + * cfg.solver.gradient_accumulation_steps, + num_training_steps=cfg.solver.max_train_steps + * cfg.solver.gradient_accumulation_steps, + ) + + # get data loader + train_dataset = TalkingVideoDataset( + img_size=(cfg.data.train_width, cfg.data.train_height), + sample_rate=cfg.data.sample_rate, + n_sample_frames=cfg.data.n_sample_frames, + n_motion_frames=cfg.data.n_motion_frames, + audio_margin=cfg.data.audio_margin, + data_meta_paths=cfg.data.train_meta_paths, + wav2vec_cfg=cfg.wav2vec_config, + ) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=16 + ) + + # Prepare everything with our `accelerator`. + ( + net, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + net, + optimizer, + train_dataloader, + lr_scheduler, + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / cfg.solver.gradient_accumulation_steps + ) + # Afterwards we recalculate our number of training epochs + num_train_epochs = math.ceil( + cfg.solver.max_train_steps / num_update_steps_per_epoch + ) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + run_time = datetime.now().strftime("%Y%m%d-%H%M") + accelerator.init_trackers( + exp_name, + init_kwargs={"mlflow": {"run_name": run_time}}, + ) + # dump config file + mlflow.log_dict( + OmegaConf.to_container( + cfg), "config.yaml" + ) + logger.info(f"save config to {save_dir}") + OmegaConf.save( + cfg, os.path.join(save_dir, "config.yaml") + ) + + # Train! + total_batch_size = ( + cfg.data.train_bs + * accelerator.num_processes + * cfg.solver.gradient_accumulation_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + logger.info( + f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}" + ) + logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # # Potentially load in the weights and states from a previous save + if cfg.resume_from_checkpoint: + logger.info(f"Loading checkpoint from {checkpoint_dir}") + global_step = load_checkpoint(cfg, checkpoint_dir, accelerator) + first_epoch = global_step // num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm( + range(global_step, cfg.solver.max_train_steps), + disable=not accelerator.is_local_main_process, + ) + progress_bar.set_description("Steps") + + for _ in range(first_epoch, num_train_epochs): + train_loss = 0.0 + t_data_start = time.time() + for _, batch in enumerate(train_dataloader): + t_data = time.time() - t_data_start + with accelerator.accumulate(net): + # Convert videos to latent space + pixel_values_vid = batch["pixel_values_vid"].to(weight_dtype) + + pixel_values_face_mask = batch["pixel_values_face_mask"] + pixel_values_face_mask = get_attention_mask( + pixel_values_face_mask, weight_dtype + ) + pixel_values_lip_mask = batch["pixel_values_lip_mask"] + pixel_values_lip_mask = get_attention_mask( + pixel_values_lip_mask, weight_dtype + ) + pixel_values_full_mask = batch["pixel_values_full_mask"] + pixel_values_full_mask = get_attention_mask( + pixel_values_full_mask, weight_dtype + ) + + with torch.no_grad(): + video_length = pixel_values_vid.shape[1] + pixel_values_vid = rearrange( + pixel_values_vid, "b f c h w -> (b f) c h w" + ) + latents = vae.encode(pixel_values_vid).latent_dist.sample() + latents = rearrange( + latents, "(b f) c h w -> b c f h w", f=video_length + ) + latents = latents * 0.18215 + + noise = torch.randn_like(latents) + if cfg.noise_offset > 0: + noise += cfg.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1, 1), + device=latents.device, + ) + + bsz = latents.shape[0] + # Sample a random timestep for each video + timesteps = torch.randint( + 0, + train_noise_scheduler.num_train_timesteps, + (bsz,), + device=latents.device, + ) + timesteps = timesteps.long() + + # mask for face locator + pixel_values_mask = ( + batch["pixel_values_mask"].unsqueeze( + 1).to(dtype=weight_dtype) + ) + pixel_values_mask = repeat( + pixel_values_mask, + "b f c h w -> b (repeat f) c h w", + repeat=video_length, + ) + pixel_values_mask = pixel_values_mask.transpose( + 1, 2) + + uncond_img_fwd = random.random() < cfg.uncond_img_ratio + uncond_audio_fwd = random.random() < cfg.uncond_audio_ratio + + start_frame = random.random() < cfg.start_ratio + pixel_values_ref_img = batch["pixel_values_ref_img"].to( + dtype=weight_dtype + ) + # initialize the motion frames as zero maps + if start_frame: + pixel_values_ref_img[:, 1:] = 0.0 + + ref_img_and_motion = rearrange( + pixel_values_ref_img, "b f c h w -> (b f) c h w" + ) + + with torch.no_grad(): + ref_image_latents = vae.encode( + ref_img_and_motion + ).latent_dist.sample() + ref_image_latents = ref_image_latents * 0.18215 + image_prompt_embeds = batch["face_emb"].to( + dtype=imageproj.dtype, device=imageproj.device + ) + + # add noise + noisy_latents = train_noise_scheduler.add_noise( + latents, noise, timesteps + ) + + # Get the target for loss depending on the prediction type + if train_noise_scheduler.prediction_type == "epsilon": + target = noise + elif train_noise_scheduler.prediction_type == "v_prediction": + target = train_noise_scheduler.get_velocity( + latents, noise, timesteps + ) + else: + raise ValueError( + f"Unknown prediction type {train_noise_scheduler.prediction_type}" + ) + + # ---- Forward!!! ----- + model_pred = net( + noisy_latents=noisy_latents, + timesteps=timesteps, + ref_image_latents=ref_image_latents, + face_emb=image_prompt_embeds, + mask=pixel_values_mask, + full_mask=pixel_values_full_mask, + face_mask=pixel_values_face_mask, + lip_mask=pixel_values_lip_mask, + audio_emb=batch["audio_tensor"].to( + dtype=weight_dtype), + uncond_img_fwd=uncond_img_fwd, + uncond_audio_fwd=uncond_audio_fwd, + ) + + if cfg.snr_gamma == 0: + loss = F.mse_loss( + model_pred.float(), + target.float(), + reduction="mean", + ) + else: + snr = compute_snr(train_noise_scheduler, timesteps) + if train_noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + torch.stack( + [snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + loss = F.mse_loss( + model_pred.float(), + target.float(), + reduction="mean", + ) + loss = ( + loss.mean(dim=list(range(1, len(loss.shape)))) + * mse_loss_weights + ).mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather( + loss.repeat(cfg.data.train_bs)).mean() + train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + trainable_params, + cfg.solver.max_grad_norm, + ) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if accelerator.sync_gradients: + reference_control_reader.clear() + reference_control_writer.clear() + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % cfg.val.validation_steps == 0 or global_step==1: + if accelerator.is_main_process: + generator = torch.Generator(device=accelerator.device) + generator.manual_seed(cfg.seed) + + log_validation( + accelerator=accelerator, + vae=vae, + net=net, + scheduler=val_noise_scheduler, + width=cfg.data.train_width, + height=cfg.data.train_height, + clip_length=cfg.data.n_sample_frames, + cfg=cfg, + save_dir=validation_dir, + global_step=global_step, + times=cfg.single_inference_times if cfg.single_inference_times is not None else None, + face_analysis_model_path=cfg.face_analysis_model_path + ) + + logs = { + "step_loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + "td": f"{t_data:.2f}s", + } + t_data_start = time.time() + progress_bar.set_postfix(**logs) + + if ( + global_step % cfg.checkpointing_steps == 0 + or global_step == cfg.solver.max_train_steps + ): + # save model + save_path = os.path.join( + checkpoint_dir, f"checkpoint-{global_step}") + if accelerator.is_main_process: + delete_additional_ckpt(checkpoint_dir, 30) + accelerator.wait_for_everyone() + accelerator.save_state(save_path) + + # save model weight + unwrap_net = accelerator.unwrap_model(net) + if accelerator.is_main_process: + save_checkpoint( + unwrap_net, + module_dir, + "net", + global_step, + total_limit=30, + ) + if global_step >= cfg.solver.max_train_steps: + break + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + accelerator.end_training() + + +def load_config(config_path: str) -> dict: + """ + Loads the configuration file. + + Args: + config_path (str): Path to the configuration file. + + Returns: + dict: The configuration dictionary. + """ + + if config_path.endswith(".yaml"): + return OmegaConf.load(config_path) + if config_path.endswith(".py"): + return import_filename(config_path).cfg + raise ValueError("Unsupported format for config file") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", type=str, default="./configs/train/stage2.yaml" + ) + args = parser.parse_args() + + try: + config = load_config(args.config) + train_stage2_process(config) + except Exception as e: + logging.error("Failed to execute the training process: %s", e)