From 0aed1868dfafecb92a644cf9e1a2f55e3c4de021 Mon Sep 17 00:00:00 2001 From: rlsu9 <147024991+rlsu9@users.noreply.github.com> Date: Tue, 31 Dec 2024 15:23:17 -0800 Subject: [PATCH] [feat]: Add format auto fixer to main branch (#124) --- .github/workflows/codespell.yml | 45 ++ .github/workflows/ruff.yml | 50 ++ .github/workflows/yapf.yml | 38 ++ README.md | 2 +- demo/gradio_web_demo.py | 54 +- env_setup.sh | 2 + .../preprocess_text_embeddings.py | 99 +-- .../data_preprocess/preprocess_vae_latents.py | 74 ++- .../preprocess_validation_text_embeddings.py | 49 +- fastvideo/dataset/__init__.py | 57 +- fastvideo/dataset/latent_datasets.py | 47 +- fastvideo/dataset/t2v_datasets.py | 94 +-- fastvideo/dataset/transform.py | 144 +++-- fastvideo/distill.py | 512 ++++++++------- fastvideo/distill/discriminator.py | 47 +- fastvideo/distill/solver.py | 122 ++-- fastvideo/distill_adv.py | 430 ++++++------- fastvideo/models/flash_attn_no_pad.py | 23 +- fastvideo/models/hunyuan/constants.py | 12 +- .../models/hunyuan/diffusion/__init__.py | 1 + .../hunyuan/diffusion/pipelines/__init__.py | 1 + .../pipelines/pipeline_hunyuan_video.py | 326 +++++----- .../hunyuan/diffusion/schedulers/__init__.py | 1 + .../scheduling_flow_match_discrete.py | 37 +- fastvideo/models/hunyuan/idle_config.py | 96 ++- fastvideo/models/hunyuan/inference.py | 121 ++-- fastvideo/models/hunyuan/modules/__init__.py | 2 +- fastvideo/models/hunyuan/modules/attenion.py | 46 +- .../models/hunyuan/modules/embed_layers.py | 36 +- .../models/hunyuan/modules/mlp_layers.py | 58 +- fastvideo/models/hunyuan/modules/models.py | 397 ++++++------ .../models/hunyuan/modules/modulate_layers.py | 16 +- .../models/hunyuan/modules/norm_layers.py | 4 +- .../models/hunyuan/modules/posemb_layers.py | 64 +- .../models/hunyuan/modules/token_refiner.py | 125 ++-- fastvideo/models/hunyuan/prompt_rewrite.py | 1 - .../models/hunyuan/text_encoder/__init__.py | 96 ++- fastvideo/models/hunyuan/utils/data_utils.py | 3 +- fastvideo/models/hunyuan/utils/file_utils.py | 12 +- fastvideo/models/hunyuan/utils/helpers.py | 4 +- ...preprocess_text_encoder_tokenizer_utils.py | 10 +- fastvideo/models/hunyuan/vae/__init__.py | 10 +- .../hunyuan/vae/autoencoder_kl_causal_3d.py | 238 ++++--- .../hunyuan/vae/unet_causal_3d_blocks.py | 288 +++++---- fastvideo/models/hunyuan/vae/vae.py | 137 ++-- .../mochi_hf/convert_diffusers_to_mochi.py | 605 ++++++++---------- .../models/mochi_hf/mochi_latents_utils.py | 60 +- fastvideo/models/mochi_hf/modeling_mochi.py | 247 ++++--- fastvideo/models/mochi_hf/norm.py | 26 +- fastvideo/models/mochi_hf/pipeline_mochi.py | 247 ++++--- fastvideo/sample/generate_synthetic.py | 72 ++- .../sample/sample_t2v_diffusers_hunyuan.py | 188 +++--- fastvideo/sample/sample_t2v_hunyuan.py | 122 ++-- fastvideo/sample/sample_t2v_mochi.py | 84 ++- fastvideo/sample/sample_t2v_mochi_no_sp.py | 22 +- fastvideo/train.py | 368 +++++------ fastvideo/utils/checkpoint.py | 161 ++--- fastvideo/utils/communications.py | 106 +-- fastvideo/utils/dataset_utils.py | 97 ++- fastvideo/utils/env_utils.py | 8 +- fastvideo/utils/fsdp_util.py | 57 +- fastvideo/utils/load.py | 144 +++-- fastvideo/utils/logging_.py | 4 +- fastvideo/utils/optimizer.py | 13 +- fastvideo/utils/parallel_states.py | 11 +- fastvideo/utils/validation.py | 176 +++-- format.sh | 237 +++++++ predict.py | 83 ++- requirements-lint.txt | 14 + scripts/huggingface/download_hf.py | 16 +- 70 files changed, 3825 insertions(+), 3374 deletions(-) create mode 100644 .github/workflows/codespell.yml create mode 100644 .github/workflows/ruff.yml create mode 100644 .github/workflows/yapf.yml create mode 100644 format.sh create mode 100644 requirements-lint.txt diff --git a/.github/workflows/codespell.yml b/.github/workflows/codespell.yml new file mode 100644 index 0000000..1d62393 --- /dev/null +++ b/.github/workflows/codespell.yml @@ -0,0 +1,45 @@ +name: codespell + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + paths: + - "**/*.py" + - "**/*.md" + - "**/*.rst" + - pyproject.toml + - requirements-lint.txt + - .github/workflows/codespell.yml + pull_request: + branches: + - main + paths: + - "**/*.py" + - "**/*.md" + - "**/*.rst" + - pyproject.toml + - requirements-lint.txt + - .github/workflows/codespell.yml + +jobs: + codespell: + runs-on: ubuntu-latest + steps: + - name: Check out repository + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' # or any version you need + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-lint.txt + - name: Spelling check with codespell + run: | + # Refer to the above environment variable here + codespell --toml pyproject.toml $CODESPELL_EXCLUDES diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..a2f556b --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,50 @@ +name: ruff + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + paths: + - "**/*.py" + - pyproject.toml + - requirements-lint.txt + - .github/workflows/matchers/ruff.json + - .github/workflows/ruff.yml + pull_request: + branches: + - main + # This workflow is only relevant when one of the following files changes. + # However, we have github configured to expect and require this workflow + # to run and pass before github with auto-merge a pull request. Until github + # allows more flexible auto-merge policy, we can just run this on every PR. + # It doesn't take that long to run, anyway. + #paths: + # - "**/*.py" + # - pyproject.toml + # - requirements-lint.txt + # - .github/workflows/matchers/ruff.json + # - .github/workflows/ruff.yml + +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - name: Check out repository + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' # or any version you need + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-lint.txt + - name: Analysing the code with ruff + run: | + ruff check . + - name: Run isort + run: | + isort . --check-only diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml new file mode 100644 index 0000000..3908e95 --- /dev/null +++ b/.github/workflows/yapf.yml @@ -0,0 +1,38 @@ +name: yapf + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/yapf.yml + pull_request: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/yapf.yml + +jobs: + yapf: + runs-on: ubuntu-latest + steps: + - name: Check out repository + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' # or any version you need + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install yapf==0.32.0 + pip install toml==0.10.2 + - name: Running yapf + run: | + yapf --diff --recursive . diff --git a/README.md b/README.md index 31346c7..5991d55 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ Ensure your data is prepared and preprocessed in the format specified in [data_p ```bash python scripts/huggingface/download_hf.py --repo_id=FastVideo/Mochi-Black-Myth --local_dir=data/Mochi-Black-Myth --repo_type=dataset ``` -Download the original model weights as specificed in [Distill Section](#-distill): +Download the original model weights as specified in [Distill Section](#-distill): Then you can run the finetune with: ``` diff --git a/demo/gradio_web_demo.py b/demo/gradio_web_demo.py index 1d4b596..8396e3d 100644 --- a/demo/gradio_web_demo.py +++ b/demo/gradio_web_demo.py @@ -1,13 +1,15 @@ +import argparse +import os +import tempfile + import gradio as gr import torch -from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline -from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.utils import export_to_video + from fastvideo.distill.solver import PCMFMScheduler -import tempfile -import os -import argparse +from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel +from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline def init_args(): @@ -21,7 +23,9 @@ def init_args(): parser.add_argument("--model_path", type=str, default="data/mochi") parser.add_argument("--seed", type=int, default=12345) parser.add_argument("--transformer_path", type=str, default=None) - parser.add_argument("--scheduler_type", type=str, default="pcm_linear_quadratic") + parser.add_argument("--scheduler_type", + type=str, + default="pcm_linear_quadratic") parser.add_argument("--lora_checkpoint_dir", type=str, default=None) parser.add_argument("--shift", type=float, default=8.0) parser.add_argument("--num_euler_timesteps", type=int, default=50) @@ -32,7 +36,6 @@ def init_args(): def load_model(args): - device = "cuda" if torch.cuda.is_available() else "cpu" if args.scheduler_type == "euler": scheduler = FlowMatchEulerDiscreteScheduler() else: @@ -47,15 +50,15 @@ def load_model(args): ) if args.transformer_path: - transformer = MochiTransformer3DModel.from_pretrained(args.transformer_path) + transformer = MochiTransformer3DModel.from_pretrained( + args.transformer_path) else: transformer = MochiTransformer3DModel.from_pretrained( - args.model_path, subfolder="transformer/" - ) + args.model_path, subfolder="transformer/") - pipe = MochiPipeline.from_pretrained( - args.model_path, transformer=transformer, scheduler=scheduler - ) + pipe = MochiPipeline.from_pretrained(args.model_path, + transformer=transformer, + scheduler=scheduler) pipe.enable_vae_tiling() # pipe.to(device) # if args.cpu_offload: @@ -76,7 +79,7 @@ def generate_video( randomize_seed=False, ): if randomize_seed: - seed = torch.randint(0, 1000000, (1,)).item() + seed = torch.randint(0, 1000000, (1, )).item() generator = torch.Generator(device="cuda").manual_seed(seed) @@ -134,9 +137,11 @@ def generate_video( step=32, value=args.height, ) - width = gr.Slider( - label="Width", minimum=256, maximum=1024, step=32, value=args.width - ) + width = gr.Slider(label="Width", + minimum=256, + maximum=1024, + step=32, + value=args.width) with gr.Row(): num_frames = gr.Slider( @@ -159,9 +164,8 @@ def generate_video( ) with gr.Row(): - use_negative_prompt = gr.Checkbox( - label="Use negative prompt", value=False - ) + use_negative_prompt = gr.Checkbox(label="Use negative prompt", + value=False) negative_prompt = gr.Text( label="Negative prompt", max_lines=1, @@ -169,9 +173,11 @@ def generate_video( visible=False, ) - seed = gr.Slider( - label="Seed", minimum=0, maximum=1000000, step=1, value=args.seed - ) + seed = gr.Slider(label="Seed", + minimum=0, + maximum=1000000, + step=1, + value=args.seed) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) seed_output = gr.Number(label="Used Seed") @@ -201,4 +207,4 @@ def generate_video( ) if __name__ == "__main__": - demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) \ No newline at end of file + demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) diff --git a/env_setup.sh b/env_setup.sh index dc5598c..132a751 100755 --- a/env_setup.sh +++ b/env_setup.sh @@ -6,5 +6,7 @@ pip install torch==2.5.0 torchvision --index-url https://download.pytorch.org/wh # install FA2 and diffusers pip install packaging ninja && pip install flash-attn==2.7.0.post2 --no-build-isolation +pip install -r requirements-lint.txt + # install fastvideo pip install -e . diff --git a/fastvideo/data_preprocess/preprocess_text_embeddings.py b/fastvideo/data_preprocess/preprocess_text_embeddings.py index 6e29c29..909bd38 100644 --- a/fastvideo/data_preprocess/preprocess_text_embeddings.py +++ b/fastvideo/data_preprocess/preprocess_text_embeddings.py @@ -1,30 +1,34 @@ import argparse -import torch -from accelerate.logging import get_logger -from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline -from diffusers.utils import export_to_video import json import os -import torch.distributed as dist -logger = get_logger(__name__) -from torch.utils.data import Dataset -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data import DataLoader -from fastvideo.utils.load import load_text_encoder, load_vae +import torch +import torch.distributed as dist +from accelerate.logging import get_logger +from diffusers.utils import export_to_video from diffusers.video_processor import VideoProcessor +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm +from fastvideo.utils.load import load_text_encoder, load_vae + +logger = get_logger(__name__) + class T5dataset(Dataset): + def __init__( - self, json_path, vae_debug, + self, + json_path, + vae_debug, ): self.json_path = json_path self.vae_debug = vae_debug with open(self.json_path, "r") as f: train_dataset = json.load(f) - self.train_dataset = sorted(train_dataset, key=lambda x: x["latent_path"]) + self.train_dataset = sorted(train_dataset, + key=lambda x: x["latent_path"]) def __getitem__(self, idx): caption = self.train_dataset[idx]["caption"] @@ -32,15 +36,17 @@ def __getitem__(self, idx): length = self.train_dataset[idx]["length"] if self.vae_debug: latents = torch.load( - os.path.join( - args.output_dir, "latent", self.train_dataset[idx]["latent_path"] - ), + os.path.join(args.output_dir, "latent", + self.train_dataset[idx]["latent_path"]), map_location="cpu", ) else: latents = [] - return dict(caption=caption, latents=latents, filename=filename, length=length) + return dict(caption=caption, + latents=latents, + filename=filename, + length=length) def __len__(self): return len(self.train_dataset) @@ -54,25 +60,31 @@ def main(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.cuda.set_device(local_rank) if not dist.is_initialized(): - dist.init_process_group( - backend="nccl", init_method="env://", world_size=world_size, rank=local_rank - ) + dist.init_process_group(backend="nccl", + init_method="env://", + world_size=world_size, + rank=local_rank) videoprocessor = VideoProcessor(vae_scale_factor=8) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(os.path.join(args.output_dir, "video"), exist_ok=True) os.makedirs(os.path.join(args.output_dir, "latent"), exist_ok=True) os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True) - os.makedirs(os.path.join(args.output_dir, "prompt_attention_mask"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "prompt_attention_mask"), + exist_ok=True) - latents_json_path = os.path.join(args.output_dir, "videos2caption_temp.json") + latents_json_path = os.path.join(args.output_dir, + "videos2caption_temp.json") train_dataset = T5dataset(latents_json_path, args.vae_debug) - text_encoder = load_text_encoder(args.model_type, args.model_path, device=device) + text_encoder = load_text_encoder(args.model_type, + args.model_path, + device=device) vae, autocast_type, fps = load_vae(args.model_type, args.model_path) vae.enable_tiling() - sampler = DistributedSampler( - train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True - ) + sampler = DistributedSampler(train_dataset, + rank=local_rank, + num_replicas=world_size, + shuffle=True) train_dataloader = DataLoader( train_dataset, sampler=sampler, @@ -85,25 +97,25 @@ def main(args): with torch.inference_mode(): with torch.autocast("cuda", dtype=autocast_type): prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt( - prompt=data["caption"], - ) + prompt=data["caption"], ) if args.vae_debug: latents = data["latents"] - video = vae.decode(latents.to(device), return_dict=False)[0] + video = vae.decode(latents.to(device), + return_dict=False)[0] video = videoprocessor.postprocess_video(video) for idx, video_name in enumerate(data["filename"]): - prompt_embed_path = os.path.join( - args.output_dir, "prompt_embed", video_name + ".pt" - ) - video_path = os.path.join( - args.output_dir, "video", video_name + ".mp4" - ) + prompt_embed_path = os.path.join(args.output_dir, + "prompt_embed", + video_name + ".pt") + video_path = os.path.join(args.output_dir, "video", + video_name + ".mp4") prompt_attention_mask_path = os.path.join( - args.output_dir, "prompt_attention_mask", video_name + ".pt" - ) + args.output_dir, "prompt_attention_mask", + video_name + ".pt") # save latent torch.save(prompt_embeds[idx], prompt_embed_path) - torch.save(prompt_attention_mask[idx], prompt_attention_mask_path) + torch.save(prompt_attention_mask[idx], + prompt_attention_mask_path) print(f"sample {video_name} saved") if args.vae_debug: export_to_video(video[idx], video_path, fps=fps) @@ -121,7 +133,8 @@ def main(args): if local_rank == 0: # os.remove(latents_json_path) all_json_data = [item for sublist in gathered_data for item in sublist] - with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f: + with open(os.path.join(args.output_dir, "videos2caption.json"), + "w") as f: json.dump(all_json_data, f, indent=4) @@ -135,7 +148,8 @@ def main(args): "--dataloader_num_workers", type=int, default=1, - help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + help= + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", ) parser.add_argument( "--train_batch_size", @@ -143,13 +157,16 @@ def main(args): default=1, help="Batch size (per device) for the training dataloader.", ) - parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl") + parser.add_argument("--text_encoder_name", + type=str, + default="google/t5-v1_1-xxl") parser.add_argument("--cache_dir", type=str, default="./cache_dir") parser.add_argument( "--output_dir", type=str, default=None, - help="The output directory where the model predictions and checkpoints will be written.", + help= + "The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument("--vae_debug", action="store_true") args = parser.parse_args() diff --git a/fastvideo/data_preprocess/preprocess_vae_latents.py b/fastvideo/data_preprocess/preprocess_vae_latents.py index c56abb1..166206a 100644 --- a/fastvideo/data_preprocess/preprocess_vae_latents.py +++ b/fastvideo/data_preprocess/preprocess_vae_latents.py @@ -1,19 +1,17 @@ -from fastvideo.dataset import getdataset -from torch.utils.data import DataLoader -from fastvideo.utils.dataset_utils import Collate import argparse -import torch -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration import json import os -from diffusers import AutoencoderKLMochi + +import torch import torch.distributed as dist +from accelerate.logging import get_logger +from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from fastvideo.utils.load import load_vae from tqdm import tqdm +from fastvideo.dataset import getdataset +from fastvideo.utils.load import load_vae + logger = get_logger(__name__) @@ -22,9 +20,10 @@ def main(args): world_size = int(os.getenv("WORLD_SIZE", 1)) print("world_size", world_size, "local rank", local_rank) train_dataset = getdataset(args) - sampler = DistributedSampler( - train_dataset, rank=local_rank, num_replicas=world_size, shuffle=True - ) + sampler = DistributedSampler(train_dataset, + rank=local_rank, + num_replicas=world_size, + shuffle=True) train_dataloader = DataLoader( train_dataset, sampler=sampler, @@ -32,12 +31,14 @@ def main(args): num_workers=args.dataloader_num_workers, ) - encoder_device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu") + encoder_device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") torch.cuda.set_device(local_rank) if not dist.is_initialized(): - dist.init_process_group( - backend="nccl", init_method="env://", world_size=world_size, rank=local_rank - ) + dist.init_process_group(backend="nccl", + init_method="env://", + world_size=world_size, + rank=local_rank) vae, autocast_type, fps = load_vae(args.model_type, args.model_path) vae.enable_tiling() os.makedirs(args.output_dir, exist_ok=True) @@ -47,14 +48,12 @@ def main(args): for _, data in tqdm(enumerate(train_dataloader), disable=local_rank != 0): with torch.inference_mode(): with torch.autocast("cuda", dtype=autocast_type): - latents = vae.encode(data["pixel_values"].to(encoder_device))[ - "latent_dist" - ].sample() + latents = vae.encode(data["pixel_values"].to( + encoder_device))["latent_dist"].sample() for idx, video_path in enumerate(data["path"]): video_name = os.path.basename(video_path).split(".")[0] - latent_path = os.path.join( - args.output_dir, "latent", video_name + ".pt" - ) + latent_path = os.path.join(args.output_dir, "latent", + video_name + ".pt") torch.save(latents[idx].to(torch.bfloat16), latent_path) item = {} item["length"] = latents[idx].shape[1] @@ -68,7 +67,8 @@ def main(args): dist.all_gather_object(gathered_data, local_data) if local_rank == 0: all_json_data = [item for sublist in gathered_data for item in sublist] - with open(os.path.join(args.output_dir, "videos2caption_temp.json"), "w") as f: + with open(os.path.join(args.output_dir, "videos2caption_temp.json"), + "w") as f: json.dump(all_json_data, f, indent=4) @@ -83,7 +83,8 @@ def main(args): "--dataloader_num_workers", type=int, default=1, - help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + help= + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", ) parser.add_argument( "--train_batch_size", @@ -91,12 +92,15 @@ def main(args): default=16, help="Batch size (per device) for the training dataloader.", ) - parser.add_argument( - "--num_latent_t", type=int, default=28, help="Number of latent timesteps." - ) + parser.add_argument("--num_latent_t", + type=int, + default=28, + help="Number of latent timesteps.") parser.add_argument("--max_height", type=int, default=480) parser.add_argument("--max_width", type=int, default=848) - parser.add_argument("--video_length_tolerance_range", type=int, default=2.0) + parser.add_argument("--video_length_tolerance_range", + type=int, + default=2.0) parser.add_argument("--group_frame", action="store_true") # TODO parser.add_argument("--group_resolution", action="store_true") # TODO parser.add_argument("--dataset", default="t2v") @@ -106,23 +110,25 @@ def main(args): parser.add_argument("--speed_factor", type=float, default=1.0) parser.add_argument("--drop_short_ratio", type=float, default=1.0) # text encoder & vae & diffusion model - parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl") + parser.add_argument("--text_encoder_name", + type=str, + default="google/t5-v1_1-xxl") parser.add_argument("--cache_dir", type=str, default="./cache_dir") parser.add_argument("--cfg", type=float, default=0.0) parser.add_argument( "--output_dir", type=str, default=None, - help="The output directory where the model predictions and checkpoints will be written.", + help= + "The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--logging_dir", type=str, default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), + help= + ("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), ) args = parser.parse_args() diff --git a/fastvideo/data_preprocess/preprocess_validation_text_embeddings.py b/fastvideo/data_preprocess/preprocess_validation_text_embeddings.py index 18d28be..ecfb717 100644 --- a/fastvideo/data_preprocess/preprocess_validation_text_embeddings.py +++ b/fastvideo/data_preprocess/preprocess_validation_text_embeddings.py @@ -1,19 +1,13 @@ import argparse -import torch -from accelerate.logging import get_logger -from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline -from diffusers.utils import export_to_video -import json import os + +import torch import torch.distributed as dist +from accelerate.logging import get_logger + +from fastvideo.utils.load import load_text_encoder logger = get_logger(__name__) -from torch.utils.data import Dataset -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data import DataLoader -from fastvideo.utils.load import load_text_encoder, load_vae -from diffusers.video_processor import VideoProcessor -from tqdm import tqdm def main(args): @@ -24,11 +18,14 @@ def main(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.cuda.set_device(local_rank) if not dist.is_initialized(): - dist.init_process_group( - backend="nccl", init_method="env://", world_size=world_size, rank=local_rank - ) + dist.init_process_group(backend="nccl", + init_method="env://", + world_size=world_size, + rank=local_rank) - text_encoder = load_text_encoder(args.model_type, args.model_path, device=device) + text_encoder = load_text_encoder(args.model_type, + args.model_path, + device=device) autocast_type = torch.float16 if args.model_type == "hunyuan" else torch.bfloat16 # output_dir/validation/prompt_attention_mask # output_dir/validation/prompt_embed @@ -37,10 +34,9 @@ def main(args): os.path.join(args.output_dir, "validation", "prompt_attention_mask"), exist_ok=True, ) - os.makedirs( - os.path.join(args.output_dir, "validation", "prompt_embed"), exist_ok=True - ) - json_data = [] + os.makedirs(os.path.join(args.output_dir, "validation", "prompt_embed"), + exist_ok=True) + with open(args.validation_prompt_txt, "r", encoding="utf-8") as file: lines = file.readlines() prompts = [line.strip() for line in lines] @@ -48,12 +44,11 @@ def main(args): with torch.inference_mode(): with torch.autocast("cuda", dtype=autocast_type): prompt_embeds, prompt_attention_mask = text_encoder.encode_prompt( - prompt - ) + prompt) file_name = prompt.split(".")[0] - prompt_embed_path = os.path.join( - args.output_dir, "validation", "prompt_embed", f"{file_name}.pt" - ) + prompt_embed_path = os.path.join(args.output_dir, "validation", + "prompt_embed", + f"{file_name}.pt") prompt_attention_mask_path = os.path.join( args.output_dir, "validation", @@ -61,7 +56,8 @@ def main(args): f"{file_name}.pt", ) torch.save(prompt_embeds[0], prompt_embed_path) - torch.save(prompt_attention_mask[0], prompt_attention_mask_path) + torch.save(prompt_attention_mask[0], + prompt_attention_mask_path) print(f"sample {file_name} saved") @@ -75,7 +71,8 @@ def main(args): "--output_dir", type=str, default=None, - help="The output directory where the model predictions and checkpoints will be written.", + help= + "The output directory where the model predictions and checkpoints will be written.", ) args = parser.parse_args() main(args) diff --git a/fastvideo/dataset/__init__.py b/fastvideo/dataset/__init__.py index f28fc40..781f232 100644 --- a/fastvideo/dataset/__init__.py +++ b/fastvideo/dataset/__init__.py @@ -1,45 +1,34 @@ -from transformers import AutoTokenizer - from torchvision import transforms from torchvision.transforms import Lambda +from transformers import AutoTokenizer + from fastvideo.dataset.t2v_datasets import T2V_dataset -from fastvideo.dataset.latent_datasets import LatentDataset -from fastvideo.dataset.transform import ( - Normalize255, - TemporalRandomCrop, - CenterCropResizeVideo, -) +from fastvideo.dataset.transform import (CenterCropResizeVideo, Normalize255, + TemporalRandomCrop) def getdataset(args): temporal_sample = TemporalRandomCrop(args.num_frames) # 16 x norm_fun = Lambda(lambda x: 2.0 * x - 1.0) resize_topcrop = [ - CenterCropResizeVideo((args.max_height, args.max_width), top_crop=True), + CenterCropResizeVideo((args.max_height, args.max_width), + top_crop=True), ] resize = [ CenterCropResizeVideo((args.max_height, args.max_width)), ] - transform = transforms.Compose( - [ - # Normalize255(), - *resize, - # RandomHorizontalFlipVideo(p=0.5), # in case their caption have position decription - # norm_fun - ] - ) - transform_topcrop = transforms.Compose( - [ - Normalize255(), - *resize_topcrop, - # RandomHorizontalFlipVideo(p=0.5), # in case their caption have position decription - norm_fun, - ] - ) + transform = transforms.Compose([ + # Normalize255(), + *resize, + ]) + transform_topcrop = transforms.Compose([ + Normalize255(), + *resize_topcrop, + norm_fun, + ]) # tokenizer = AutoTokenizer.from_pretrained("/storage/ongoing/new/Open-Sora-Plan/cache_dir/mt5-xxl", cache_dir=args.cache_dir) - tokenizer = AutoTokenizer.from_pretrained( - args.text_encoder_name, cache_dir=args.cache_dir - ) + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, + cache_dir=args.cache_dir) if args.dataset == "t2v": return T2V_dataset( args, @@ -53,11 +42,13 @@ def getdataset(args): if __name__ == "__main__": - from accelerate import Accelerator - from fastvideo.dataset.t2v_datasets import dataset_prog import random + + from accelerate import Accelerator from tqdm import tqdm + from fastvideo.dataset.t2v_datasets import dataset_prog + args = type( "args", (), @@ -75,7 +66,8 @@ def getdataset(args): "interpolation_scale_h": 1, "interpolation_scale_w": 1, "cache_dir": "../cache_dir", - "image_data": "/storage/ongoing/new/Open-Sora-Plan-bak/7.14bak/scripts/train_data/image_data.txt", + "image_data": + "/storage/ongoing/new/Open-Sora-Plan-bak/7.14bak/scripts/train_data/image_data.txt", "video_data": "1", "train_fps": 24, "drop_short_ratio": 1.0, @@ -93,7 +85,8 @@ def getdataset(args): for idx in tqdm(range(num)): image_data = dataset_prog.img_cap_list[idx] caps = [ - i["cap"] if isinstance(i["cap"], list) else [i["cap"]] for i in image_data + i["cap"] if isinstance(i["cap"], list) else [i["cap"]] + for i in image_data ] try: caps = [[random.choice(i)] for i in caps] diff --git a/fastvideo/dataset/latent_datasets.py b/fastvideo/dataset/latent_datasets.py index fd95f3e..667ea51 100644 --- a/fastvideo/dataset/latent_datasets.py +++ b/fastvideo/dataset/latent_datasets.py @@ -1,13 +1,18 @@ -import torch -from torch.utils.data import Dataset import json import os import random +import torch +from torch.utils.data import Dataset + class LatentDataset(Dataset): + def __init__( - self, json_path, num_latent_t, cfg_rate, + self, + json_path, + num_latent_t, + cfg_rate, ): # data_merge_path: video_dir, latent_dir, prompt_embed_dir, json_path self.json_path = json_path @@ -15,10 +20,10 @@ def __init__( self.datase_dir_path = os.path.dirname(json_path) self.video_dir = os.path.join(self.datase_dir_path, "video") self.latent_dir = os.path.join(self.datase_dir_path, "latent") - self.prompt_embed_dir = os.path.join(self.datase_dir_path, "prompt_embed") - self.prompt_attention_mask_dir = os.path.join( - self.datase_dir_path, "prompt_attention_mask" - ) + self.prompt_embed_dir = os.path.join(self.datase_dir_path, + "prompt_embed") + self.prompt_attention_mask_dir = os.path.join(self.datase_dir_path, + "prompt_attention_mask") with open(self.json_path, "r") as f: self.data_anno = json.load(f) # json.load(f) already keeps the order @@ -36,14 +41,15 @@ def __init__( def __getitem__(self, idx): latent_file = self.data_anno[idx]["latent_path"] prompt_embed_file = self.data_anno[idx]["prompt_embed_path"] - prompt_attention_mask_file = self.data_anno[idx]["prompt_attention_mask"] + prompt_attention_mask_file = self.data_anno[idx][ + "prompt_attention_mask"] # load latent = torch.load( os.path.join(self.latent_dir, latent_file), map_location="cpu", weights_only=True, ) - latent = latent.squeeze(0)[:, -self.num_latent_t :] + latent = latent.squeeze(0)[:, -self.num_latent_t:] if random.random() < self.cfg_rate: prompt_embed = self.uncond_prompt_embed prompt_attention_mask = self.uncond_prompt_mask @@ -54,9 +60,8 @@ def __getitem__(self, idx): weights_only=True, ) prompt_attention_mask = torch.load( - os.path.join( - self.prompt_attention_mask_dir, prompt_attention_mask_file - ), + os.path.join(self.prompt_attention_mask_dir, + prompt_attention_mask_file), map_location="cpu", weights_only=True, ) @@ -89,16 +94,15 @@ def latent_collate_function(batch): 0, max_w - latent.shape[3], ), - ) - for latent in latents + ) for latent in latents ] # attn mask latent_attn_mask = torch.ones(len(latents), max_t, max_h, max_w) # set to 0 if padding for i, latent in enumerate(latents): - latent_attn_mask[i, latent.shape[1] :, :, :] = 0 - latent_attn_mask[i, :, latent.shape[2] :, :] = 0 - latent_attn_mask[i, :, :, latent.shape[3] :] = 0 + latent_attn_mask[i, latent.shape[1]:, :, :] = 0 + latent_attn_mask[i, :, latent.shape[2]:, :] = 0 + latent_attn_mask[i, :, :, latent.shape[3]:] = 0 prompt_embeds = torch.stack(prompt_embeds, dim=0) prompt_attention_masks = torch.stack(prompt_attention_masks, dim=0) @@ -107,10 +111,13 @@ def latent_collate_function(batch): if __name__ == "__main__": - dataset = LatentDataset("data/Mochi-Synthetic-Data/merge.txt", num_latent_t=28) + dataset = LatentDataset("data/Mochi-Synthetic-Data/merge.txt", + num_latent_t=28) dataloader = torch.utils.data.DataLoader( - dataset, batch_size=2, shuffle=False, collate_fn=latent_collate_function - ) + dataset, + batch_size=2, + shuffle=False, + collate_fn=latent_collate_function) for latent, prompt_embed, latent_attn_mask, prompt_attention_mask in dataloader: print( latent.shape, diff --git a/fastvideo/dataset/t2v_datasets.py b/fastvideo/dataset/t2v_datasets.py index 5392a40..d973519 100644 --- a/fastvideo/dataset/t2v_datasets.py +++ b/fastvideo/dataset/t2v_datasets.py @@ -1,18 +1,18 @@ import json -import os, io, csv, math, random -import numpy as np -from einops import rearrange -from decord import VideoReader -from os.path import join as opj +import math +import os +import random from collections import Counter +from os.path import join as opj +import numpy as np import torch -from torch.utils.data.dataset import Dataset -from torch.utils.data import DataLoader, Dataset, get_worker_info -from tqdm import tqdm +import torchvision +from einops import rearrange from PIL import Image +from torch.utils.data import Dataset + from fastvideo.utils.dataset_utils import DecordInit -import torchvision from fastvideo.utils.logging_ import main_print @@ -27,6 +27,7 @@ def __call__(cls, *args, **kwargs): class DataSetProg(metaclass=SingletonMeta): + def __init__(self): self.cap_list = [] self.elements = [] @@ -45,7 +46,8 @@ def set_cap_list(self, num_workers, cap_list, n_elements): for i in range(self.num_workers): self.n_used_elements[i] = 0 - per_worker = int(math.ceil(len(self.elements) / float(self.num_workers))) + per_worker = int( + math.ceil(len(self.elements) / float(self.num_workers))) start = i * per_worker end = min(start + per_worker, len(self.elements)) self.worker_elements[i] = self.elements[start:end] @@ -57,8 +59,8 @@ def get_item(self, work_info): worker_id = work_info.id idx = self.worker_elements[worker_id][ - self.n_used_elements[worker_id] % len(self.worker_elements[worker_id]) - ] + self.n_used_elements[worker_id] % + len(self.worker_elements[worker_id])] self.n_used_elements[worker_id] += 1 return idx @@ -66,14 +68,19 @@ def get_item(self, work_info): dataset_prog = DataSetProg() -def filter_resolution(h, w, max_h_div_w_ratio=17 / 16, min_h_div_w_ratio=8 / 16): +def filter_resolution(h, + w, + max_h_div_w_ratio=17 / 16, + min_h_div_w_ratio=8 / 16): if h / w <= max_h_div_w_ratio and h / w >= min_h_div_w_ratio: return True return False class T2V_dataset(Dataset): - def __init__(self, args, transform, temporal_sample, tokenizer, transform_topcrop): + + def __init__(self, args, transform, temporal_sample, tokenizer, + transform_topcrop): self.data = args.data_merge_path self.num_frames = args.num_frames self.train_fps = args.train_fps @@ -92,7 +99,7 @@ def __init__(self, args, transform, temporal_sample, tokenizer, transform_topcro self.v_decoder = DecordInit() self.video_length_tolerance_range = args.video_length_tolerance_range self.support_Chinese = True - if not ("mt5" in args.text_encoder_name): + if "mt5" not in args.text_encoder_name: self.support_Chinese = False cap_list = self.get_cap_list() @@ -102,7 +109,8 @@ def __init__(self, args, transform, temporal_sample, tokenizer, transform_topcro self.lengths = self.sample_num_frames n_elements = len(cap_list) - dataset_prog.set_cap_list(args.dataloader_num_workers, cap_list, n_elements) + dataset_prog.set_cap_list(args.dataloader_num_workers, cap_list, + n_elements) print(f"video length: {len(dataset_prog.cap_list)}", flush=True) @@ -130,8 +138,7 @@ def get_video(self, idx): assert os.path.exists(video_path), f"file {video_path} do not exist!" frame_indices = dataset_prog.cap_list[idx]["sample_frame_index"] torchvision_video, _, metadata = torchvision.io.read_video( - video_path, output_format="TCHW" - ) + video_path, output_format="TCHW") video = torchvision_video[frame_indices] video = self.transform(video) video = rearrange(video, "t c h w -> c t h w") @@ -171,7 +178,8 @@ def get_video(self, idx): ) def get_image(self, idx): - image_data = dataset_prog.cap_list[idx] # [{'path': path, 'cap': cap}, ...] + image_data = dataset_prog.cap_list[ + idx] # [{'path': path, 'cap': cap}, ...] image = Image.open(image_data["path"]).convert("RGB") # [h, w, c] image = torch.from_numpy(np.array(image)) # [h, w, c] @@ -180,20 +188,15 @@ def get_image(self, idx): # h, w = i.shape[-2:] # assert h / w <= 17 / 16 and h / w >= 8 / 16, f'Only image with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But found ratio is {round(h / w, 2)} with the shape of {i.shape}' - image = ( - self.transform_topcrop(image) - if "human_images" in image_data["path"] - else self.transform(image) - ) # [1 C H W] -> num_img [1 C H W] + image = (self.transform_topcrop(image) if "human_images" + in image_data["path"] else self.transform(image) + ) # [1 C H W] -> num_img [1 C H W] image = image.transpose(0, 1) # [1 C H W] -> [C 1 H W] image = image.float() / 127.5 - 1.0 - caps = ( - image_data["cap"] - if isinstance(image_data["cap"], list) - else [image_data["cap"]] - ) + caps = (image_data["cap"] if isinstance(image_data["cap"], list) else + [image_data["cap"]]) caps = [random.choice(caps)] text = caps input_ids, cond_mask = [], [] @@ -247,13 +250,12 @@ def define_frame_index(self, cap_list): cnt_no_resolution += 1 continue else: - if ( - resolution.get("height", None) is None - or resolution.get("width", None) is None - ): + if (resolution.get("height", None) is None + or resolution.get("width", None) is None): cnt_no_resolution += 1 continue - height, width = i["resolution"]["height"], i["resolution"]["width"] + height, width = i["resolution"]["height"], i["resolution"][ + "width"] aspect = self.max_height / self.max_width hw_aspect_thr = 1.5 is_pick = filter_resolution( @@ -271,7 +273,7 @@ def define_frame_index(self, cap_list): i["num_frames"] = math.ceil(fps * duration) # max 5.0 and min 1.0 are just thresholds to filter some videos which have suitable duration. if i["num_frames"] / fps > self.video_length_tolerance_range * ( - self.num_frames / self.train_fps * self.speed_factor + self.num_frames / self.train_fps * self.speed_factor ): # too long video is not suitable for this training stage (self.num_frames) cnt_too_long += 1 continue @@ -279,21 +281,19 @@ def define_frame_index(self, cap_list): # resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24) frame_interval = fps / self.train_fps start_frame_idx = 0 - frame_indices = np.arange( - start_frame_idx, i["num_frames"], frame_interval - ).astype(int) + frame_indices = np.arange(start_frame_idx, i["num_frames"], + frame_interval).astype(int) # comment out it to enable dynamic frames training - if ( - len(frame_indices) < self.num_frames - and random.random() < self.drop_short_ratio - ): + if (len(frame_indices) < self.num_frames + and random.random() < self.drop_short_ratio): cnt_too_short += 1 continue # too long video will be temporal-crop randomly if len(frame_indices) > self.num_frames: - begin_index, end_index = self.temporal_sample(len(frame_indices)) + begin_index, end_index = self.temporal_sample( + len(frame_indices)) frame_indices = frame_indices[begin_index:end_index] # frame_indices = frame_indices[:self.num_frames] # head crop i["sample_frame_index"] = frame_indices.tolist() @@ -309,7 +309,7 @@ def define_frame_index(self, cap_list): sample_num_frames.append(i["sample_num_frames"]) else: raise NameError( - f"Unknown file extention {path.split('.')[-1]}, only support .mp4 for video and .jpg for image" + f"Unknown file extension {path.split('.')[-1]}, only support .mp4 for video and .jpg for image" ) # import ipdb;ipdb.set_trace() main_print( @@ -324,14 +324,16 @@ def decord_read(self, path, frame_indices): decord_vr = self.v_decoder(path) video_data = decord_vr.get_batch(frame_indices).asnumpy() video_data = torch.from_numpy(video_data) - video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) + video_data = video_data.permute(0, 3, 1, + 2) # (T, H, W, C) -> (T C H W) return video_data def read_jsons(self, data): cap_lists = [] with open(data, "r") as f: folder_anno = [ - i.strip().split(",") for i in f.readlines() if len(i.strip()) > 0 + i.strip().split(",") for i in f.readlines() + if len(i.strip()) > 0 ] print(folder_anno) for folder, anno in folder_anno: diff --git a/fastvideo/dataset/transform.py b/fastvideo/dataset/transform.py index 4188344..319649c 100644 --- a/fastvideo/dataset/transform.py +++ b/fastvideo/dataset/transform.py @@ -1,7 +1,8 @@ -import torch -import random import numbers -from torchvision.transforms import RandomCrop, RandomResizedCrop +import random + +import torch +from PIL import Image def _is_tensor_video_clip(clip): @@ -20,21 +21,19 @@ def center_crop_arr(pil_image, image_size): https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 """ while min(*pil_image.size) >= 2 * image_size: - pil_image = pil_image.resize( - tuple(x // 2 for x in pil_image.size), resample=Image.BOX - ) + pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), + resample=Image.BOX) scale = image_size / min(*pil_image.size) - pil_image = pil_image.resize( - tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC - ) + pil_image = pil_image.resize(tuple( + round(x * scale) for x in pil_image.size), + resample=Image.BICUBIC) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 crop_x = (arr.shape[1] - image_size) // 2 - return Image.fromarray( - arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] - ) + return Image.fromarray(arr[crop_y:crop_y + image_size, + crop_x:crop_x + image_size]) def crop(clip, i, j, h, w): @@ -44,7 +43,7 @@ def crop(clip, i, j, h, w): """ if len(clip.size()) != 4: raise ValueError("clip should be a 4D tensor") - return clip[..., i : i + h, j : j + w] + return clip[..., i:i + h, j:j + w] def resize(clip, target_size, interpolation_mode): @@ -153,16 +152,14 @@ def random_shift_crop(clip): h, w = clip.size(-2), clip.size(-1) if h <= w: - long_edge = w short_edge = h else: - long_edge = h short_edge = w th, tw = short_edge, short_edge - i = torch.randint(0, h - th + 1, size=(1,)).item() - j = torch.randint(0, w - tw + 1, size=(1,)).item() + i = torch.randint(0, h - th + 1, size=(1, )).item() + j = torch.randint(0, w - tw + 1, size=(1, )).item() return crop(clip, i, j, th, tw) @@ -177,9 +174,8 @@ def normalize_video(clip): """ _is_tensor_video_clip(clip) if not clip.dtype == torch.uint8: - raise TypeError( - "clip tensor should have data type uint8. Got %s" % str(clip.dtype) - ) + raise TypeError("clip tensor should have data type uint8. Got %s" % + str(clip.dtype)) # return clip.float().permute(3, 0, 1, 2) / 255.0 return clip.float() / 255.0 @@ -217,6 +213,7 @@ def hflip(clip): class RandomCropVideo: + def __init__(self, size): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) @@ -246,8 +243,8 @@ def get_params(self, clip): if w == tw and h == th: return 0, 0, h, w - i = torch.randint(0, h - th + 1, size=(1,)).item() - j = torch.randint(0, w - tw + 1, size=(1,)).item() + i = torch.randint(0, h - th + 1, size=(1, )).item() + j = torch.randint(0, w - tw + 1, size=(1, )).item() return i, j, th, tw @@ -256,6 +253,7 @@ def __repr__(self) -> str: class SpatialStrideCropVideo: + def __init__(self, stride): self.stride = stride @@ -288,7 +286,10 @@ class LongSideResizeVideo: """ def __init__( - self, size, skip_low_resolution=False, interpolation_mode="bilinear", + self, + size, + skip_low_resolution=False, + interpolation_mode="bilinear", ): self.size = size self.skip_low_resolution = skip_low_resolution @@ -311,9 +312,9 @@ def __call__(self, clip): else: h = int(h * self.size / w) w = self.size - resize_clip = resize( - clip, target_size=(h, w), interpolation_mode=self.interpolation_mode - ) + resize_clip = resize(clip, + target_size=(h, w), + interpolation_mode=self.interpolation_mode) return resize_clip def __repr__(self) -> str: @@ -327,12 +328,14 @@ class CenterCropResizeVideo: """ def __init__( - self, size, top_crop=False, interpolation_mode="bilinear", + self, + size, + top_crop=False, + interpolation_mode="bilinear", ): if len(size) != 2: raise ValueError( - f"size should be tuple (height, width), instead got {size}" - ) + f"size should be tuple (height, width), instead got {size}") self.size = size self.top_crop = top_crop self.interpolation_mode = interpolation_mode @@ -346,9 +349,10 @@ def __call__(self, clip): size is (T, C, crop_size, crop_size) """ # clip_center_crop = center_crop_using_short_edge(clip) - clip_center_crop = center_crop_th_tw( - clip, self.size[0], self.size[1], top_crop=self.top_crop - ) + clip_center_crop = center_crop_th_tw(clip, + self.size[0], + self.size[1], + top_crop=self.top_crop) # import ipdb;ipdb.set_trace() clip_center_crop_resize = resize( clip_center_crop, @@ -368,7 +372,9 @@ class UCFCenterCropVideo: """ def __init__( - self, size, interpolation_mode="bilinear", + self, + size, + interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: @@ -389,9 +395,9 @@ def __call__(self, clip): torch.tensor: scale resized / center cropped video clip. size is (T, C, crop_size, crop_size) """ - clip_resize = resize_scale( - clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode - ) + clip_resize = resize_scale(clip=clip, + target_size=self.size, + interpolation_mode=self.interpolation_mode) clip_center_crop = center_crop(clip_resize, self.size) return clip_center_crop @@ -405,7 +411,9 @@ class KineticsRandomCropResizeVideo: """ def __init__( - self, size, interpolation_mode="bilinear", + self, + size, + interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: @@ -420,13 +428,17 @@ def __init__( def __call__(self, clip): clip_random_crop = random_shift_crop(clip) - clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode) + clip_resize = resize(clip_random_crop, self.size, + self.interpolation_mode) return clip_resize class CenterCropVideo: + def __init__( - self, size, interpolation_mode="bilinear", + self, + size, + interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: @@ -559,9 +571,8 @@ def __init__(self, t_stride, extra_1): def __call__(self, t, h, w): if self.extra_1: t = t - 1 - truncate_t_list = list(range(t + 1))[t // 2 :][ - :: self.t_stride - ] # need half at least + truncate_t_list = list( + range(t + 1))[t // 2:][::self.t_stride] # need half at least truncate_t = random.choice(truncate_t_list) if self.extra_1: truncate_t = truncate_t + 1 @@ -569,27 +580,26 @@ def __call__(self, t, h, w): if __name__ == "__main__": - from torchvision import transforms - import torchvision.io as io + import os + import numpy as np + import torchvision.io as io + from torchvision import transforms from torchvision.utils import save_image - import os - vframes, aframes, info = io.read_video( - filename="./v_Archery_g01_c03.avi", pts_unit="sec", output_format="TCHW" - ) + vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", + pts_unit="sec", + output_format="TCHW") - trans = transforms.Compose( - [ - Normalize255(), - RandomHorizontalFlipVideo(), - UCFCenterCropVideo(512), - # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), - transforms.Normalize( - mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True - ), - ] - ) + trans = transforms.Compose([ + Normalize255(), + RandomHorizontalFlipVideo(), + UCFCenterCropVideo(512), + # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + transforms.Normalize(mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + inplace=True), + ]) target_video_len = 32 frame_interval = 1 @@ -603,9 +613,10 @@ def __call__(self, t, h, w): # print(start_frame_ind) # print(end_frame_ind) assert end_frame_ind - start_frame_ind >= target_video_len - frame_indice = np.linspace( - start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int - ) + frame_indice = np.linspace(start_frame_ind, + end_frame_ind - 1, + target_video_len, + dtype=int) print(frame_indice) select_vframes = vframes[frame_indice] @@ -616,13 +627,14 @@ def __call__(self, t, h, w): print(select_vframes_trans.shape) print(select_vframes_trans.dtype) - select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to( - dtype=torch.uint8 - ) + select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * + 255).to(dtype=torch.uint8) print(select_vframes_trans_int.dtype) print(select_vframes_trans_int.permute(0, 2, 3, 1).shape) - io.write_video("./test.avi", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8) + io.write_video("./test.avi", + select_vframes_trans_int.permute(0, 2, 3, 1), + fps=8) for i in range(target_video_len): save_image( diff --git a/fastvideo/distill.py b/fastvideo/distill.py index ce5fb97..89f02e1 100644 --- a/fastvideo/distill.py +++ b/fastvideo/distill.py @@ -1,53 +1,47 @@ +# !/bin/python3 +# isort: skip_file import argparse import math import os -from fastvideo.utils.parallel_states import ( - initialize_sequence_parallel_state, - destroy_sequence_parallel_group, - get_sequence_parallel_state, - nccl_info, -) -from fastvideo.utils.communications import sp_parallel_dataloader_wrapper, broadcast -from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input -from fastvideo.utils.validation import log_validation import time -from torch.utils.data import DataLoader +from collections import deque +from copy import deepcopy + import torch -from torch.distributed.fsdp import ShardingStrategy -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - StateDictType, - FullStateDictConfig, -) -from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule -import json -from torch.utils.data.distributed import DistributedSampler -from fastvideo.utils.dataset_utils import LengthGroupedSampler +import torch.distributed as dist import wandb from accelerate.utils import set_seed -from tqdm.auto import tqdm -from fastvideo.utils.fsdp_util import get_dit_fsdp_kwargs, apply_fsdp_checkpointing from diffusers import FlowMatchEulerDiscreteScheduler -from fastvideo.utils.load import load_transformer -from fastvideo.distill.solver import EulerSolver, extract_into_tensor -from copy import deepcopy from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version -from fastvideo.dataset.latent_datasets import LatentDataset, latent_collate_function -import torch.distributed as dist -from safetensors.torch import save_file from peft import LoraConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from fastvideo.utils.checkpoint import ( - save_checkpoint, - save_lora_checkpoint, - resume_lora_optimizer, -) +from torch.distributed.fsdp import ShardingStrategy +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from tqdm.auto import tqdm + +from fastvideo.dataset.latent_datasets import (LatentDataset, + latent_collate_function) +from fastvideo.distill.solver import EulerSolver, extract_into_tensor +from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input +from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule +from fastvideo.utils.checkpoint import (resume_lora_optimizer, save_checkpoint, + save_lora_checkpoint) +from fastvideo.utils.communications import (broadcast, + sp_parallel_dataloader_wrapper) +from fastvideo.utils.dataset_utils import LengthGroupedSampler +from fastvideo.utils.fsdp_util import (apply_fsdp_checkpointing, + get_dit_fsdp_kwargs) +from fastvideo.utils.load import load_transformer +from fastvideo.utils.parallel_states import (destroy_sequence_parallel_group, + get_sequence_parallel_state, + initialize_sequence_parallel_state + ) +from fastvideo.utils.validation import log_validation # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.31.0") -import time -from collections import deque def main_print(content): @@ -55,31 +49,6 @@ def main_print(content): print(content) -def save_checkpoint(transformer, rank, output_dir, step): - main_print(f"--> saving checkpoint at step {step}") - with FSDP.state_dict_type( - transformer, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True), - ): - cpu_state = transformer.state_dict() - # todo move to get_state_dict - if rank <= 0: - save_dir = os.path.join(output_dir, f"checkpoint-{step}") - os.makedirs(save_dir, exist_ok=True) - # save using safetensors - weight_path = os.path.join(save_dir, "diffusion_pytorch_model.safetensors") - save_file(cpu_state, weight_path) - config_dict = dict(transformer.config) - if "dtype" in config_dict: - del config_dict["dtype"] # TODO - config_path = os.path.join(save_dir, "config.json") - # save dict as json - with open(config_path, "w") as f: - json.dump(config_dict, f, indent=4) - main_print(f"--> checkpoint saved at step {step}") - - def reshard_fsdp(model): for m in FSDP.fsdp_modules(model): if m._has_params and m.sharding_strategy is not ShardingStrategy.NO_SHARD: @@ -88,18 +57,20 @@ def reshard_fsdp(model): def get_norm(model_pred, norms, gradient_accumulation_steps): fro_norm = ( - torch.linalg.matrix_norm(model_pred, ord="fro") / gradient_accumulation_steps - ) - largest_singular_value = ( - torch.linalg.matrix_norm(model_pred, ord=2) / gradient_accumulation_steps - ) - absolute_mean = torch.mean(torch.abs(model_pred)) / gradient_accumulation_steps - absolute_max = torch.max(torch.abs(model_pred)) / gradient_accumulation_steps + torch.linalg.matrix_norm(model_pred, ord="fro") / # codespell:ignore + gradient_accumulation_steps) + largest_singular_value = (torch.linalg.matrix_norm(model_pred, ord=2) / + gradient_accumulation_steps) + absolute_mean = torch.mean( + torch.abs(model_pred)) / gradient_accumulation_steps + absolute_max = torch.max( + torch.abs(model_pred)) / gradient_accumulation_steps dist.all_reduce(fro_norm, op=dist.ReduceOp.AVG) dist.all_reduce(largest_singular_value, op=dist.ReduceOp.AVG) dist.all_reduce(absolute_mean, op=dist.ReduceOp.AVG) - norms["fro"] += torch.mean(fro_norm).item() - norms["largest singular value"] += torch.mean(largest_singular_value).item() + norms["fro"] += torch.mean(fro_norm).item() # codespell:ignore + norms["largest singular value"] += torch.mean( + largest_singular_value).item() norms["absolute mean"] += absolute_mean.item() norms["absolute max"] += absolute_max.item() @@ -132,7 +103,7 @@ def distill_one_step( total_loss = 0.0 optimizer.zero_grad() model_pred_norm = { - "fro": 0.0, + "fro": 0.0, # codespell:ignore "largest singular value": 0.0, "absolute mean": 0.0, "absolute max": 0.0, @@ -147,22 +118,23 @@ def distill_one_step( model_input = normalize_dit_input(model_type, latents) noise = torch.randn_like(model_input) bsz = model_input.shape[0] - index = torch.randint( - 0, num_euler_timesteps, (bsz,), device=model_input.device - ).long() + index = torch.randint(0, + num_euler_timesteps, (bsz, ), + device=model_input.device).long() if sp_size > 1: broadcast(index) # Add noise according to flow matching. # sigmas = get_sigmas(start_timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) sigmas = extract_into_tensor(solver.sigmas, index, model_input.shape) - sigmas_prev = extract_into_tensor(solver.sigmas_prev, index, model_input.shape) + sigmas_prev = extract_into_tensor(solver.sigmas_prev, index, + model_input.shape) - timesteps = (sigmas * noise_scheduler.config.num_train_timesteps).view(-1) + timesteps = (sigmas * + noise_scheduler.config.num_train_timesteps).view(-1) # if squeeze to [], unsqueeze to [1] - timesteps_prev = ( - sigmas_prev * noise_scheduler.config.num_train_timesteps - ).view(-1) + timesteps_prev = (sigmas_prev * + noise_scheduler.config.num_train_timesteps).view(-1) noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input # Predict the noise residual with torch.autocast("cuda", dtype=torch.bfloat16): @@ -175,14 +147,14 @@ def distill_one_step( } if hunyuan_teacher_disable_cfg: teacher_kwargs["guidance"] = torch.tensor( - [1000.0], device=noisy_model_input.device, dtype=torch.bfloat16 - ) + [1000.0], + device=noisy_model_input.device, + dtype=torch.bfloat16) model_pred = transformer(**teacher_kwargs)[0] # if accelerator.is_main_process: model_pred, end_index = solver.euler_style_multiphase_pred( - noisy_model_input, model_pred, index, multiphase - ) + noisy_model_input, model_pred, index, multiphase) with torch.no_grad(): w = distill_cfg with torch.autocast("cuda", dtype=torch.bfloat16): @@ -205,10 +177,10 @@ def distill_one_step( uncond_prompt_mask.unsqueeze(0).expand(bsz, -1), return_dict=False, )[0].float() - teacher_output = cond_teacher_output + w * ( - cond_teacher_output - uncond_teacher_output - ) - x_prev = solver.euler_step(noisy_model_input, teacher_output, index) + teacher_output = cond_teacher_output + w * (cond_teacher_output - + uncond_teacher_output) + x_prev = solver.euler_step(noisy_model_input, teacher_output, + index) # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n with torch.no_grad(): @@ -231,41 +203,32 @@ def distill_one_step( )[0] target, end_index = solver.euler_style_multiphase_pred( - x_prev, target_pred, index, multiphase, True - ) + x_prev, target_pred, index, multiphase, True) huber_c = 0.001 # loss = loss.mean() - loss = ( - torch.mean( - torch.sqrt((model_pred.float() - target.float()) ** 2 + huber_c ** 2) - - huber_c - ) - / gradient_accumulation_steps - ) + loss = (torch.mean( + torch.sqrt((model_pred.float() - target.float())**2 + huber_c**2) - + huber_c) / gradient_accumulation_steps) if pred_decay_weight > 0: if pred_decay_type == "l1": pred_decay_loss = ( - torch.mean(torch.sqrt(model_pred.float() ** 2)) - * pred_decay_weight - / gradient_accumulation_steps - ) + torch.mean(torch.sqrt(model_pred.float()**2)) * + pred_decay_weight / gradient_accumulation_steps) loss += pred_decay_loss elif pred_decay_type == "l2": # essnetially k2? - pred_decay_loss = ( - torch.mean(model_pred.float() ** 2) - * pred_decay_weight - / gradient_accumulation_steps - ) + pred_decay_loss = (torch.mean(model_pred.float()**2) * + pred_decay_weight / + gradient_accumulation_steps) loss += pred_decay_loss else: - assert NotImplementedError("pred_decay_type is not implemented") + assert NotImplementedError( + "pred_decay_type is not implemented") # calculate model_pred norm and mean - get_norm( - model_pred.detach().float(), model_pred_norm, gradient_accumulation_steps - ) + get_norm(model_pred.detach().float(), model_pred_norm, + gradient_accumulation_steps) loss.backward() avg_loss = loss.detach().clone() @@ -275,13 +238,12 @@ def distill_one_step( # update ema if ema_transformer is not None: reshard_fsdp(ema_transformer) - for p_averaged, p_model in zip( - ema_transformer.parameters(), transformer.parameters() - ): + for p_averaged, p_model in zip(ema_transformer.parameters(), + transformer.parameters()): with torch.no_grad(): p_averaged.copy_( - torch.lerp(p_averaged.detach(), p_model.detach(), 1 - ema_decay) - ) + torch.lerp(p_averaged.detach(), p_model.detach(), + 1 - ema_decay)) grad_norm = transformer.clip_grad_norm_(max_grad_norm) optimizer.step() @@ -312,7 +274,7 @@ def main(args): if rank <= 0 and args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) - # For mixed precision training we cast all non-trainable weigths to half-precision + # For mixed precision training we cast all non-trainable weights to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. # Create model: @@ -360,27 +322,36 @@ def main(args): if args.use_lora: transformer.config.lora_rank = args.lora_rank transformer.config.lora_alpha = args.lora_alpha - transformer.config.lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + transformer.config.lora_target_modules = [ + "to_k", "to_q", "to_v", "to_out.0" + ] transformer._no_split_modules = no_split_modules - fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"](transformer) + fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"]( + transformer) - transformer = FSDP(transformer, **fsdp_kwargs,) - teacher_transformer = FSDP(teacher_transformer, **fsdp_kwargs,) + transformer = FSDP( + transformer, + **fsdp_kwargs, + ) + teacher_transformer = FSDP( + teacher_transformer, + **fsdp_kwargs, + ) if args.use_ema: - ema_transformer = FSDP(ema_transformer, **fsdp_kwargs,) - main_print(f"--> model loaded") + ema_transformer = FSDP( + ema_transformer, + **fsdp_kwargs, + ) + main_print("--> model loaded") if args.gradient_checkpointing: - apply_fsdp_checkpointing( - transformer, no_split_modules, args.selective_checkpointing - ) - apply_fsdp_checkpointing( - teacher_transformer, no_split_modules, args.selective_checkpointing - ) + apply_fsdp_checkpointing(transformer, no_split_modules, + args.selective_checkpointing) + apply_fsdp_checkpointing(teacher_transformer, no_split_modules, + args.selective_checkpointing) if args.use_ema: - apply_fsdp_checkpointing( - ema_transformer, no_split_modules, args.selective_checkpointing - ) + apply_fsdp_checkpointing(ema_transformer, no_split_modules, + args.selective_checkpointing) # Set model as trainable. transformer.train() teacher_transformer.requires_grad_(False) @@ -388,9 +359,8 @@ def main(args): ema_transformer.requires_grad_(False) noise_scheduler = FlowMatchEulerDiscreteScheduler(shift=args.shift) if args.scheduler_type == "pcm_linear_quadratic": - linear_steps = int( - noise_scheduler.config.num_train_timesteps * args.linear_range - ) + linear_steps = int(noise_scheduler.config.num_train_timesteps * + args.linear_range) sigmas = linear_quadratic_schedule( noise_scheduler.config.num_train_timesteps, args.linear_quadratic_threshold, @@ -406,7 +376,8 @@ def main(args): ) solver.to(device) params_to_optimize = transformer.parameters() - params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize)) + params_to_optimize = list( + filter(lambda p: p.requires_grad, params_to_optimize)) optimizer = torch.optim.AdamW( params_to_optimize, @@ -419,8 +390,7 @@ def main(args): init_steps = 0 if args.resume_from_lora_checkpoint: transformer, optimizer, init_steps = resume_lora_optimizer( - transformer, args.resume_from_lora_checkpoint, optimizer - ) + transformer, args.resume_from_lora_checkpoint, optimizer) main_print(f"optimizer: {optimizer}") # todo add lr scheduler @@ -434,23 +404,19 @@ def main(args): last_epoch=init_steps - 1, ) - train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, args.cfg) + train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, + args.cfg) uncond_prompt_embed = train_dataset.uncond_prompt_embed uncond_prompt_mask = train_dataset.uncond_prompt_mask - sampler = ( - LengthGroupedSampler( - args.train_batch_size, - rank=rank, - world_size=world_size, - lengths=train_dataset.lengths, - group_frame=args.group_frame, - group_resolution=args.group_resolution, - ) - if (args.group_frame or args.group_resolution) - else DistributedSampler( - train_dataset, rank=rank, num_replicas=world_size, shuffle=False - ) - ) + sampler = (LengthGroupedSampler( + args.train_batch_size, + rank=rank, + world_size=world_size, + lengths=train_dataset.lengths, + group_frame=args.group_frame, + group_resolution=args.group_resolution, + ) if (args.group_frame or args.group_resolution) else DistributedSampler( + train_dataset, rank=rank, num_replicas=world_size, shuffle=False)) train_dataloader = DataLoader( train_dataset, @@ -463,45 +429,43 @@ def main(args): ) num_update_steps_per_epoch = math.ceil( - len(train_dataloader) - / args.gradient_accumulation_steps - * args.sp_size - / args.train_sp_batch_size - ) - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + len(train_dataloader) / args.gradient_accumulation_steps * + args.sp_size / args.train_sp_batch_size) + args.num_train_epochs = math.ceil(args.max_train_steps / + num_update_steps_per_epoch) if rank <= 0: project = args.tracker_project_name or "fastvideo" wandb.init(project=project, config=args) # Train! - total_batch_size = ( - args.train_batch_size - * world_size - * args.gradient_accumulation_steps - / args.sp_size - * args.train_sp_batch_size - ) + total_batch_size = (args.train_batch_size * world_size * + args.gradient_accumulation_steps / args.sp_size * + args.train_sp_batch_size) main_print("***** Running training *****") main_print(f" Num examples = {len(train_dataset)}") main_print(f" Dataloader size = {len(train_dataloader)}") main_print(f" Num Epochs = {args.num_train_epochs}") main_print(f" Resume training from step {init_steps}") - main_print(f" Instantaneous batch size per device = {args.train_batch_size}") + main_print( + f" Instantaneous batch size per device = {args.train_batch_size}") main_print( f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}" ) - main_print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + main_print( + f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") main_print(f" Total optimization steps = {args.max_train_steps}") main_print( f" Total training parameters per FSDP shard = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e9} B" ) # print dtype - main_print(f" Master weight dtype: {transformer.parameters().__next__().dtype}") + main_print( + f" Master weight dtype: {transformer.parameters().__next__().dtype}") # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: - assert NotImplementedError("resume_from_checkpoint is not supported now.") + assert NotImplementedError( + "resume_from_checkpoint is not supported now.") # TODO progress_bar = tqdm( @@ -573,40 +537,47 @@ def get_num_phases(multi_phased_distill_schedule, step): step_times.append(step_time) avg_step_time = sum(step_times) / len(step_times) - progress_bar.set_postfix( - { - "loss": f"{loss:.4f}", - "step_time": f"{step_time:.2f}s", - "grad_norm": grad_norm, - "phases": num_phases, - } - ) + progress_bar.set_postfix({ + "loss": f"{loss:.4f}", + "step_time": f"{step_time:.2f}s", + "grad_norm": grad_norm, + "phases": num_phases, + }) progress_bar.update(1) if rank <= 0: wandb.log( { - "train_loss": loss, - "learning_rate": lr_scheduler.get_last_lr()[0], - "step_time": step_time, - "avg_step_time": avg_step_time, - "grad_norm": grad_norm, - "pred_fro_norm": pred_norm["fro"], - "pred_largest_singular_value": pred_norm["largest singular value"], - "pred_absolute_mean": pred_norm["absolute mean"], - "pred_absolute_max": pred_norm["absolute max"], + "train_loss": + loss, + "learning_rate": + lr_scheduler.get_last_lr()[0], + "step_time": + step_time, + "avg_step_time": + avg_step_time, + "grad_norm": + grad_norm, + "pred_fro_norm": + pred_norm["fro"], # codespell:ignore + "pred_largest_singular_value": + pred_norm["largest singular value"], + "pred_absolute_mean": + pred_norm["absolute mean"], + "pred_absolute_max": + pred_norm["absolute max"], }, step=step, ) if step % args.checkpointing_steps == 0: if args.use_lora: # Save LoRA weights - save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, step - ) + save_lora_checkpoint(transformer, optimizer, rank, + args.output_dir, step) else: # Your existing checkpoint saving code if args.use_ema: - save_checkpoint(ema_transformer, rank, args.output_dir, step) + save_checkpoint(ema_transformer, rank, args.output_dir, + step) else: save_checkpoint(transformer, rank, args.output_dir, step) dist.barrier() @@ -640,11 +611,11 @@ def get_num_phases(multi_phased_distill_schedule, step): ) if args.use_lora: - save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, args.max_train_steps - ) + save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, + args.max_train_steps) else: - save_checkpoint(transformer, rank, args.output_dir, args.max_train_steps) + save_checkpoint(transformer, rank, args.output_dir, + args.max_train_steps) if get_sequence_parallel_state(): destroy_sequence_parallel_group() @@ -653,9 +624,10 @@ def get_num_phases(multi_phased_distill_schedule, step): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--model_type", type=str, default="mochi", help="The type of model to train." - ) + parser.add_argument("--model_type", + type=str, + default="mochi", + help="The type of model to train.") # dataset & dataloader parser.add_argument("--data_json_path", type=str, required=True) @@ -664,7 +636,8 @@ def get_num_phases(multi_phased_distill_schedule, step): "--dataloader_num_workers", type=int, default=10, - help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + help= + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", ) parser.add_argument( "--train_batch_size", @@ -672,9 +645,10 @@ def get_num_phases(multi_phased_distill_schedule, step): default=16, help="Batch size (per device) for the training dataloader.", ) - parser.add_argument( - "--num_latent_t", type=int, default=28, help="Number of latent timesteps." - ) + parser.add_argument("--num_latent_t", + type=int, + default=28, + help="Number of latent timesteps.") parser.add_argument("--group_frame", action="store_true") # TODO parser.add_argument("--group_resolution", action="store_true") # TODO @@ -696,14 +670,16 @@ def get_num_phases(multi_phased_distill_schedule, step): parser.add_argument("--validation_steps", type=float, default=64) parser.add_argument("--log_validation", action="store_true") parser.add_argument("--tracker_project_name", type=str, default=None) - parser.add_argument( - "--seed", type=int, default=None, help="A seed for reproducible training." - ) + parser.add_argument("--seed", + type=int, + default=None, + help="A seed for reproducible training.") parser.add_argument( "--output_dir", type=str, default=None, - help="The output directory where the model predictions and checkpoints will be written.", + help= + "The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--checkpoints_total_limit", @@ -715,39 +691,37 @@ def get_num_phases(multi_phased_distill_schedule, step): "--checkpointing_steps", type=int, default=500, - help=( - "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" - " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" - " training using `--resume_from_checkpoint`." - ), + help= + ("Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`."), ) parser.add_argument("--shift", type=float, default=1.0) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, - help=( - "Whether training should be resumed from a previous checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), + help= + ("Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), ) parser.add_argument( "--resume_from_lora_checkpoint", type=str, default=None, - help=( - "Whether training should be resumed from a previous lora checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), + help= + ("Whether training should be resumed from a previous lora checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), ) parser.add_argument( "--logging_dir", type=str, default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), + help= + ("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), ) # optimizer & scheduler & Training @@ -756,25 +730,29 @@ def get_num_phases(multi_phased_distill_schedule, step): "--max_train_steps", type=int, default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + help= + "Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", + help= + "Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, - help="Initial learning rate (after the potential warmup period) to use.", + help= + "Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=False, - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + help= + "Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_warmup_steps", @@ -782,41 +760,47 @@ def get_num_phases(multi_phased_distill_schedule, step): default=10, help="Number of steps for the warmup in the lr scheduler.", ) - parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." - ) + parser.add_argument("--max_grad_norm", + default=1.0, + type=float, + help="Max gradient norm.") parser.add_argument( "--gradient_checkpointing", action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + help= + "Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument("--selective_checkpointing", type=float, default=1.0) parser.add_argument( "--allow_tf32", action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), + help= + ("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), + help= + ("Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), ) parser.add_argument( "--use_cpu_offload", action="store_true", - help="Whether to use CPU offload for param & gradient & optimizer states.", + help= + "Whether to use CPU offload for param & gradient & optimizer states.", ) - parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel") + parser.add_argument("--sp_size", + type=int, + default=1, + help="For sequence parallel") parser.add_argument( "--train_sp_batch_size", type=int, @@ -830,12 +814,14 @@ def get_num_phases(multi_phased_distill_schedule, step): default=False, help="Whether to use LoRA for finetuning.", ) - parser.add_argument( - "--lora_alpha", type=int, default=256, help="Alpha parameter for LoRA." - ) - parser.add_argument( - "--lora_rank", type=int, default=128, help="LoRA rank parameter. " - ) + parser.add_argument("--lora_alpha", + type=int, + default=256, + help="Alpha parameter for LoRA.") + parser.add_argument("--lora_rank", + type=int, + default=128, + help="LoRA rank parameter. ") parser.add_argument("--fsdp_sharding_startegy", default="full") # lr_scheduler @@ -843,10 +829,9 @@ def get_num_phases(multi_phased_distill_schedule, step): "--lr_scheduler", type=str, default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), + help= + ('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), ) parser.add_argument("--num_euler_timesteps", type=int, default=100) parser.add_argument( @@ -866,13 +851,15 @@ def get_num_phases(multi_phased_distill_schedule, step): action="store_true", help="Whether to apply the cfg_solver.", ) - parser.add_argument( - "--distill_cfg", type=float, default=3.0, help="Distillation coefficient." - ) + parser.add_argument("--distill_cfg", + type=float, + default=3.0, + help="Distillation coefficient.") # ["euler_linear_quadratic", "pcm", "pcm_linear_qudratic"] - parser.add_argument( - "--scheduler_type", type=str, default="pcm", help="The scheduler type to use." - ) + parser.add_argument("--scheduler_type", + type=str, + default="pcm", + help="The scheduler type to use.") parser.add_argument( "--linear_quadratic_threshold", type=float, @@ -885,11 +872,16 @@ def get_num_phases(multi_phased_distill_schedule, step): default=0.5, help="Range for linear quadratic scheduler.", ) - parser.add_argument( - "--weight_decay", type=float, default=0.001, help="Weight decay to apply." - ) - parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA.") - parser.add_argument("--multi_phased_distill_schedule", type=str, default=None) + parser.add_argument("--weight_decay", + type=float, + default=0.001, + help="Weight decay to apply.") + parser.add_argument("--use_ema", + action="store_true", + help="Whether to use EMA.") + parser.add_argument("--multi_phased_distill_schedule", + type=str, + default=None) parser.add_argument("--pred_decay_weight", type=float, default=0.0) parser.add_argument("--pred_decay_type", default="l1") parser.add_argument("--hunyuan_teacher_disable_cfg", action="store_true") diff --git a/fastvideo/distill/discriminator.py b/fastvideo/distill/discriminator.py index 58e5812..2ea7264 100644 --- a/fastvideo/distill/discriminator.py +++ b/fastvideo/distill/discriminator.py @@ -1,29 +1,11 @@ -from typing import Any, Dict, Optional, Union - -import torch import torch.nn as nn - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin -from diffusers.models.attention import JointTransformerBlock -from diffusers.models.attention_processor import Attention, AttentionProcessor -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.normalization import AdaLayerNormContinuous -from diffusers.utils import ( - USE_PEFT_BACKEND, - is_torch_version, - logging, - scale_lora_layers, - unscale_lora_layers, -) -from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed -from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput -from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel +from diffusers.utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name class DiscriminatorHead(nn.Module): + def __init__(self, input_channel, output_channel=1): super().__init__() inner_channel = 1024 @@ -57,30 +39,31 @@ def forward(self, x): class Discriminator(nn.Module): + def __init__( - self, stride=8, num_h_per_head=1, adapter_channel_dims=[3072], total_layers=48, + self, + stride=8, + num_h_per_head=1, + adapter_channel_dims=[3072], + total_layers=48, ): super().__init__() adapter_channel_dims = adapter_channel_dims * (total_layers // stride) self.stride = stride self.num_h_per_head = num_h_per_head self.head_num = len(adapter_channel_dims) - self.heads = nn.ModuleList( - [ - nn.ModuleList( - [ - DiscriminatorHead(adapter_channel) - for _ in range(self.num_h_per_head) - ] - ) - for adapter_channel in adapter_channel_dims - ] - ) + self.heads = nn.ModuleList([ + nn.ModuleList([ + DiscriminatorHead(adapter_channel) + for _ in range(self.num_h_per_head) + ]) for adapter_channel in adapter_channel_dims + ]) def forward(self, features): outputs = [] def create_custom_forward(module): + def custom_forward(*inputs): return module(*inputs) diff --git a/fastvideo/distill/solver.py b/fastvideo/distill/solver.py index 7ede6fe..1b960c9 100644 --- a/fastvideo/distill/solver.py +++ b/fastvideo/distill/solver.py @@ -3,11 +3,10 @@ import numpy as np import torch - from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.utils import BaseOutput, logging -from diffusers.utils.torch_utils import randn_tensor from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput, logging + from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -21,7 +20,7 @@ class PCMFMSchedulerOutput(BaseOutput): def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) + return out.reshape(b, *((1, ) * (len(x_shape) - 1))) class PCMFMScheduler(SchedulerMixin, ConfigMixin): @@ -40,26 +39,28 @@ def __init__( ): if linear_quadratic: linear_steps = int(num_train_timesteps * linear_range) - sigmas = linear_quadratic_schedule( - num_train_timesteps, linear_quadratic_threshold, linear_steps - ) + sigmas = linear_quadratic_schedule(num_train_timesteps, + linear_quadratic_threshold, + linear_steps) sigmas = torch.tensor(sigmas).to(dtype=torch.float32) else: - timesteps = np.linspace( - 1, num_train_timesteps, num_train_timesteps, dtype=np.float32 - )[::-1].copy() + timesteps = np.linspace(1, + num_train_timesteps, + num_train_timesteps, + dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) sigmas = timesteps / num_train_timesteps sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) - self.euler_timesteps = ( - np.arange(1, pcm_timesteps + 1) * (num_train_timesteps // pcm_timesteps) - ).round().astype(np.int64) - 1 + self.euler_timesteps = (np.arange(1, pcm_timesteps + 1) * + (num_train_timesteps // + pcm_timesteps)).round().astype(np.int64) - 1 self.sigmas = sigmas.numpy()[::-1][self.euler_timesteps] self.sigmas = torch.from_numpy((self.sigmas[::-1].copy())) self.timesteps = self.sigmas * num_train_timesteps self._step_index = None self._begin_index = None - self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() @@ -118,9 +119,9 @@ def scale_noise( def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps - def set_timesteps( - self, num_inference_steps: int, device: Union[str, torch.device] = None - ): + def set_timesteps(self, + num_inference_steps: int, + device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -131,9 +132,10 @@ def set_timesteps( The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps - inference_indices = np.linspace( - 0, self.config.pcm_timesteps, num=num_inference_steps, endpoint=False - ) + inference_indices = np.linspace(0, + self.config.pcm_timesteps, + num=num_inference_steps, + endpoint=False) inference_indices = np.floor(inference_indices).astype(np.int64) inference_indices = torch.from_numpy(inference_indices).long() @@ -141,8 +143,8 @@ def set_timesteps( timesteps = self.sigmas_ * self.config.num_train_timesteps self.timesteps = timesteps.to(device=device) self.sigmas_ = torch.cat( - [self.sigmas_, torch.zeros(1, device=self.sigmas_.device)] - ) + [self.sigmas_, + torch.zeros(1, device=self.sigmas_.device)]) self._step_index = None self._begin_index = None @@ -204,18 +206,12 @@ def step( returned, otherwise a tuple is returned where the first element is the sample tensor. """ - if ( - isinstance(timestep, int) - or isinstance(timestep, torch.IntTensor) - or isinstance(timestep, torch.LongTensor) - ): - raise ValueError( - ( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep." - ), - ) + if (isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor)): + raise ValueError(( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep."), ) if self.step_index is None: self._init_step_index(timestep) @@ -233,7 +229,7 @@ def step( self._step_index += 1 if not return_dict: - return (prev_sample,) + return (prev_sample, ) return PCMFMSchedulerOutput(prev_sample=prev_sample) @@ -242,19 +238,21 @@ def __len__(self): class EulerSolver: + def __init__(self, sigmas, timesteps=1000, euler_timesteps=50): self.step_ratio = timesteps // euler_timesteps - self.euler_timesteps = ( - np.arange(1, euler_timesteps + 1) * self.step_ratio - ).round().astype(np.int64) - 1 - self.euler_timesteps_prev = np.asarray([0] + self.euler_timesteps[:-1].tolist()) + self.euler_timesteps = (np.arange(1, euler_timesteps + 1) * + self.step_ratio).round().astype(np.int64) - 1 + self.euler_timesteps_prev = np.asarray( + [0] + self.euler_timesteps[:-1].tolist()) self.sigmas = sigmas[self.euler_timesteps] self.sigmas_prev = np.asarray( [sigmas[0]] + sigmas[self.euler_timesteps[:-1]].tolist() ) # either use sigma0 or 0 self.euler_timesteps = torch.from_numpy(self.euler_timesteps).long() - self.euler_timesteps_prev = torch.from_numpy(self.euler_timesteps_prev).long() + self.euler_timesteps_prev = torch.from_numpy( + self.euler_timesteps_prev).long() self.sigmas = torch.from_numpy(self.sigmas) self.sigmas_prev = torch.from_numpy(self.sigmas_prev) @@ -267,38 +265,44 @@ def to(self, device): return self def euler_step(self, sample, model_pred, timestep_index): - sigma = extract_into_tensor(self.sigmas, timestep_index, model_pred.shape) - sigma_prev = extract_into_tensor( - self.sigmas_prev, timestep_index, model_pred.shape - ) + sigma = extract_into_tensor(self.sigmas, timestep_index, + model_pred.shape) + sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index, + model_pred.shape) x_prev = sample + (sigma_prev - sigma) * model_pred return x_prev def euler_style_multiphase_pred( - self, sample, model_pred, timestep_index, multiphase, is_target=False, + self, + sample, + model_pred, + timestep_index, + multiphase, + is_target=False, ): - inference_indices = np.linspace( - 0, len(self.euler_timesteps), num=multiphase, endpoint=False - ) + inference_indices = np.linspace(0, + len(self.euler_timesteps), + num=multiphase, + endpoint=False) inference_indices = np.floor(inference_indices).astype(np.int64) - inference_indices = ( - torch.from_numpy(inference_indices).long().to(self.euler_timesteps.device) - ) + inference_indices = (torch.from_numpy(inference_indices).long().to( + self.euler_timesteps.device)) expanded_timestep_index = timestep_index.unsqueeze(1).expand( - -1, inference_indices.size(0) - ) + -1, inference_indices.size(0)) valid_indices_mask = expanded_timestep_index >= inference_indices - last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax(dim=1) + last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax( + dim=1) last_valid_index = inference_indices.size(0) - 1 - last_valid_index timestep_index_end = inference_indices[last_valid_index] if is_target: - sigma = extract_into_tensor(self.sigmas_prev, timestep_index, sample.shape) + sigma = extract_into_tensor(self.sigmas_prev, timestep_index, + sample.shape) else: - sigma = extract_into_tensor(self.sigmas, timestep_index, sample.shape) - sigma_prev = extract_into_tensor( - self.sigmas_prev, timestep_index_end, sample.shape - ) + sigma = extract_into_tensor(self.sigmas, timestep_index, + sample.shape) + sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index_end, + sample.shape) x_prev = sample + (sigma_prev - sigma) * model_pred return x_prev, timestep_index_end diff --git a/fastvideo/distill_adv.py b/fastvideo/distill_adv.py index e47b6f9..51a2048 100644 --- a/fastvideo/distill_adv.py +++ b/fastvideo/distill_adv.py @@ -1,67 +1,50 @@ +# !/bin/python3 +# isort: skip_file import argparse -from email.policy import strict -import logging import math import os -import shutil -from pathlib import Path -from fastvideo.utils.parallel_states import ( - initialize_sequence_parallel_state, - destroy_sequence_parallel_group, - get_sequence_parallel_state, - nccl_info, -) -from fastvideo.utils.communications import sp_parallel_dataloader_wrapper, broadcast -from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input -from fastvideo.utils.validation import log_validation import time -from torch.utils.data import DataLoader -import torch -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - StateDictType, - FullStateDictConfig, -) -from fastvideo.utils.load import load_transformer +from collections import deque +from copy import deepcopy -from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule -import json -from torch.utils.data.distributed import DistributedSampler -from fastvideo.utils.dataset_utils import LengthGroupedSampler +import torch +import torch.distributed as dist import wandb from accelerate.utils import set_seed -from tqdm.auto import tqdm -from fastvideo.utils.fsdp_util import ( - get_dit_fsdp_kwargs, - apply_fsdp_checkpointing, - get_discriminator_fsdp_kwargs, -) -import diffusers from diffusers import FlowMatchEulerDiscreteScheduler -from fastvideo.distill.discriminator import Discriminator -from fastvideo.distill.solver import EulerSolver, extract_into_tensor -from copy import deepcopy from diffusers.optimization import get_scheduler -from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel from diffusers.utils import check_min_version -from fastvideo.dataset.latent_datasets import LatentDataset, latent_collate_function -import torch.distributed as dist from peft import LoraConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from tqdm.auto import tqdm + +from fastvideo.dataset.latent_datasets import (LatentDataset, + latent_collate_function) +from fastvideo.distill.discriminator import Discriminator +from fastvideo.distill.solver import EulerSolver, extract_into_tensor +from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input +from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule from fastvideo.utils.checkpoint import ( - save_checkpoint, - save_lora_checkpoint, - resume_lora_optimizer, - resume_training, - save_checkpoint_generator_discriminator, - resume_training_generator_discriminator, -) + resume_lora_optimizer, resume_training_generator_discriminator, + save_checkpoint, save_lora_checkpoint) +from fastvideo.utils.communications import (broadcast, + sp_parallel_dataloader_wrapper) +from fastvideo.utils.dataset_utils import LengthGroupedSampler +from fastvideo.utils.fsdp_util import (apply_fsdp_checkpointing, + get_discriminator_fsdp_kwargs, + get_dit_fsdp_kwargs) +from fastvideo.utils.load import load_transformer from fastvideo.utils.logging_ import main_print +from fastvideo.utils.parallel_states import (destroy_sequence_parallel_group, + get_sequence_parallel_state, + initialize_sequence_parallel_state + ) +from fastvideo.utils.validation import log_validation # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.31.0") -import time -from collections import deque def gan_d_loss( @@ -100,10 +83,9 @@ def gan_d_loss( fake_outputs = discriminator(fake_features) real_outputs = discriminator(real_features) for fake_output, real_output in zip(fake_outputs, real_outputs): - loss += ( - torch.mean(weight * torch.relu(fake_output.float() + 1)) - + torch.mean(weight * torch.relu(1 - real_output.float())) - ) / (discriminator.head_num * discriminator.num_h_per_head) + loss += (torch.mean(weight * torch.relu(fake_output.float() + 1)) + + torch.mean(weight * torch.relu(1 - real_output.float()))) / ( + discriminator.head_num * discriminator.num_h_per_head) return loss @@ -127,11 +109,10 @@ def gan_g_loss( output_features_stride=discriminator_head_stride, return_dict=False, )[1] - fake_outputs = discriminator(features,) + fake_outputs = discriminator(features, ) for fake_output in fake_outputs: loss += torch.mean(weight * torch.relu(1 - fake_output.float())) / ( - discriminator.head_num * discriminator.num_h_per_head - ) + discriminator.head_num * discriminator.num_h_per_head) return loss @@ -170,20 +151,22 @@ def distill_one_step_adv( model_input = normalize_dit_input(model_type, latents) noise = torch.randn_like(model_input) bsz = model_input.shape[0] - index = torch.randint( - 0, num_euler_timesteps, (bsz,), device=model_input.device - ).long() + index = torch.randint(0, + num_euler_timesteps, (bsz, ), + device=model_input.device).long() if sp_size > 1: broadcast(index) # Add noise according to flow matching. # sigmas = get_sigmas(start_timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) sigmas = extract_into_tensor(solver.sigmas, index, model_input.shape) - sigmas_prev = extract_into_tensor(solver.sigmas_prev, index, model_input.shape) + sigmas_prev = extract_into_tensor(solver.sigmas_prev, index, + model_input.shape) timesteps = (sigmas * noise_scheduler.config.num_train_timesteps).view(-1) # if squeeze to [], unsqueeze to [1] - timesteps_prev = (sigmas_prev * noise_scheduler.config.num_train_timesteps).view(-1) + timesteps_prev = (sigmas_prev * + noise_scheduler.config.num_train_timesteps).view(-1) noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input # Predict the noise residual @@ -198,10 +181,8 @@ def distill_one_step_adv( # if accelerator.is_main_process: model_pred, end_index = solver.euler_style_multiphase_pred( - noisy_model_input, model_pred, index, multiphase - ) + noisy_model_input, model_pred, index, multiphase) - weighting = 1.0 # # simplified flow matching aka 0-rectified flow matching loss # # target = model_input - noise # target = model_input @@ -210,15 +191,17 @@ def distill_one_step_adv( adv_index[i] = torch.randint( end_index[i].item(), end_index[i].item() + num_euler_timesteps // multiphase, - (1,), + (1, ), dtype=end_index.dtype, device=end_index.device, ) - sigmas_end = extract_into_tensor(solver.sigmas_prev, end_index, model_input.shape) - sigmas_adv = extract_into_tensor(solver.sigmas_prev, adv_index, model_input.shape) - timesteps_end = (sigmas_end * noise_scheduler.config.num_train_timesteps).view(-1) - timesteps_adv = (sigmas_adv * noise_scheduler.config.num_train_timesteps).view(-1) + sigmas_end = extract_into_tensor(solver.sigmas_prev, end_index, + model_input.shape) + sigmas_adv = extract_into_tensor(solver.sigmas_prev, adv_index, + model_input.shape) + timesteps_adv = (sigmas_adv * + noise_scheduler.config.num_train_timesteps).view(-1) with torch.no_grad(): w = distill_cfg @@ -242,9 +225,8 @@ def distill_one_step_adv( uncond_prompt_mask.unsqueeze(0).expand(bsz, -1), return_dict=False, )[0].float() - teacher_output = cond_teacher_output + w * ( - cond_teacher_output - uncond_teacher_output - ) + teacher_output = cond_teacher_output + w * (cond_teacher_output - + uncond_teacher_output) x_prev = solver.euler_step(noisy_model_input, teacher_output, index) # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n @@ -259,21 +241,19 @@ def distill_one_step_adv( )[0] target, end_index = solver.euler_style_multiphase_pred( - x_prev, target_pred, index, multiphase, True - ) + x_prev, target_pred, index, multiphase, True) - real_adv = ( - (1 - sigmas_adv) * target + (sigmas_adv - sigmas_end) * torch.randn_like(target) - ) / (1 - sigmas_end) - fake_adv = ( - (1 - sigmas_adv) * model_pred - + (sigmas_adv - sigmas_end) * torch.randn_like(model_pred) - ) / (1 - sigmas_end) + real_adv = ((1 - sigmas_adv) * target + + (sigmas_adv - sigmas_end) * torch.randn_like(target)) / ( + 1 - sigmas_end) + fake_adv = ((1 - sigmas_adv) * model_pred + + (sigmas_adv - sigmas_end) * torch.randn_like(model_pred)) / ( + 1 - sigmas_end) huber_c = 0.001 g_loss = torch.mean( - torch.sqrt((model_pred.float() - target.float()) ** 2 + huber_c ** 2) - huber_c - ) + torch.sqrt((model_pred.float() - target.float())**2 + huber_c**2) - + huber_c) discriminator.requires_grad_(False) with torch.autocast("cuda", dtype=torch.bfloat16): g_gan_loss = adv_weight * gan_g_loss( @@ -342,7 +322,7 @@ def main(args): if rank <= 0 and args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) - # For mixed precision training we cast all non-trainable weigths to half-precision + # For mixed precision training we cast all non-trainable weights to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. # Create model: @@ -388,35 +368,46 @@ def main(args): args.use_cpu_offload, args.master_weight_type, ) - discriminator_fsdp_kwargs = get_discriminator_fsdp_kwargs(args.master_weight_type) + discriminator_fsdp_kwargs = get_discriminator_fsdp_kwargs( + args.master_weight_type) if args.use_lora: assert args.model_type == "mochi", "LoRA is only supported for Mochi model." transformer.config.lora_rank = args.lora_rank transformer.config.lora_alpha = args.lora_alpha - transformer.config.lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + transformer.config.lora_target_modules = [ + "to_k", "to_q", "to_v", "to_out.0" + ] transformer._no_split_modules = no_split_modules - fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"](transformer) + fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"]( + transformer) - transformer = FSDP(transformer, **fsdp_kwargs,) - teacher_transformer = FSDP(teacher_transformer, **fsdp_kwargs,) - discriminator = FSDP(discriminator, **discriminator_fsdp_kwargs,) - main_print(f"--> model loaded") + transformer = FSDP( + transformer, + **fsdp_kwargs, + ) + teacher_transformer = FSDP( + teacher_transformer, + **fsdp_kwargs, + ) + discriminator = FSDP( + discriminator, + **discriminator_fsdp_kwargs, + ) + main_print("--> model loaded") if args.gradient_checkpointing: - apply_fsdp_checkpointing( - transformer, no_split_modules, args.selective_checkpointing - ) - apply_fsdp_checkpointing( - teacher_transformer, no_split_modules, args.selective_checkpointing - ) + apply_fsdp_checkpointing(transformer, no_split_modules, + args.selective_checkpointing) + apply_fsdp_checkpointing(teacher_transformer, no_split_modules, + args.selective_checkpointing) # Set model as trainable. transformer.train() teacher_transformer.requires_grad_(False) noise_scheduler = FlowMatchEulerDiscreteScheduler(shift=args.shift) if args.scheduler_type == "pcm_linear_quadratic": sigmas = linear_quadratic_schedule( - noise_scheduler.config.num_train_timesteps, args.linear_quadratic_threshold - ) + noise_scheduler.config.num_train_timesteps, + args.linear_quadratic_threshold) sigmas = torch.tensor(sigmas).to(dtype=torch.float32) else: sigmas = noise_scheduler.sigmas @@ -427,7 +418,8 @@ def main(args): ) solver.to(device) params_to_optimize = transformer.parameters() - params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize)) + params_to_optimize = list( + filter(lambda p: p.requires_grad, params_to_optimize)) optimizer = torch.optim.AdamW( params_to_optimize, @@ -448,8 +440,7 @@ def main(args): init_steps = 0 if args.resume_from_lora_checkpoint: transformer, optimizer, init_steps = resume_lora_optimizer( - transformer, args.resume_from_lora_checkpoint, optimizer - ) + transformer, args.resume_from_lora_checkpoint, optimizer) elif args.resume_from_checkpoint: ( transformer, @@ -478,23 +469,19 @@ def main(args): last_epoch=init_steps - 1, ) - train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, args.cfg) + train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, + args.cfg) uncond_prompt_embed = train_dataset.uncond_prompt_embed uncond_prompt_mask = train_dataset.uncond_prompt_mask - sampler = ( - LengthGroupedSampler( - args.train_batch_size, - rank=rank, - world_size=world_size, - lengths=train_dataset.lengths, - group_frame=args.group_frame, - group_resolution=args.group_resolution, - ) - if (args.group_frame or args.group_resolution) - else DistributedSampler( - train_dataset, rank=rank, num_replicas=world_size, shuffle=False - ) - ) + sampler = (LengthGroupedSampler( + args.train_batch_size, + rank=rank, + world_size=world_size, + lengths=train_dataset.lengths, + group_frame=args.group_frame, + group_resolution=args.group_resolution, + ) if (args.group_frame or args.group_resolution) else DistributedSampler( + train_dataset, rank=rank, num_replicas=world_size, shuffle=False)) train_dataloader = DataLoader( train_dataset, @@ -507,41 +494,38 @@ def main(args): ) assert args.gradient_accumulation_steps == 1 num_update_steps_per_epoch = math.ceil( - len(train_dataloader) - / args.gradient_accumulation_steps - * args.sp_size - / args.train_sp_batch_size - ) - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + len(train_dataloader) / args.gradient_accumulation_steps * + args.sp_size / args.train_sp_batch_size) + args.num_train_epochs = math.ceil(args.max_train_steps / + num_update_steps_per_epoch) if rank <= 0: project = args.tracker_project_name or "fastvideo" wandb.init(project=project, config=args) # Train! - total_batch_size = ( - args.train_batch_size - * world_size - * args.gradient_accumulation_steps - / args.sp_size - * args.train_sp_batch_size - ) + total_batch_size = (args.train_batch_size * world_size * + args.gradient_accumulation_steps / args.sp_size * + args.train_sp_batch_size) main_print("***** Running training *****") main_print(f" Num examples = {len(train_dataset)}") main_print(f" Dataloader size = {len(train_dataloader)}") main_print(f" Num Epochs = {args.num_train_epochs}") main_print(f" Resume training from step {init_steps}") - main_print(f" Instantaneous batch size per device = {args.train_batch_size}") + main_print( + f" Instantaneous batch size per device = {args.train_batch_size}") main_print( f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}" ) - main_print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + main_print( + f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") main_print(f" Total optimization steps = {args.max_train_steps}") main_print( f" Total training parameters per FSDP shard = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e9} B" ) # print dtype - main_print(f" Master weight dtype: {transformer.parameters().__next__().dtype}") + main_print( + f" Master weight dtype: {transformer.parameters().__next__().dtype}") progress_bar = tqdm( range(0, args.max_train_steps), @@ -559,6 +543,7 @@ def main(args): ) step_times = deque(maxlen=100) + # log_validation(args, transformer, device, # torch.bfloat16, 0, scheduler_type=args.scheduler_type, shift=args.shift, num_euler_timesteps=args.num_euler_timesteps, linear_quadratic_threshold=args.linear_quadratic_threshold,ema=False) def get_num_phases(multi_phased_distill_schedule, step): @@ -610,15 +595,13 @@ def get_num_phases(multi_phased_distill_schedule, step): step_times.append(step_time) avg_step_time = sum(step_times) / len(step_times) - progress_bar.set_postfix( - { - "g_loss": f"{generator_loss:.4f}", - "d_loss": f"{discriminator_loss:.4f}", - "g_grad_norm": generator_grad_norm, - "d_grad_norm": discriminator_grad_norm, - "step_time": f"{step_time:.2f}s", - } - ) + progress_bar.set_postfix({ + "g_loss": f"{generator_loss:.4f}", + "d_loss": f"{discriminator_loss:.4f}", + "g_grad_norm": generator_grad_norm, + "d_grad_norm": discriminator_grad_norm, + "step_time": f"{step_time:.2f}s", + }) progress_bar.update(1) if rank <= 0: wandb.log( @@ -637,9 +620,8 @@ def get_num_phases(multi_phased_distill_schedule, step): main_print(f"--> saving checkpoint at step {step}") if args.use_lora: # Save LoRA weights - save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, step - ) + save_lora_checkpoint(transformer, optimizer, rank, + args.output_dir, step) else: # Your existing checkpoint saving code # TODO @@ -652,9 +634,8 @@ def get_num_phases(multi_phased_distill_schedule, step): # args.output_dir, # step, # ) - save_checkpoint( - transformer, rank, args.output_dir, args.max_train_steps - ) + save_checkpoint(transformer, rank, args.output_dir, + args.max_train_steps) main_print(f"--> checkpoint saved at step {step}") dist.barrier() if args.log_validation and step % args.validation_steps == 0: @@ -673,11 +654,11 @@ def get_num_phases(multi_phased_distill_schedule, step): ) if args.use_lora: - save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, args.max_train_steps - ) + save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, + args.max_train_steps) else: - save_checkpoint(transformer, rank, args.output_dir, args.max_train_steps) + save_checkpoint(transformer, rank, args.output_dir, + args.max_train_steps) if get_sequence_parallel_state(): destroy_sequence_parallel_group() @@ -686,9 +667,10 @@ def get_num_phases(multi_phased_distill_schedule, step): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--model_type", type=str, default="mochi", help="The type of model to train." - ) + parser.add_argument("--model_type", + type=str, + default="mochi", + help="The type of model to train.") # dataset & dataloader parser.add_argument("--data_json_path", type=str, required=True) parser.add_argument("--num_frames", type=int, default=163) @@ -696,7 +678,8 @@ def get_num_phases(multi_phased_distill_schedule, step): "--dataloader_num_workers", type=int, default=10, - help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + help= + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", ) parser.add_argument( "--train_batch_size", @@ -704,9 +687,10 @@ def get_num_phases(multi_phased_distill_schedule, step): default=16, help="Batch size (per device) for the training dataloader.", ) - parser.add_argument( - "--num_latent_t", type=int, default=28, help="Number of latent timesteps." - ) + parser.add_argument("--num_latent_t", + type=int, + default=28, + help="Number of latent timesteps.") parser.add_argument("--group_frame", action="store_true") # TODO parser.add_argument("--group_resolution", action="store_true") # TODO @@ -725,14 +709,16 @@ def get_num_phases(multi_phased_distill_schedule, step): parser.add_argument("--validation_steps", type=float, default=64) parser.add_argument("--log_validation", action="store_true") parser.add_argument("--tracker_project_name", type=str, default=None) - parser.add_argument( - "--seed", type=int, default=None, help="A seed for reproducible training." - ) + parser.add_argument("--seed", + type=int, + default=None, + help="A seed for reproducible training.") parser.add_argument( "--output_dir", type=str, default=None, - help="The output directory where the model predictions and checkpoints will be written.", + help= + "The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--checkpoints_total_limit", @@ -744,11 +730,10 @@ def get_num_phases(multi_phased_distill_schedule, step): "--checkpointing_steps", type=int, default=500, - help=( - "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" - " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" - " training using `--resume_from_checkpoint`." - ), + help= + ("Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`."), ) parser.add_argument("--validation_prompt_dir", type=str) parser.add_argument("--shift", type=float, default=1.0) @@ -756,28 +741,27 @@ def get_num_phases(multi_phased_distill_schedule, step): "--resume_from_checkpoint", type=str, default=None, - help=( - "Whether training should be resumed from a previous checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), + help= + ("Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), ) parser.add_argument( "--resume_from_lora_checkpoint", type=str, default=None, - help=( - "Whether training should be resumed from a previous lora checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), + help= + ("Whether training should be resumed from a previous lora checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), ) parser.add_argument( "--logging_dir", type=str, default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), + help= + ("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), ) # optimizer & scheduler & Training @@ -786,25 +770,29 @@ def get_num_phases(multi_phased_distill_schedule, step): "--max_train_steps", type=int, default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + help= + "Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, - help="Initial learning rate (after the potential warmup period) to use.", + help= + "Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--discriminator_learning_rate", type=float, default=1e-5, - help="Initial learning rate (after the potential warmup period) to use.", + help= + "Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=False, - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + help= + "Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_warmup_steps", @@ -812,41 +800,47 @@ def get_num_phases(multi_phased_distill_schedule, step): default=10, help="Number of steps for the warmup in the lr scheduler.", ) - parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." - ) + parser.add_argument("--max_grad_norm", + default=1.0, + type=float, + help="Max gradient norm.") parser.add_argument( "--gradient_checkpointing", action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + help= + "Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument("--selective_checkpointing", type=float, default=1.0) parser.add_argument( "--allow_tf32", action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), + help= + ("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), + help= + ("Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), ) parser.add_argument( "--use_cpu_offload", action="store_true", - help="Whether to use CPU offload for param & gradient & optimizer states.", + help= + "Whether to use CPU offload for param & gradient & optimizer states.", ) - parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel") + parser.add_argument("--sp_size", + type=int, + default=1, + help="For sequence parallel") parser.add_argument( "--train_sp_batch_size", type=int, @@ -860,19 +854,24 @@ def get_num_phases(multi_phased_distill_schedule, step): default=False, help="Whether to use LoRA for finetuning.", ) - parser.add_argument( - "--lora_alpha", type=int, default=256, help="Alpha parameter for LoRA." - ) - parser.add_argument( - "--lora_rank", type=int, default=128, help="LoRA rank parameter. " - ) + parser.add_argument("--lora_alpha", + type=int, + default=256, + help="Alpha parameter for LoRA.") + parser.add_argument("--lora_rank", + type=int, + default=128, + help="LoRA rank parameter. ") parser.add_argument("--fsdp_sharding_startegy", default="full") - parser.add_argument("--multi_phased_distill_schedule", type=str, default=None) + parser.add_argument("--multi_phased_distill_schedule", + type=str, + default=None) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", + help= + "Number of updates steps to accumulate before performing a backward/update pass.", ) # lr_scheduler @@ -880,10 +879,9 @@ def get_num_phases(multi_phased_distill_schedule, step): "--lr_scheduler", type=str, default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), + help= + ('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), ) parser.add_argument("--num_euler_timesteps", type=int, default=100) parser.add_argument( @@ -903,13 +901,15 @@ def get_num_phases(multi_phased_distill_schedule, step): action="store_true", help="Whether to apply the cfg_solver.", ) - parser.add_argument( - "--distill_cfg", type=float, default=3.0, help="Distillation coefficient." - ) + parser.add_argument("--distill_cfg", + type=float, + default=3.0, + help="Distillation coefficient.") # ["euler_linear_quadratic", "pcm", "pcm_linear_qudratic"] - parser.add_argument( - "--scheduler_type", type=str, default="pcm", help="The scheduler type to use." - ) + parser.add_argument("--scheduler_type", + type=str, + default="pcm", + help="The scheduler type to use.") parser.add_argument( "--adv_weight", type=float, diff --git a/fastvideo/models/flash_attn_no_pad.py b/fastvideo/models/flash_attn_no_pad.py index ff917e2..3bffc69 100644 --- a/fastvideo/models/flash_attn_no_pad.py +++ b/fastvideo/models/flash_attn_no_pad.py @@ -1,21 +1,25 @@ +from einops import rearrange from flash_attn import flash_attn_varlen_qkvpacked_func from flash_attn.bert_padding import pad_input, unpad_input -from einops import rearrange -def flash_attn_no_pad( - qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None -): +def flash_attn_no_pad(qkv, + key_padding_mask, + causal=False, + dropout_p=0.0, + softmax_scale=None): # adapted from https://github.com/Dao-AILab/flash-attention/blob/13403e81157ba37ca525890f2f0f2137edf75311/flash_attn/flash_attention.py#L27 batch_size = qkv.shape[0] seqlen = qkv.shape[1] nheads = qkv.shape[-2] x = rearrange(qkv, "b s three h d -> b s (three h d)") x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_input( - x, key_padding_mask - ) + x, key_padding_mask) - x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) + x_unpad = rearrange(x_unpad, + "nnz (three h d) -> nnz three h d", + three=3, + h=nheads) output_unpad = flash_attn_varlen_qkvpacked_func( x_unpad, cu_seqlens, @@ -25,9 +29,8 @@ def flash_attn_no_pad( causal=causal, ) output = rearrange( - pad_input( - rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen - ), + pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, + batch_size, seqlen), "b s (h d) -> b s h d", h=nheads, ) diff --git a/fastvideo/models/hunyuan/constants.py b/fastvideo/models/hunyuan/constants.py index a3f8571..3dd67b2 100644 --- a/fastvideo/models/hunyuan/constants.py +++ b/fastvideo/models/hunyuan/constants.py @@ -1,4 +1,5 @@ import os + import torch __all__ = [ @@ -33,8 +34,7 @@ PROMPT_TEMPLATE_ENCODE = ( "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, " "quantity, text, spatial relationships of the objects and background:<|eot_id|>" - "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" -) + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>") PROMPT_TEMPLATE_ENCODE_VIDEO = ( "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " "1. The main content and theme of the video." @@ -42,13 +42,15 @@ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." "4. background environment, light, style and atmosphere." "5. camera angles, movements, and transitions used in the video:<|eot_id|>" - "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" -) + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>") NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" PROMPT_TEMPLATE = { - "dit-llm-encode": {"template": PROMPT_TEMPLATE_ENCODE, "crop_start": 36,}, + "dit-llm-encode": { + "template": PROMPT_TEMPLATE_ENCODE, + "crop_start": 36, + }, "dit-llm-encode-video": { "template": PROMPT_TEMPLATE_ENCODE_VIDEO, "crop_start": 95, diff --git a/fastvideo/models/hunyuan/diffusion/__init__.py b/fastvideo/models/hunyuan/diffusion/__init__.py index 2141aa3..754bc48 100644 --- a/fastvideo/models/hunyuan/diffusion/__init__.py +++ b/fastvideo/models/hunyuan/diffusion/__init__.py @@ -1,2 +1,3 @@ +# ruff: noqa: F401 from .pipelines import HunyuanVideoPipeline from .schedulers import FlowMatchDiscreteScheduler diff --git a/fastvideo/models/hunyuan/diffusion/pipelines/__init__.py b/fastvideo/models/hunyuan/diffusion/pipelines/__init__.py index e44cb61..fa109e1 100644 --- a/fastvideo/models/hunyuan/diffusion/pipelines/__init__.py +++ b/fastvideo/models/hunyuan/diffusion/pipelines/__init__.py @@ -1 +1,2 @@ +# ruff: noqa: F401 from .pipeline_hunyuan_video import HunyuanVideoPipeline diff --git a/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py b/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py index a2abf2e..86adda3 100644 --- a/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py +++ b/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py @@ -17,41 +17,34 @@ # # ============================================================================== import inspect -from typing import Any, Callable, Dict, List, Optional, Union, Tuple -import torch -import torch.distributed as dist -import numpy as np from dataclasses import dataclass -from packaging import version +from typing import Any, Callable, Dict, List, Optional, Union +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.configuration_utils import FrozenDict from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import ( - USE_PEFT_BACKEND, - deprecate, - logging, - replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, -) +from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate, logging, + replace_example_docstring, scale_lora_layers) from diffusers.utils.torch_utils import randn_tensor -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.utils import BaseOutput +from einops import rearrange + +from fastvideo.utils.communications import all_gather +from fastvideo.utils.parallel_states import (get_sequence_parallel_state, + nccl_info) from ...constants import PRECISION_TO_TYPE -from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D -from ...text_encoder import TextEncoder from ...modules import HYVideoDiffusionTransformer - -from einops import rearrange -from fastvideo.utils.parallel_states import get_sequence_parallel_state, nccl_info -from fastvideo.utils.communications import all_gather, all_to_all_4D -import torch.nn.functional as F +from ...text_encoder import TextEncoder +from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -63,16 +56,14 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std( - dim=list(range(1, noise_pred_text.ndim)), keepdim=True - ) + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), + keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = ( - guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - ) + noise_cfg = (guidance_rescale * noise_pred_rescaled + + (1 - guidance_rescale) * noise_cfg) return noise_cfg @@ -113,8 +104,7 @@ def retrieve_timesteps( ) if timesteps is not None: accepts_timesteps = "timesteps" in set( - inspect.signature(scheduler.set_timesteps).parameters.keys() - ) + inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -125,8 +115,7 @@ def retrieve_timesteps( num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set( - inspect.signature(scheduler.set_timesteps).parameters.keys() - ) + inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -169,7 +158,9 @@ class HunyuanVideoPipeline(DiffusionPipeline): model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _optional_components = ["text_encoder_2"] _exclude_from_cpu_offload = ["transformer"] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _callback_tensor_inputs = [ + "latents", "prompt_embeds", "negative_prompt_embeds" + ] def __init__( self, @@ -193,29 +184,25 @@ def __init__( self.args = args # ========================================================================================== - if ( - hasattr(scheduler.config, "steps_offset") - and scheduler.config.steps_offset != 1 - ): + if (hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate( - "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False - ) + " file") + deprecate("steps_offset!=1", + "1.0.0", + deprecation_message, + standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if ( - hasattr(scheduler.config, "clip_sample") - and scheduler.config.clip_sample is True - ): + if (hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is True): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" @@ -223,9 +210,10 @@ def __init__( " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) - deprecate( - "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False - ) + deprecate("clip_sample not set", + "1.0.0", + deprecation_message, + standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) @@ -237,8 +225,10 @@ def __init__( scheduler=scheduler, text_encoder_2=text_encoder_2, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels) - + 1) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor) def encode_prompt( self, @@ -303,23 +293,17 @@ def encode_prompt( else: scale_lora_layers(text_encoder.model, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer) + prompt = self.maybe_convert_prompt(prompt, + text_encoder.tokenizer) text_inputs = text_encoder.text2tokens(prompt, data_type=data_type) if clip_skip is None: - prompt_outputs = text_encoder.encode( - text_inputs, data_type=data_type, device=device - ) + prompt_outputs = text_encoder.encode(text_inputs, + data_type=data_type, + device=device) prompt_embeds = prompt_outputs.hidden_state else: prompt_outputs = text_encoder.encode( @@ -331,23 +315,23 @@ def encode_prompt( # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)] + prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = text_encoder.model.text_model.final_layer_norm( - prompt_embeds - ) + prompt_embeds) attention_mask = prompt_outputs.attention_mask if attention_mask is not None: attention_mask = attention_mask.to(device) bs_embed, seq_len = attention_mask.shape - attention_mask = attention_mask.repeat(1, num_videos_per_prompt) + attention_mask = attention_mask.repeat(1, + num_videos_per_prompt) attention_mask = attention_mask.view( - bs_embed * num_videos_per_prompt, seq_len - ) + bs_embed * num_videos_per_prompt, seq_len) if text_encoder is not None: prompt_embeds_dtype = text_encoder.dtype @@ -356,20 +340,21 @@ def encode_prompt( else: prompt_embeds_dtype = prompt_embeds.dtype - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, + device=device) if prompt_embeds.ndim == 2: bs_embed, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) - prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_videos_per_prompt, -1) else: bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view( - bs_embed * num_videos_per_prompt, seq_len, -1 - ) + bs_embed * num_videos_per_prompt, seq_len, -1) return ( prompt_embeds, @@ -380,7 +365,10 @@ def encode_prompt( def decode_latents(self, latents, enable_tiling=True): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" - deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + deprecate("decode_latents", + "1.0.0", + deprecation_message, + standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents if enable_tiling: @@ -439,17 +427,14 @@ def check_inputs( f"`video_length` has to be 1 or a multiple of 8 but is {video_length}." ) - if callback_steps is not None and ( - not isinstance(callback_steps, int) or callback_steps <= 0 - ): + if callback_steps is not None and (not isinstance(callback_steps, int) + or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) + f" {type(callback_steps)}.") if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs - for k in callback_on_step_end_tensor_inputs - ): + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) @@ -457,15 +442,13 @@ def check_inputs( if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) + " only forward one of the two.") elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and ( - not isinstance(prompt, str) and not isinstance(prompt, list) - ): + elif prompt is not None and (not isinstance(prompt, str) + and not isinstance(prompt, list)): raise ValueError( f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" ) @@ -481,8 +464,7 @@ def check_inputs( raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) + f" {negative_prompt_embeds.shape}.") def prepare_latents( self, @@ -510,9 +492,10 @@ def prepare_latents( ) if latents is None: - latents = randn_tensor( - shape, generator=generator, device=device, dtype=dtype - ) + latents = randn_tensor(shape, + generator=generator, + device=device, + dtype=dtype) else: latents = latents.to(device) @@ -604,7 +587,8 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -615,13 +599,9 @@ def __call__( cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, clip_skip: Optional[int] = None, - callback_on_step_end: Optional[ - Union[ - Callable[[int, int, Dict], None], - PipelineCallback, - MultiPipelineCallbacks, - ] - ] = None, + callback_on_step_end: Optional[Union[Callable[[int, int, Dict], + None], PipelineCallback, + MultiPipelineCallbacks, ]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], vae_ver: str = "88-4c-sd", enable_tiling: bool = False, @@ -727,7 +707,8 @@ def __call__( "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + if isinstance(callback_on_step_end, + (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet @@ -763,18 +744,12 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - device = ( - torch.device(f"cuda:{dist.get_rank()}") - if dist.is_initialized() - else self._execution_device - ) + device = (torch.device(f"cuda:{dist.get_rank()}") + if dist.is_initialized() else self._execution_device) # 3. Encode input prompt - lora_scale = ( - self.cross_attention_kwargs.get("scale", None) - if self.cross_attention_kwargs is not None - else None - ) + lora_scale = (self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None else None) ( prompt_embeds, @@ -830,14 +805,15 @@ def __call__( if prompt_mask is not None: prompt_mask = torch.cat([negative_prompt_mask, prompt_mask]) if prompt_embeds_2 is not None: - prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_embeds_2 = torch.cat( + [negative_prompt_embeds_2, prompt_embeds_2]) if prompt_mask_2 is not None: - prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2]) + prompt_mask_2 = torch.cat( + [negative_prompt_mask_2, prompt_mask_2]) # 4. Prepare timesteps extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs( - self.scheduler.set_timesteps, {"n_tokens": n_tokens} - ) + self.scheduler.set_timesteps, {"n_tokens": n_tokens}) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, @@ -869,27 +845,30 @@ def __call__( world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group if get_sequence_parallel_state(): - latents = rearrange( - latents, "b t (n s) h w -> b t n s h w", n=world_size - ).contiguous() + latents = rearrange(latents, + "b t (n s) h w -> b t n s h w", + n=world_size).contiguous() latents = latents[:, :, rank, :, :, :] # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_func_kwargs( - self.scheduler.step, {"generator": generator, "eta": eta}, + self.scheduler.step, + { + "generator": generator, + "eta": eta + }, ) target_dtype = PRECISION_TO_TYPE[self.args.precision] - autocast_enabled = ( - target_dtype != torch.float32 - ) and not self.args.disable_autocast + autocast_enabled = (target_dtype != + torch.float32) and not self.args.disable_autocast vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision] vae_autocast_enabled = ( - vae_dtype != torch.float32 - ) and not self.args.disable_autocast + vae_dtype != torch.float32) and not self.args.disable_autocast # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = len( + timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) # if is_progress_bar: @@ -899,40 +878,33 @@ def __call__( continue # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) - if self.do_classifier_free_guidance - else latents - ) + latent_model_input = (torch.cat( + [latents] * + 2) if self.do_classifier_free_guidance else latents) latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input, t) t_expand = t.repeat(latent_model_input.shape[0]) - guidance_expand = ( - torch.tensor( - [embedded_guidance_scale] * latent_model_input.shape[0], - dtype=torch.float32, - device=device, - ).to(target_dtype) - * 1000.0 - if embedded_guidance_scale is not None - else None - ) + guidance_expand = (torch.tensor( + [embedded_guidance_scale] * latent_model_input.shape[0], + dtype=torch.float32, + device=device, + ).to(target_dtype) * 1000.0 if embedded_guidance_scale + is not None else None) # predict the noise residual - with torch.autocast( - device_type="cuda", dtype=target_dtype, enabled=autocast_enabled - ): - # concat prompt_embeds_2 and prompt_embeds. Mismach fill with zeros + with torch.autocast(device_type="cuda", + dtype=target_dtype, + enabled=autocast_enabled): + # concat prompt_embeds_2 and prompt_embeds. Mismatch fill with zeros if prompt_embeds_2.shape[-1] != prompt_embeds.shape[-1]: prompt_embeds_2 = F.pad( prompt_embeds_2, - (0, prompt_embeds.shape[2] - prompt_embeds_2.shape[1]), + (0, prompt_embeds.shape[2] - + prompt_embeds_2.shape[1]), value=0, ).unsqueeze(1) encoder_hidden_states = torch.cat( - [prompt_embeds_2, prompt_embeds], dim=1 - ) + [prompt_embeds_2, prompt_embeds], dim=1) noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256) latent_model_input, # [2, 16, 33, 24, 42] encoder_hidden_states, @@ -940,16 +912,13 @@ def __call__( prompt_mask, # [2, 256]fpdb guidance=guidance_expand, return_dict=False, - )[ - 0 - ] + )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf @@ -960,26 +929,29 @@ def __call__( ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs, return_dict=False - )[0] + latents = self.scheduler.step(noise_pred, + t, + latents, + **extra_step_kwargs, + return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop( - "negative_prompt_embeds", negative_prompt_embeds - ) + "negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + (i + 1) > num_warmup_steps and + (i + 1) % self.scheduler.order == 0): if progress_bar is not None: progress_bar.update() if callback is not None and i % callback_steps == 0: @@ -1002,29 +974,25 @@ def __call__( f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}." ) - if ( - hasattr(self.vae.config, "shift_factor") - and self.vae.config.shift_factor - ): - latents = ( - latents / self.vae.config.scaling_factor - + self.vae.config.shift_factor - ) + if (hasattr(self.vae.config, "shift_factor") + and self.vae.config.shift_factor): + latents = (latents / self.vae.config.scaling_factor + + self.vae.config.shift_factor) else: latents = latents / self.vae.config.scaling_factor - with torch.autocast( - device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled - ): + with torch.autocast(device_type="cuda", + dtype=vae_dtype, + enabled=vae_autocast_enabled): if enable_tiling: self.vae.enable_tiling() - image = self.vae.decode( - latents, return_dict=False, generator=generator - )[0] + image = self.vae.decode(latents, + return_dict=False, + generator=generator)[0] else: - image = self.vae.decode( - latents, return_dict=False, generator=generator - )[0] + image = self.vae.decode(latents, + return_dict=False, + generator=generator)[0] if expand_temporal_dim or image.shape[2] == 1: image = image.squeeze(2) diff --git a/fastvideo/models/hunyuan/diffusion/schedulers/__init__.py b/fastvideo/models/hunyuan/diffusion/schedulers/__init__.py index 14f2ba3..2238803 100644 --- a/fastvideo/models/hunyuan/diffusion/schedulers/__init__.py +++ b/fastvideo/models/hunyuan/diffusion/schedulers/__init__.py @@ -1 +1,2 @@ +# ruff: noqa: F401 from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler diff --git a/fastvideo/models/hunyuan/diffusion/schedulers/scheduling_flow_match_discrete.py b/fastvideo/models/hunyuan/diffusion/schedulers/scheduling_flow_match_discrete.py index 07f9c01..69ed531 100644 --- a/fastvideo/models/hunyuan/diffusion/schedulers/scheduling_flow_match_discrete.py +++ b/fastvideo/models/hunyuan/diffusion/schedulers/scheduling_flow_match_discrete.py @@ -20,13 +20,10 @@ from dataclasses import dataclass from typing import Optional, Tuple, Union -import numpy as np import torch - from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.utils import BaseOutput, logging from diffusers.schedulers.scheduling_utils import SchedulerMixin - +from diffusers.utils import BaseOutput, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -83,7 +80,8 @@ def __init__( self.sigmas = sigmas # the value fed to model - self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32) + self.timesteps = (sigmas[:-1] * + num_train_timesteps).to(dtype=torch.float32) self._step_index = None self._begin_index = None @@ -149,8 +147,7 @@ def set_timesteps( self.sigmas = sigmas self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to( - dtype=torch.float32, device=device - ) + dtype=torch.float32, device=device) # Reset step index self._step_index = None @@ -177,9 +174,9 @@ def _init_step_index(self, timestep): else: self._step_index = self._begin_index - def scale_model_input( - self, sample: torch.Tensor, timestep: Optional[int] = None - ) -> torch.Tensor: + def scale_model_input(self, + sample: torch.Tensor, + timestep: Optional[int] = None) -> torch.Tensor: return sample def sd3_time_shift(self, t: torch.Tensor): @@ -217,18 +214,12 @@ def step( returned, otherwise a tuple is returned where the first element is the sample tensor. """ - if ( - isinstance(timestep, int) - or isinstance(timestep, torch.IntTensor) - or isinstance(timestep, torch.LongTensor) - ): - raise ValueError( - ( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep." - ), - ) + if (isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor)): + raise ValueError(( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep."), ) if self.step_index is None: self._init_step_index(timestep) @@ -249,7 +240,7 @@ def step( self._step_index += 1 if not return_dict: - return (prev_sample,) + return (prev_sample, ) return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample) diff --git a/fastvideo/models/hunyuan/idle_config.py b/fastvideo/models/hunyuan/idle_config.py index 3cde32e..f9588d1 100644 --- a/fastvideo/models/hunyuan/idle_config.py +++ b/fastvideo/models/hunyuan/idle_config.py @@ -1,11 +1,14 @@ +# ruff: noqa: F405, F403 import argparse -from .constants import * import re + +from .constants import * from .modules.models import HUNYUAN_VIDEO_CONFIG def parse_args(namespace=None): - parser = argparse.ArgumentParser(description="HunyuanVideo inference script") + parser = argparse.ArgumentParser( + description="HunyuanVideo inference script") parser = add_network_args(parser) parser = add_extra_models_args(parser) @@ -33,7 +36,8 @@ def add_network_args(parser: argparse.ArgumentParser): "--latent-channels", type=str, default=16, - help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, " + help= + "Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, " "it still needs to match the latent channels of the VAE model.", ) group.add_argument( @@ -41,13 +45,15 @@ def add_network_args(parser: argparse.ArgumentParser): type=str, default="bf16", choices=PRECISIONS, - help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.", + help= + "Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.", ) # RoPE - group.add_argument( - "--rope-theta", type=int, default=256, help="Theta used in RoPE." - ) + group.add_argument("--rope-theta", + type=int, + default=256, + help="Theta used in RoPE.") return parser @@ -98,9 +104,10 @@ def add_extra_models_args(parser: argparse.ArgumentParser): default=4096, help="Dimension of the text encoder hidden states.", ) - group.add_argument( - "--text-len", type=int, default=256, help="Maximum length of the text input." - ) + group.add_argument("--text-len", + type=int, + default=256, + help="Maximum length of the text input.") group.add_argument( "--tokenizer", type=str, @@ -131,7 +138,8 @@ def add_extra_models_args(parser: argparse.ArgumentParser): group.add_argument( "--apply-final-norm", action="store_true", - help="Apply final normalization to the used text encoder hidden states.", + help= + "Apply final normalization to the used text encoder hidden states.", ) # - CLIP @@ -195,7 +203,10 @@ def add_denoise_schedule_args(parser: argparse.ArgumentParser): help="If reverse, learning/sampling from t=1 -> t=0.", ) group.add_argument( - "--flow-solver", type=str, default="euler", help="Solver for flow matching.", + "--flow-solver", + type=str, + default="euler", + help="Solver for flow matching.", ) group.add_argument( "--use-linear-quadratic-schedule", @@ -221,13 +232,16 @@ def add_inference_args(parser: argparse.ArgumentParser): "--model-base", type=str, default="ckpts", - help="Root path of all the models, including t2v models and extra models.", + help= + "Root path of all the models, including t2v models and extra models.", ) group.add_argument( "--dit-weight", type=str, - default="ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", - help="Path to the HunyuanVideo model. If None, search the model in the args.model_root." + default= + "ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", + help= + "Path to the HunyuanVideo model. If None, search the model in the args.model_root." "1. If it is a file, load the model directly." "2. If it is a directory, search the model in the directory. Support two types of models: " "1) named `pytorch_model_*.pt`" @@ -238,13 +252,15 @@ def add_inference_args(parser: argparse.ArgumentParser): type=str, default="540p", choices=["540p", "720p"], - help="Root path of all the models, including t2v models and extra models.", + help= + "Root path of all the models, including t2v models and extra models.", ) group.add_argument( "--load-key", type=str, default="module", - help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.", + help= + "Key to load the model states. 'module' for the main model, 'ema' for the EMA model.", ) group.add_argument( "--use-cpu-offload", @@ -268,7 +284,8 @@ def add_inference_args(parser: argparse.ArgumentParser): group.add_argument( "--disable-autocast", action="store_true", - help="Disable autocast for denoising loop and vae decoding in pipeline sampling.", + help= + "Disable autocast for denoising loop and vae decoding in pipeline sampling.", ) group.add_argument( "--save-path", @@ -300,7 +317,8 @@ def add_inference_args(parser: argparse.ArgumentParser): type=int, nargs="+", default=(720, 1280), - help="Video size for training. If a single value is provided, it will be used for both height " + help= + "Video size for training. If a single value is provided, it will be used for both height " "and width. If two values are provided, they will be used for height and width " "respectively.", ) @@ -308,7 +326,8 @@ def add_inference_args(parser: argparse.ArgumentParser): "--video-length", type=int, default=129, - help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1", + help= + "How many frames to sample from a video. if using 3d vae, the number should be 4n+1", ) # --- prompt --- group.add_argument( @@ -322,31 +341,38 @@ def add_inference_args(parser: argparse.ArgumentParser): type=str, default="auto", choices=["file", "random", "fixed", "auto"], - help="Seed type for evaluation. If file, use the seed from the CSV file. If random, generate a " + help= + "Seed type for evaluation. If file, use the seed from the CSV file. If random, generate a " "random seed. If fixed, use the fixed seed given by `--seed`. If auto, `csv` will use the " "seed column if available, otherwise use the fixed `seed` value. `prompt` will use the " "fixed `seed` value.", ) - group.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") + group.add_argument("--seed", + type=int, + default=None, + help="Seed for evaluation.") # Classifier-Free Guidance - group.add_argument( - "--neg-prompt", type=str, default=None, help="Negative prompt for sampling." - ) - group.add_argument( - "--cfg-scale", type=float, default=1.0, help="Classifier free guidance scale." - ) + group.add_argument("--neg-prompt", + type=str, + default=None, + help="Negative prompt for sampling.") + group.add_argument("--cfg-scale", + type=float, + default=1.0, + help="Classifier free guidance scale.") group.add_argument( "--embedded-cfg-scale", type=float, default=6.0, - help="Embeded classifier free guidance scale.", + help="Embedded classifier free guidance scale.", ) group.add_argument( "--reproduce", action="store_true", - help="Enable reproducibility by setting random seeds and deterministic algorithms.", + help= + "Enable reproducibility by setting random seeds and deterministic algorithms.", ) return parser @@ -357,10 +383,16 @@ def add_parallel_args(parser: argparse.ArgumentParser): # ======================== Model loads ======================== group.add_argument( - "--ulysses-degree", type=int, default=1, help="Ulysses degree.", + "--ulysses-degree", + type=int, + default=1, + help="Ulysses degree.", ) group.add_argument( - "--ring-degree", type=int, default=1, help="Ulysses degree.", + "--ring-degree", + type=int, + default=1, + help="Ulysses degree.", ) return parser diff --git a/fastvideo/models/hunyuan/inference.py b/fastvideo/models/hunyuan/inference.py index 825e485..17723fa 100644 --- a/fastvideo/models/hunyuan/inference.py +++ b/fastvideo/models/hunyuan/inference.py @@ -1,33 +1,27 @@ import os -import time import random -import functools -from typing import List, Optional, Tuple, Union - +import time from pathlib import Path -from loguru import logger import torch -import torch.distributed as dist -from fastvideo.models.hunyuan.constants import ( - PROMPT_TEMPLATE, - NEGATIVE_PROMPT, - PRECISION_TO_TYPE, -) -from fastvideo.models.hunyuan.vae import load_vae +from loguru import logger +from safetensors.torch import load_file as safetensors_load_file + +from fastvideo.models.hunyuan.constants import (NEGATIVE_PROMPT, + PRECISION_TO_TYPE, + PROMPT_TEMPLATE) +from fastvideo.models.hunyuan.diffusion.pipelines import HunyuanVideoPipeline +from fastvideo.models.hunyuan.diffusion.schedulers import \ + FlowMatchDiscreteScheduler from fastvideo.models.hunyuan.modules import load_model from fastvideo.models.hunyuan.text_encoder import TextEncoder from fastvideo.models.hunyuan.utils.data_utils import align_to -from fastvideo.models.hunyuan.diffusion.schedulers import FlowMatchDiscreteScheduler -from fastvideo.models.hunyuan.diffusion.pipelines import HunyuanVideoPipeline -from safetensors.torch import load_file as safetensors_load_file -from fastvideo.utils.parallel_states import ( - initialize_sequence_parallel_state, - nccl_info, -) +from fastvideo.models.hunyuan.vae import load_vae +from fastvideo.utils.parallel_states import nccl_info class Inference(object): + def __init__( self, args, @@ -53,18 +47,17 @@ def __init__( self.use_cpu_offload = use_cpu_offload self.args = args - self.device = ( - device - if device is not None - else "cuda" - if torch.cuda.is_available() - else "cpu" - ) + self.device = (device if device is not None else + "cuda" if torch.cuda.is_available() else "cpu") self.logger = logger self.parallel_args = parallel_args @classmethod - def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs): + def from_pretrained(cls, + pretrained_model_path, + args, + device=None, + **kwargs): """ Initialize the Inference pipeline. @@ -74,7 +67,8 @@ def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs): device (int): The device for inference. Default is 0. """ # ======================================================================== - logger.info(f"Got text-to-video model root path: {pretrained_model_path}") + logger.info( + f"Got text-to-video model root path: {pretrained_model_path}") # ==================== Initialize Distributed Environment ================ if nccl_info.sp_size > 1: @@ -91,7 +85,10 @@ def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs): # =========================== Build main model =========================== logger.info("Building model...") - factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]} + factor_kwargs = { + "device": device, + "dtype": PRECISION_TO_TYPE[args.precision] + } in_channels = args.latent_channels out_channels = args.latent_channels @@ -118,27 +115,22 @@ def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs): # Text encoder if args.prompt_template_video is not None: crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get( - "crop_start", 0 - ) + "crop_start", 0) elif args.prompt_template is not None: - crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0) + crop_start = PROMPT_TEMPLATE[args.prompt_template].get( + "crop_start", 0) else: crop_start = 0 max_length = args.text_len + crop_start # prompt_template - prompt_template = ( - PROMPT_TEMPLATE[args.prompt_template] - if args.prompt_template is not None - else None - ) + prompt_template = (PROMPT_TEMPLATE[args.prompt_template] + if args.prompt_template is not None else None) # prompt_template_video - prompt_template_video = ( - PROMPT_TEMPLATE[args.prompt_template_video] - if args.prompt_template_video is not None - else None - ) + prompt_template_video = (PROMPT_TEMPLATE[args.prompt_template_video] + if args.prompt_template_video is not None else + None) text_encoder = TextEncoder( text_encoder_type=args.text_encoder, @@ -192,7 +184,9 @@ def load_state_dict(args, model, pretrained_model_path): model_path = dit_weight / f"pytorch_model_{load_key}.pt" bare_model = True elif any(str(f).endswith("_model_states.pt") for f in files): - files = [f for f in files if str(f).endswith("_model_states.pt")] + files = [ + f for f in files if str(f).endswith("_model_states.pt") + ] model_path = files[0] if len(files) > 1: logger.warning( @@ -216,7 +210,9 @@ def load_state_dict(args, model, pretrained_model_path): model_path = dit_weight / f"pytorch_model_{load_key}.pt" bare_model = True elif any(str(f).endswith("_model_states.pt") for f in files): - files = [f for f in files if str(f).endswith("_model_states.pt")] + files = [ + f for f in files if str(f).endswith("_model_states.pt") + ] model_path = files[0] if len(files) > 1: logger.warning( @@ -245,13 +241,13 @@ def load_state_dict(args, model, pretrained_model_path): state_dict = safetensors_load_file(model_path) elif model_path.suffix == ".pt": # Use torch for .pt files - state_dict = torch.load( - model_path, map_location=lambda storage, loc: storage - ) + state_dict = torch.load(model_path, + map_location=lambda storage, loc: storage) else: raise ValueError(f"Unsupported file format: {model_path}") - if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict): + if bare_model == "unknown" and ("ema" in state_dict + or "module" in state_dict): bare_model = False if bare_model is False: if load_key in state_dict: @@ -259,8 +255,7 @@ def load_state_dict(args, model, pretrained_model_path): else: raise KeyError( f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint " - f"are: {list(state_dict.keys())}." - ) + f"are: {list(state_dict.keys())}.") model.load_state_dict(state_dict, strict=True) return model @@ -269,15 +264,18 @@ def parse_size(size): if isinstance(size, int): size = [size] if not isinstance(size, (list, tuple)): - raise ValueError(f"Size must be an integer or (height, width), got {size}.") + raise ValueError( + f"Size must be an integer or (height, width), got {size}.") if len(size) == 1: size = [size[0], size[0]] if len(size) != 2: - raise ValueError(f"Size must be an integer or (height, width), got {size}.") + raise ValueError( + f"Size must be an integer or (height, width), got {size}.") return size class HunyuanVideoSampler(Inference): + def __init__( self, args, @@ -403,15 +401,13 @@ def predict( ] elif isinstance(seed, int): seeds = [ - seed + i - for _ in range(batch_size) + seed + i for _ in range(batch_size) for i in range(num_videos_per_prompt) ] elif isinstance(seed, (list, tuple)): if len(seed) == batch_size: seeds = [ - int(seed[i]) + j - for i in range(batch_size) + int(seed[i]) + j for i in range(batch_size) for j in range(num_videos_per_prompt) ] elif len(seed) == batch_size * num_videos_per_prompt: @@ -426,7 +422,9 @@ def predict( f"Seed must be an integer, a list of integers, or None, got {seed}." ) # Peiyuan: using GPU seed will cause A100 and H100 to generate different results... - generator = [torch.Generator("cpu").manual_seed(seed) for seed in seeds] + generator = [ + torch.Generator("cpu").manual_seed(seed) for seed in seeds + ] out_dict["seeds"] = seeds # ======================================================================== @@ -455,7 +453,8 @@ def predict( # Arguments: prompt, new_prompt, negative_prompt # ======================================================================== if not isinstance(prompt, str): - raise TypeError(f"`prompt` must be a string, but got {type(prompt)}") + raise TypeError( + f"`prompt` must be a string, but got {type(prompt)}") prompt = [prompt.strip()] # negative prompt @@ -478,9 +477,11 @@ def predict( self.pipeline.scheduler = scheduler if "884" in self.args.vae: - latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] + latents_size = [(video_length - 1) // 4 + 1, height // 8, + width // 8] elif "888" in self.args.vae: - latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8] + latents_size = [(video_length - 1) // 8 + 1, height // 8, + width // 8] n_tokens = latents_size[0] * latents_size[1] * latents_size[2] # ======================================================================== diff --git a/fastvideo/models/hunyuan/modules/__init__.py b/fastvideo/models/hunyuan/modules/__init__.py index f54c6b9..c85b51a 100644 --- a/fastvideo/models/hunyuan/modules/__init__.py +++ b/fastvideo/models/hunyuan/modules/__init__.py @@ -1,4 +1,4 @@ -from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG +from .models import HUNYUAN_VIDEO_CONFIG, HYVideoDiffusionTransformer def load_model(args, in_channels, out_channels, factor_kwargs): diff --git a/fastvideo/models/hunyuan/modules/attenion.py b/fastvideo/models/hunyuan/modules/attenion.py index aa0cf88..a010388 100644 --- a/fastvideo/models/hunyuan/modules/attenion.py +++ b/fastvideo/models/hunyuan/modules/attenion.py @@ -1,18 +1,19 @@ -import importlib.metadata -import math - import torch -import torch.nn as nn import torch.nn.functional as F - -from fastvideo.utils.parallel_states import get_sequence_parallel_state, nccl_info -from fastvideo.utils.communications import all_gather, all_to_all_4D from fastvideo.models.flash_attn_no_pad import flash_attn_no_pad +from fastvideo.utils.communications import all_gather, all_to_all_4D +from fastvideo.utils.parallel_states import (get_sequence_parallel_state, + nccl_info) def attention( - q, k, v, drop_rate=0, attn_mask=None, causal=False, + q, + k, + v, + drop_rate=0, + attn_mask=None, + causal=False, ): qkv = torch.stack([q, k, v], dim=2) @@ -20,9 +21,11 @@ def attention( if attn_mask is not None and attn_mask.dtype != torch.bool: attn_mask = attn_mask.bool() - x = flash_attn_no_pad( - qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None - ) + x = flash_attn_no_pad(qkv, + attn_mask, + causal=causal, + dropout_p=drop_rate, + softmax_scale=None) b, s, a, d = x.shape out = x.reshape(b, s, -1) @@ -44,8 +47,7 @@ def parallel_attention(q, k, v, img_q_len, img_kv_len, text_mask): def shrink_head(encoder_state, dim): local_heads = encoder_state.shape[dim] // nccl_info.sp_size return encoder_state.narrow( - dim, nccl_info.rank_within_group * local_heads, local_heads - ) + dim, nccl_info.rank_within_group * local_heads, local_heads) encoder_query = shrink_head(encoder_query, dim=2) encoder_key = shrink_head(encoder_key, dim=2) @@ -63,16 +65,20 @@ def shrink_head(encoder_state, dim): qkv = torch.stack([query, key, value], dim=2) attn_mask = F.pad(text_mask, (sequence_length, 0), value=True) - hidden_states = flash_attn_no_pad( - qkv, attn_mask, causal=False, dropout_p=0.0, softmax_scale=None - ) + hidden_states = flash_attn_no_pad(qkv, + attn_mask, + causal=False, + dropout_p=0.0, + softmax_scale=None) hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( - (sequence_length, encoder_sequence_length), dim=1 - ) + (sequence_length, encoder_sequence_length), dim=1) if get_sequence_parallel_state(): - hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2) - encoder_hidden_states = all_gather(encoder_hidden_states, dim=2).contiguous() + hidden_states = all_to_all_4D(hidden_states, + scatter_dim=1, + gather_dim=2) + encoder_hidden_states = all_gather(encoder_hidden_states, + dim=2).contiguous() hidden_states = hidden_states.to(query.dtype) encoder_hidden_states = encoder_hidden_states.to(query.dtype) diff --git a/fastvideo/models/hunyuan/modules/embed_layers.py b/fastvideo/models/hunyuan/modules/embed_layers.py index 3134f37..d2cb9bb 100644 --- a/fastvideo/models/hunyuan/modules/embed_layers.py +++ b/fastvideo/models/hunyuan/modules/embed_layers.py @@ -1,7 +1,7 @@ import math + import torch import torch.nn as nn -from einops import rearrange, repeat from ..utils.helpers import to_2tuple @@ -45,7 +45,8 @@ def __init__( bias=bias, **factory_kwargs, ) - nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) + nn.init.xavier_uniform_( + self.proj.weight.view(self.proj.weight.size(0), -1)) if bias: nn.init.zeros_(self.proj.bias) @@ -66,7 +67,12 @@ class TextProjection(nn.Module): Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py """ - def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): + def __init__(self, + in_channels, + hidden_size, + act_layer, + dtype=None, + device=None): factory_kwargs = {"dtype": dtype, "device": device} super().__init__() self.linear_1 = nn.Linear( @@ -105,15 +111,14 @@ def timestep_embedding(t, dim, max_period=10000): .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py """ half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32) - / half - ).to(device=t.device) + freqs = torch.exp(-math.log(max_period) * + torch.arange(start=0, end=half, dtype=torch.float32) / + half).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding @@ -140,9 +145,10 @@ def __init__( out_size = hidden_size self.mlp = nn.Sequential( - nn.Linear( - frequency_embedding_size, hidden_size, bias=True, **factory_kwargs - ), + nn.Linear(frequency_embedding_size, + hidden_size, + bias=True, + **factory_kwargs), act_layer(), nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), ) @@ -150,8 +156,8 @@ def __init__( nn.init.normal_(self.mlp[2].weight, std=0.02) def forward(self, t): - t_freq = timestep_embedding( - t, self.frequency_embedding_size, self.max_period - ).type(self.mlp[0].weight.dtype) + t_freq = timestep_embedding(t, self.frequency_embedding_size, + self.max_period).type( + self.mlp[0].weight.dtype) t_emb = self.mlp(t_freq) return t_emb diff --git a/fastvideo/models/hunyuan/modules/mlp_layers.py b/fastvideo/models/hunyuan/modules/mlp_layers.py index ecaa08e..f84894e 100644 --- a/fastvideo/models/hunyuan/modules/mlp_layers.py +++ b/fastvideo/models/hunyuan/modules/mlp_layers.py @@ -6,8 +6,8 @@ import torch import torch.nn as nn -from .modulate_layers import modulate from ..utils.helpers import to_2tuple +from .modulate_layers import modulate class MLP(nn.Module): @@ -32,21 +32,21 @@ def __init__( hidden_channels = hidden_channels or in_channels bias = to_2tuple(bias) drop_probs = to_2tuple(drop) - linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + linear_layer = partial(nn.Conv2d, + kernel_size=1) if use_conv else nn.Linear - self.fc1 = linear_layer( - in_channels, hidden_channels, bias=bias[0], **factory_kwargs - ) + self.fc1 = linear_layer(in_channels, + hidden_channels, + bias=bias[0], + **factory_kwargs) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) - self.norm = ( - norm_layer(hidden_channels, **factory_kwargs) - if norm_layer is not None - else nn.Identity() - ) - self.fc2 = linear_layer( - hidden_channels, out_features, bias=bias[1], **factory_kwargs - ) + self.norm = (norm_layer(hidden_channels, **factory_kwargs) + if norm_layer is not None else nn.Identity()) + self.fc2 = linear_layer(hidden_channels, + out_features, + bias=bias[1], + **factory_kwargs) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): @@ -66,9 +66,15 @@ class MLPEmbedder(nn.Module): def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs) + self.in_layer = nn.Linear(in_dim, + hidden_dim, + bias=True, + **factory_kwargs) self.silu = nn.SiLU() - self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs) + self.out_layer = nn.Linear(hidden_dim, + hidden_dim, + bias=True, + **factory_kwargs) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.out_layer(self.silu(self.in_layer(x))) @@ -77,16 +83,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FinalLayer(nn.Module): """The final layer of DiT.""" - def __init__( - self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None - ): + def __init__(self, + hidden_size, + patch_size, + out_channels, + act_layer, + device=None, + dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() # Just use LayerNorm for the final layer - self.norm_final = nn.LayerNorm( - hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs - ) + self.norm_final = nn.LayerNorm(hidden_size, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) if isinstance(patch_size, int): self.linear = nn.Linear( hidden_size, @@ -106,7 +117,10 @@ def __init__( # Here we don't distinguish between the modulate types. Just use the simple one. self.adaLN_modulation = nn.Sequential( act_layer(), - nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + nn.Linear(hidden_size, + 2 * hidden_size, + bias=True, + **factory_kwargs), ) # Zero-initialize the modulation nn.init.zeros_(self.adaLN_modulation[1].weight) diff --git a/fastvideo/models/hunyuan/modules/models.py b/fastvideo/models/hunyuan/modules/models.py index f9e5dcb..759897e 100644 --- a/fastvideo/models/hunyuan/modules/models.py +++ b/fastvideo/models/hunyuan/modules/models.py @@ -1,29 +1,28 @@ -from typing import Any, List, Tuple, Optional, Union, Dict -from einops import rearrange +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn -import torch.nn.functional as F - -from diffusers.models import ModelMixin from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models import ModelMixin +from einops import rearrange + +from fastvideo.models.hunyuan.modules.posemb_layers import \ + get_nd_rotary_pos_embed +from fastvideo.utils.parallel_states import nccl_info from .activation_layers import get_activation_layer -from .norm_layers import get_norm_layer -from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection from .attenion import parallel_attention +from .embed_layers import PatchEmbed, TextProjection, TimestepEmbedder +from .mlp_layers import MLP, FinalLayer, MLPEmbedder +from .modulate_layers import ModulateDiT, apply_gate, modulate +from .norm_layers import get_norm_layer from .posemb_layers import apply_rotary_emb -from .mlp_layers import MLP, MLPEmbedder, FinalLayer -from .modulate_layers import ModulateDiT, modulate, apply_gate from .token_refiner import SingleTokenRefiner -from fastvideo.models.hunyuan.modules.posemb_layers import get_nd_rotary_pos_embed - -from fastvideo.utils.parallel_states import nccl_info class MMDoubleStreamBlock(nn.Module): """ - A multimodal dit block with seperate modulation for + A multimodal dit block with separate modulation for text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206 (Flux.1): https://github.com/black-forest-labs/flux """ @@ -54,31 +53,31 @@ def __init__( act_layer=get_activation_layer("silu"), **factory_kwargs, ) - self.img_norm1 = nn.LayerNorm( - hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs - ) + self.img_norm1 = nn.LayerNorm(hidden_size, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) - self.img_attn_qkv = nn.Linear( - hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs - ) + self.img_attn_qkv = nn.Linear(hidden_size, + hidden_size * 3, + bias=qkv_bias, + **factory_kwargs) qk_norm_layer = get_norm_layer(qk_norm_type) - self.img_attn_q_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) - if qk_norm - else nn.Identity() - ) - self.img_attn_k_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) - if qk_norm - else nn.Identity() - ) - self.img_attn_proj = nn.Linear( - hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs - ) - - self.img_norm2 = nn.LayerNorm( - hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs - ) + self.img_attn_q_norm = (qk_norm_layer( + head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm else nn.Identity()) + self.img_attn_k_norm = (qk_norm_layer( + head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm else nn.Identity()) + self.img_attn_proj = nn.Linear(hidden_size, + hidden_size, + bias=qkv_bias, + **factory_kwargs) + + self.img_norm2 = nn.LayerNorm(hidden_size, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) self.img_mlp = MLP( hidden_size, mlp_hidden_dim, @@ -93,30 +92,30 @@ def __init__( act_layer=get_activation_layer("silu"), **factory_kwargs, ) - self.txt_norm1 = nn.LayerNorm( - hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs - ) - - self.txt_attn_qkv = nn.Linear( - hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs - ) - self.txt_attn_q_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) - if qk_norm - else nn.Identity() - ) - self.txt_attn_k_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) - if qk_norm - else nn.Identity() - ) - self.txt_attn_proj = nn.Linear( - hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs - ) - - self.txt_norm2 = nn.LayerNorm( - hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs - ) + self.txt_norm1 = nn.LayerNorm(hidden_size, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) + + self.txt_attn_qkv = nn.Linear(hidden_size, + hidden_size * 3, + bias=qkv_bias, + **factory_kwargs) + self.txt_attn_q_norm = (qk_norm_layer( + head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm else nn.Identity()) + self.txt_attn_k_norm = (qk_norm_layer( + head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm else nn.Identity()) + self.txt_attn_proj = nn.Linear(hidden_size, + hidden_size, + bias=qkv_bias, + **factory_kwargs) + + self.txt_norm2 = nn.LayerNorm(hidden_size, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) self.txt_mlp = MLP( hidden_size, mlp_hidden_dim, @@ -159,13 +158,14 @@ def forward( # Prepare image for attention. img_modulated = self.img_norm1(img) - img_modulated = modulate( - img_modulated, shift=img_mod1_shift, scale=img_mod1_scale - ) + img_modulated = modulate(img_modulated, + shift=img_mod1_shift, + scale=img_mod1_scale) img_qkv = self.img_attn_qkv(img_modulated) - img_q, img_k, img_v = rearrange( - img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num - ) + img_q, img_k, img_v = rearrange(img_qkv, + "B L (K H D) -> K B L H D", + K=3, + H=self.heads_num) # Apply QK-Norm if needed img_q = self.img_attn_q_norm(img_q).to(img_v) img_k = self.img_attn_k_norm(img_k).to(img_v) @@ -176,15 +176,18 @@ def forward( def shrink_head(encoder_state, dim): local_heads = encoder_state.shape[dim] // nccl_info.sp_size return encoder_state.narrow( - dim, nccl_info.rank_within_group * local_heads, local_heads - ) + dim, nccl_info.rank_within_group * local_heads, + local_heads) freqs_cis = ( shrink_head(freqs_cis[0], dim=0), shrink_head(freqs_cis[1], dim=0), ) - img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) + img_qq, img_kk = apply_rotary_emb(img_q, + img_k, + freqs_cis, + head_first=False) assert ( img_qq.shape == img_q.shape and img_kk.shape == img_k.shape ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" @@ -192,13 +195,14 @@ def shrink_head(encoder_state, dim): # Prepare txt for attention. txt_modulated = self.txt_norm1(txt) - txt_modulated = modulate( - txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale - ) + txt_modulated = modulate(txt_modulated, + shift=txt_mod1_shift, + scale=txt_mod1_scale) txt_qkv = self.txt_attn_qkv(txt_modulated) - txt_q, txt_k, txt_v = rearrange( - txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num - ) + txt_q, txt_k, txt_v = rearrange(txt_qkv, + "B L (K H D) -> K B L H D", + K=3, + H=self.heads_num) # Apply QK-Norm if needed. txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) @@ -214,27 +218,27 @@ def shrink_head(encoder_state, dim): # attention computation end - img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] + img_attn, txt_attn = attn[:, :img.shape[1]], attn[:, img.shape[1]:] - # Calculate the img bloks. - img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) + # Calculate the img blocks. + img = img + apply_gate(self.img_attn_proj(img_attn), + gate=img_mod1_gate) img = img + apply_gate( self.img_mlp( - modulate( - self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale - ) - ), + modulate(self.img_norm2(img), + shift=img_mod2_shift, + scale=img_mod2_scale)), gate=img_mod2_gate, ) - # Calculate the txt bloks. - txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) + # Calculate the txt blocks. + txt = txt + apply_gate(self.txt_attn_proj(txt_attn), + gate=txt_mod1_gate) txt = txt + apply_gate( self.txt_mlp( - modulate( - self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale - ) - ), + modulate(self.txt_norm2(txt), + shift=txt_mod2_shift, + scale=txt_mod2_scale)), gate=txt_mod2_gate, ) @@ -270,32 +274,27 @@ def __init__( head_dim = hidden_size // heads_num mlp_hidden_dim = int(hidden_size * mlp_width_ratio) self.mlp_hidden_dim = mlp_hidden_dim - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 # qkv and mlp_in - self.linear1 = nn.Linear( - hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs - ) + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, + **factory_kwargs) # proj and mlp_out - self.linear2 = nn.Linear( - hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs - ) + self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, + **factory_kwargs) qk_norm_layer = get_norm_layer(qk_norm_type) - self.q_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) - if qk_norm - else nn.Identity() - ) - self.k_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) - if qk_norm - else nn.Identity() - ) - - self.pre_norm = nn.LayerNorm( - hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs - ) + self.q_norm = (qk_norm_layer( + head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm else nn.Identity()) + self.k_norm = (qk_norm_layer( + head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm else nn.Identity()) + + self.pre_norm = nn.LayerNorm(hidden_size, + elementwise_affine=False, + eps=1e-6, + **factory_kwargs) self.mlp_act = get_activation_layer(mlp_act_type)() self.modulation = ModulateDiT( @@ -322,11 +321,14 @@ def forward( ) -> torch.Tensor: mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale) - qkv, mlp = torch.split( - self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 - ) + qkv, mlp = torch.split(self.linear1(x_mod), + [3 * self.hidden_size, self.mlp_hidden_dim], + dim=-1) - q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + q, k, v = rearrange(qkv, + "B L (K H D) -> K B L H D", + K=3, + H=self.heads_num) # Apply QK-Norm if needed. q = self.q_norm(q).to(v) @@ -335,15 +337,18 @@ def forward( def shrink_head(encoder_state, dim): local_heads = encoder_state.shape[dim] // nccl_info.sp_size return encoder_state.narrow( - dim, nccl_info.rank_within_group * local_heads, local_heads - ) + dim, nccl_info.rank_within_group * local_heads, local_heads) - freqs_cis = (shrink_head(freqs_cis[0], dim=0), shrink_head(freqs_cis[1], dim=0)) + freqs_cis = (shrink_head(freqs_cis[0], + dim=0), shrink_head(freqs_cis[1], dim=0)) img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :] - img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) + img_qq, img_kk = apply_rotary_emb(img_q, + img_k, + freqs_cis, + head_first=False) assert ( img_qq.shape == img_q.shape and img_kk.shape == img_k.shape ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" @@ -464,15 +469,13 @@ def __init__( pe_dim = hidden_size // heads_num if sum(rope_dim_list) != pe_dim: raise ValueError( - f"Got {rope_dim_list} but expected positional dim {pe_dim}" - ) + f"Got {rope_dim_list} but expected positional dim {pe_dim}") self.hidden_size = hidden_size self.heads_num = heads_num # image projection - self.img_in = PatchEmbed( - self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs - ) + self.img_in = PatchEmbed(self.patch_size, self.in_channels, + self.hidden_size, **factory_kwargs) # text projection if self.text_projection == "linear": @@ -492,60 +495,48 @@ def __init__( ) else: raise NotImplementedError( - f"Unsupported text_projection: {self.text_projection}" - ) + f"Unsupported text_projection: {self.text_projection}") # time modulation - self.time_in = TimestepEmbedder( - self.hidden_size, get_activation_layer("silu"), **factory_kwargs - ) + self.time_in = TimestepEmbedder(self.hidden_size, + get_activation_layer("silu"), + **factory_kwargs) # text modulation - self.vector_in = MLPEmbedder( - self.config.text_states_dim_2, self.hidden_size, **factory_kwargs - ) + self.vector_in = MLPEmbedder(self.config.text_states_dim_2, + self.hidden_size, **factory_kwargs) # guidance modulation - self.guidance_in = ( - TimestepEmbedder( - self.hidden_size, get_activation_layer("silu"), **factory_kwargs - ) - if guidance_embed - else None - ) + self.guidance_in = (TimestepEmbedder( + self.hidden_size, get_activation_layer("silu"), **factory_kwargs) + if guidance_embed else None) # double blocks - self.double_blocks = nn.ModuleList( - [ - MMDoubleStreamBlock( - self.hidden_size, - self.heads_num, - mlp_width_ratio=mlp_width_ratio, - mlp_act_type=mlp_act_type, - qk_norm=qk_norm, - qk_norm_type=qk_norm_type, - qkv_bias=qkv_bias, - **factory_kwargs, - ) - for _ in range(mm_double_blocks_depth) - ] - ) + self.double_blocks = nn.ModuleList([ + MMDoubleStreamBlock( + self.hidden_size, + self.heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_act_type=mlp_act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + **factory_kwargs, + ) for _ in range(mm_double_blocks_depth) + ]) # single blocks - self.single_blocks = nn.ModuleList( - [ - MMSingleStreamBlock( - self.hidden_size, - self.heads_num, - mlp_width_ratio=mlp_width_ratio, - mlp_act_type=mlp_act_type, - qk_norm=qk_norm, - qk_norm_type=qk_norm_type, - **factory_kwargs, - ) - for _ in range(mm_single_blocks_depth) - ] - ) + self.single_blocks = nn.ModuleList([ + MMSingleStreamBlock( + self.hidden_size, + self.heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_act_type=mlp_act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + **factory_kwargs, + ) for _ in range(mm_single_blocks_depth) + ]) self.final_layer = FinalLayer( self.hidden_size, @@ -569,11 +560,13 @@ def disable_deterministic(self): def get_rotary_pos_embed(self, rope_sizes): target_ndim = 3 - ndim = 5 - 2 + head_dim = self.hidden_size // self.heads_num rope_dim_list = self.rope_dim_list if rope_dim_list is None: - rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + rope_dim_list = [ + head_dim // target_ndim for _ in range(target_ndim) + ] assert ( sum(rope_dim_list) == head_dim ), "sum(rope_dim_list) should equal to head_dim of attention layer" @@ -605,21 +598,21 @@ def forward( return_dict: bool = False, guidance=None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - if guidance == None: - guidance = torch.tensor( - [6016.0], device=hidden_states.device, dtype=torch.bfloat16 - ) - out = {} + if guidance is None: + guidance = torch.tensor([6016.0], + device=hidden_states.device, + dtype=torch.bfloat16) img = x = hidden_states text_mask = encoder_attention_mask t = timestep txt = encoder_hidden_states[:, 1:] - text_states_2 = encoder_hidden_states[:, 0, : self.config.text_states_dim_2] - _, _, ot, oh, ow = x.shape + text_states_2 = encoder_hidden_states[:, 0, :self.config. + text_states_dim_2] + _, _, ot, oh, ow = x.shape # codespell:ignore tt, th, tw = ( - ot // self.patch_size[0], - oh // self.patch_size[1], - ow // self.patch_size[2], + ot // self.patch_size[0], # codespell:ignore + oh // self.patch_size[1], # codespell:ignore + ow // self.patch_size[2], # codespell:ignore ) original_tt = nccl_info.sp_size * tt freqs_cos, freqs_sin = self.get_rotary_pos_embed((original_tt, th, tw)) @@ -644,11 +637,11 @@ def forward( if self.text_projection == "linear": txt = self.txt_in(txt) elif self.text_projection == "single_refiner": - txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None) + txt = self.txt_in(txt, t, + text_mask if self.use_attention_mask else None) else: raise NotImplementedError( - f"Unsupported text_projection: {self.text_projection}" - ) + f"Unsupported text_projection: {self.text_projection}") txt_seq_len = txt.shape[1] img_seq_len = img.shape[1] @@ -681,10 +674,11 @@ def forward( img = x[:, :img_seq_len, ...] # ---------------------------- Final layer ------------------------------ - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + img = self.final_layer(img, + vec) # (N, T, patch_size ** 2 * out_channels) img = self.unpatchify(img, tt, th, tw) - assert return_dict == False, "return_dict is not supported." + assert not return_dict, "return_dict is not supported." if output_features: features_list = torch.stack(features_list, dim=0) else: @@ -708,25 +702,24 @@ def unpatchify(self, x, t, h, w): def params_count(self): counts = { - "double": sum( - [ - sum(p.numel() for p in block.img_attn_qkv.parameters()) - + sum(p.numel() for p in block.img_attn_proj.parameters()) - + sum(p.numel() for p in block.img_mlp.parameters()) - + sum(p.numel() for p in block.txt_attn_qkv.parameters()) - + sum(p.numel() for p in block.txt_attn_proj.parameters()) - + sum(p.numel() for p in block.txt_mlp.parameters()) - for block in self.double_blocks - ] - ), - "single": sum( - [ - sum(p.numel() for p in block.linear1.parameters()) - + sum(p.numel() for p in block.linear2.parameters()) - for block in self.single_blocks - ] - ), - "total": sum(p.numel() for p in self.parameters()), + "double": + sum([ + sum(p.numel() for p in block.img_attn_qkv.parameters()) + + sum(p.numel() for p in block.img_attn_proj.parameters()) + + sum(p.numel() for p in block.img_mlp.parameters()) + + sum(p.numel() for p in block.txt_attn_qkv.parameters()) + + sum(p.numel() for p in block.txt_attn_proj.parameters()) + + sum(p.numel() for p in block.txt_mlp.parameters()) + for block in self.double_blocks + ]), + "single": + sum([ + sum(p.numel() for p in block.linear1.parameters()) + + sum(p.numel() for p in block.linear2.parameters()) + for block in self.single_blocks + ]), + "total": + sum(p.numel() for p in self.parameters()), } counts["attn+mlp"] = counts["double"] + counts["single"] return counts diff --git a/fastvideo/models/hunyuan/modules/modulate_layers.py b/fastvideo/models/hunyuan/modules/modulate_layers.py index 63c66b3..c7a3636 100644 --- a/fastvideo/models/hunyuan/modules/modulate_layers.py +++ b/fastvideo/models/hunyuan/modules/modulate_layers.py @@ -18,9 +18,10 @@ def __init__( factory_kwargs = {"dtype": dtype, "device": device} super().__init__() self.act = act_layer() - self.linear = nn.Linear( - hidden_size, factor * hidden_size, bias=True, **factory_kwargs - ) + self.linear = nn.Linear(hidden_size, + factor * hidden_size, + bias=True, + **factory_kwargs) # Zero-initialize the modulation nn.init.zeros_(self.linear.weight) nn.init.zeros_(self.linear.bias) @@ -70,6 +71,7 @@ def apply_gate(x, gate=None, tanh=False): def ckpt_wrapper(module): + def ckpt_forward(*inputs): outputs = module(*inputs) return outputs @@ -77,11 +79,8 @@ def ckpt_forward(*inputs): return ckpt_forward -import torch -import torch.nn as nn - - class RMSNorm(nn.Module): + def __init__( self, dim: int, @@ -153,4 +152,5 @@ def get_norm_layer(norm_layer): elif norm_layer == "rms": return RMSNorm else: - raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + raise NotImplementedError( + f"Norm layer {norm_layer} is not implemented") diff --git a/fastvideo/models/hunyuan/modules/norm_layers.py b/fastvideo/models/hunyuan/modules/norm_layers.py index d8c73b1..0e1e590 100644 --- a/fastvideo/models/hunyuan/modules/norm_layers.py +++ b/fastvideo/models/hunyuan/modules/norm_layers.py @@ -3,6 +3,7 @@ class RMSNorm(nn.Module): + def __init__( self, dim: int, @@ -74,4 +75,5 @@ def get_norm_layer(norm_layer): elif norm_layer == "rms": return RMSNorm else: - raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + raise NotImplementedError( + f"Norm layer {norm_layer} is not implemented") diff --git a/fastvideo/models/hunyuan/modules/posemb_layers.py b/fastvideo/models/hunyuan/modules/posemb_layers.py index dfce82c..2c92471 100644 --- a/fastvideo/models/hunyuan/modules/posemb_layers.py +++ b/fastvideo/models/hunyuan/modules/posemb_layers.py @@ -1,10 +1,11 @@ +from typing import List, Tuple, Union + import torch -from typing import Union, Tuple, List def _to_tuple(x, dim=2): if isinstance(x, int): - return (x,) * dim + return (x, ) * dim elif len(x) == dim: return x else: @@ -29,7 +30,7 @@ def get_meshgrid_nd(start, *args, dim=2): if len(args) == 0: # start is grid_size num = _to_tuple(start, dim=dim) - start = (0,) * dim + start = (0, ) * dim stop = num elif len(args) == 1: # start is start, args[0] is stop, step is 1 @@ -108,7 +109,10 @@ def reshape_for_broadcast( x.shape[1], x.shape[-1], ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + shape = [ + d if i == 1 or i == ndim - 1 else 1 + for i, d in enumerate(x.shape) + ] return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) else: # freqs_cis: values in complex space @@ -126,14 +130,16 @@ def reshape_for_broadcast( x.shape[1], x.shape[-1], ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + shape = [ + d if i == 1 or i == ndim - 1 else 1 + for i, d in enumerate(x.shape) + ] return freqs_cis.view(*shape) def rotate_half(x): - x_real, x_imag = ( - x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) - ) # [B, S, H, D//2] + x_real, x_imag = (x.float().reshape(*x.shape[:-1], -1, + 2).unbind(-1)) # [B, S, H, D//2] return torch.stack([-x_imag, x_real], dim=-1).flatten(3) @@ -171,18 +177,15 @@ def apply_rotary_emb( xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) else: # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex) - xq_ = torch.view_as_complex( - xq.float().reshape(*xq.shape[:-1], -1, 2) - ) # [B, S, H, D//2] + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, + 2)) # [B, S, H, D//2] freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to( - xq.device - ) # [S, D//2] --> [1, S, 1, D//2] + xq.device) # [S, D//2] --> [1, S, 1, D//2] # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin) # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) - xk_ = torch.view_as_complex( - xk.float().reshape(*xk.shape[:-1], -1, 2) - ) # [B, S, H, D//2] + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, + 2)) # [B, S, H, D//2] xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) return xq_out, xk_out @@ -216,21 +219,24 @@ def get_nd_rotary_pos_embed( pos_embed (torch.Tensor): [HW, D/2] """ - grid = get_meshgrid_nd( - start, *args, dim=len(rope_dim_list) - ) # [3, W, H, D] / [2, W, H] + grid = get_meshgrid_nd(start, *args, + dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] - if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): + if isinstance(theta_rescale_factor, int) or isinstance( + theta_rescale_factor, float): theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) - elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + elif isinstance(theta_rescale_factor, + list) and len(theta_rescale_factor) == 1: theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) assert len(theta_rescale_factor) == len( rope_dim_list ), "len(theta_rescale_factor) should equal to len(rope_dim_list)" - if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): + if isinstance(interpolation_factor, int) or isinstance( + interpolation_factor, float): interpolation_factor = [interpolation_factor] * len(rope_dim_list) - elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + elif isinstance(interpolation_factor, + list) and len(interpolation_factor) == 1: interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) assert len(interpolation_factor) == len( rope_dim_list @@ -292,11 +298,10 @@ def get_1d_rotary_pos_embed( # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning # has some connection to NTK literature if theta_rescale_factor != 1.0: - theta *= theta_rescale_factor ** (dim / (dim - 2)) + theta *= theta_rescale_factor**(dim / (dim - 2)) - freqs = 1.0 / ( - theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) - ) # [D/2] + freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim) + ) # [D/2] # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] if use_real: @@ -304,7 +309,6 @@ def get_1d_rotary_pos_embed( freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] return freqs_cos, freqs_sin else: - freqs_cis = torch.polar( - torch.ones_like(freqs), freqs - ) # complex64 # [S, D/2] + freqs_cis = torch.polar(torch.ones_like(freqs), + freqs) # complex64 # [S, D/2] return freqs_cis diff --git a/fastvideo/models/hunyuan/modules/token_refiner.py b/fastvideo/models/hunyuan/modules/token_refiner.py index d840526..ea6cb9c 100644 --- a/fastvideo/models/hunyuan/modules/token_refiner.py +++ b/fastvideo/models/hunyuan/modules/token_refiner.py @@ -1,19 +1,19 @@ from typing import Optional -from einops import rearrange import torch import torch.nn as nn +from einops import rearrange from .activation_layers import get_activation_layer from .attenion import attention -from .norm_layers import get_norm_layer -from .embed_layers import TimestepEmbedder, TextProjection -from .attenion import attention +from .embed_layers import TextProjection, TimestepEmbedder from .mlp_layers import MLP -from .modulate_layers import modulate, apply_gate +from .modulate_layers import apply_gate +from .norm_layers import get_norm_layer class IndividualTokenRefinerBlock(nn.Module): + def __init__( self, hidden_size, @@ -33,30 +33,30 @@ def __init__( head_dim = hidden_size // heads_num mlp_hidden_dim = int(hidden_size * mlp_width_ratio) - self.norm1 = nn.LayerNorm( - hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs - ) - self.self_attn_qkv = nn.Linear( - hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs - ) + self.norm1 = nn.LayerNorm(hidden_size, + elementwise_affine=True, + eps=1e-6, + **factory_kwargs) + self.self_attn_qkv = nn.Linear(hidden_size, + hidden_size * 3, + bias=qkv_bias, + **factory_kwargs) qk_norm_layer = get_norm_layer(qk_norm_type) - self.self_attn_q_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) - if qk_norm - else nn.Identity() - ) - self.self_attn_k_norm = ( - qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) - if qk_norm - else nn.Identity() - ) - self.self_attn_proj = nn.Linear( - hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs - ) - - self.norm2 = nn.LayerNorm( - hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs - ) + self.self_attn_q_norm = (qk_norm_layer( + head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm else nn.Identity()) + self.self_attn_k_norm = (qk_norm_layer( + head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm else nn.Identity()) + self.self_attn_proj = nn.Linear(hidden_size, + hidden_size, + bias=qkv_bias, + **factory_kwargs) + + self.norm2 = nn.LayerNorm(hidden_size, + elementwise_affine=True, + eps=1e-6, + **factory_kwargs) act_layer = get_activation_layer(act_type) self.mlp = MLP( in_channels=hidden_size, @@ -68,7 +68,10 @@ def __init__( self.adaLN_modulation = nn.Sequential( act_layer(), - nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + nn.Linear(hidden_size, + 2 * hidden_size, + bias=True, + **factory_kwargs), ) # Zero-initialize the modulation nn.init.zeros_(self.adaLN_modulation[1].weight) @@ -77,14 +80,18 @@ def __init__( def forward( self, x: torch.Tensor, - c: torch.Tensor, # timestep_aware_representations + context_aware_representations + c: torch. + Tensor, # timestep_aware_representations + context_aware_representations attn_mask: torch.Tensor = None, ): gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) norm_x = self.norm1(x) qkv = self.self_attn_qkv(norm_x) - q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + q, k, v = rearrange(qkv, + "B L (K H D) -> K B L H D", + K=3, + H=self.heads_num) # Apply QK-Norm if needed q = self.self_attn_q_norm(q).to(v) k = self.self_attn_k_norm(k).to(v) @@ -101,6 +108,7 @@ def forward( class IndividualTokenRefiner(nn.Module): + def __init__( self, hidden_size, @@ -117,25 +125,25 @@ def __init__( ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - self.blocks = nn.ModuleList( - [ - IndividualTokenRefinerBlock( - hidden_size=hidden_size, - heads_num=heads_num, - mlp_width_ratio=mlp_width_ratio, - mlp_drop_rate=mlp_drop_rate, - act_type=act_type, - qk_norm=qk_norm, - qk_norm_type=qk_norm_type, - qkv_bias=qkv_bias, - **factory_kwargs, - ) - for _ in range(depth) - ] - ) + self.blocks = nn.ModuleList([ + IndividualTokenRefinerBlock( + hidden_size=hidden_size, + heads_num=heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + **factory_kwargs, + ) for _ in range(depth) + ]) def forward( - self, x: torch.Tensor, c: torch.LongTensor, mask: Optional[torch.Tensor] = None, + self, + x: torch.Tensor, + c: torch.LongTensor, + mask: Optional[torch.Tensor] = None, ): mask = mask.clone().bool() # avoid attention weight become NaN @@ -171,17 +179,18 @@ def __init__( self.attn_mode = attn_mode assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner." - self.input_embedder = nn.Linear( - in_channels, hidden_size, bias=True, **factory_kwargs - ) + self.input_embedder = nn.Linear(in_channels, + hidden_size, + bias=True, + **factory_kwargs) act_layer = get_activation_layer(act_type) # Build timestep embedding layer - self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs) + self.t_embedder = TimestepEmbedder(hidden_size, act_layer, + **factory_kwargs) # Build context embedding layer - self.c_embedder = TextProjection( - in_channels, hidden_size, act_layer, **factory_kwargs - ) + self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, + **factory_kwargs) self.individual_token_refiner = IndividualTokenRefiner( hidden_size=hidden_size, @@ -209,9 +218,9 @@ def forward( else: mask_float = mask.float().unsqueeze(-1) # [b, s1, 1] context_aware_representations = (x * mask_float).sum( - dim=1 - ) / mask_float.sum(dim=1) - context_aware_representations = self.c_embedder(context_aware_representations) + dim=1) / mask_float.sum(dim=1) + context_aware_representations = self.c_embedder( + context_aware_representations) c = timestep_aware_representations + context_aware_representations x = self.input_embedder(x) diff --git a/fastvideo/models/hunyuan/prompt_rewrite.py b/fastvideo/models/hunyuan/prompt_rewrite.py index 72840b3..556b8aa 100644 --- a/fastvideo/models/hunyuan/prompt_rewrite.py +++ b/fastvideo/models/hunyuan/prompt_rewrite.py @@ -16,7 +16,6 @@ input: "{input}" """ - master_mode_prompt = """Master mode - Video Recaption Task: You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description. diff --git a/fastvideo/models/hunyuan/text_encoder/__init__.py b/fastvideo/models/hunyuan/text_encoder/__init__.py index 494992b..18cd88e 100644 --- a/fastvideo/models/hunyuan/text_encoder/__init__.py +++ b/fastvideo/models/hunyuan/text_encoder/__init__.py @@ -1,14 +1,12 @@ from dataclasses import dataclass from typing import Optional, Tuple -from copy import deepcopy import torch import torch.nn as nn -from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, AutoModel +from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer from transformers.utils import ModelOutput -from ..constants import TEXT_ENCODER_PATH, TOKENIZER_PATH -from ..constants import PRECISION_TO_TYPE +from ..constants import PRECISION_TO_TYPE, TEXT_ENCODER_PATH, TOKENIZER_PATH def use_default(value, default): @@ -33,16 +31,16 @@ def load_text_encoder( text_encoder = CLIPTextModel.from_pretrained(text_encoder_path) text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm elif text_encoder_type == "llm": - text_encoder = AutoModel.from_pretrained( - text_encoder_path, low_cpu_mem_usage=True - ) + text_encoder = AutoModel.from_pretrained(text_encoder_path, + low_cpu_mem_usage=True) text_encoder.final_layer_norm = text_encoder.norm else: raise ValueError(f"Unsupported text encoder type: {text_encoder_type}") # from_pretrained will ensure that the model is in eval mode. if text_encoder_precision is not None: - text_encoder = text_encoder.to(dtype=PRECISION_TO_TYPE[text_encoder_precision]) + text_encoder = text_encoder.to( + dtype=PRECISION_TO_TYPE[text_encoder_precision]) text_encoder.requires_grad_(False) @@ -55,20 +53,22 @@ def load_text_encoder( return text_encoder, text_encoder_path -def load_tokenizer( - tokenizer_type, tokenizer_path=None, padding_side="right", logger=None -): +def load_tokenizer(tokenizer_type, + tokenizer_path=None, + padding_side="right", + logger=None): if tokenizer_path is None: tokenizer_path = TOKENIZER_PATH[tokenizer_type] if logger is not None: - logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}") + logger.info( + f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}") if tokenizer_type == "clipL": - tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77) + tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, + max_length=77) elif tokenizer_type == "llm": - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_path, padding_side=padding_side - ) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, + padding_side=padding_side) else: raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}") @@ -100,6 +100,7 @@ class TextEncoderModelOutput(ModelOutput): class TextEncoder(nn.Module): + def __init__( self, text_encoder_type: str, @@ -124,20 +125,16 @@ def __init__( self.max_length = max_length self.precision = text_encoder_precision self.model_path = text_encoder_path - self.tokenizer_type = ( - tokenizer_type if tokenizer_type is not None else text_encoder_type - ) - self.tokenizer_path = ( - tokenizer_path if tokenizer_path is not None else text_encoder_path - ) + self.tokenizer_type = (tokenizer_type if tokenizer_type is not None + else text_encoder_type) + self.tokenizer_path = (tokenizer_path if tokenizer_path is not None + else text_encoder_path) self.use_attention_mask = use_attention_mask if prompt_template_video is not None: - assert ( - use_attention_mask is True - ), "Attention mask is True required when training videos." - self.input_max_length = ( - input_max_length if input_max_length is not None else max_length - ) + assert (use_attention_mask is True + ), "Attention mask is True required when training videos." + self.input_max_length = (input_max_length if input_max_length + is not None else max_length) self.prompt_template = prompt_template self.prompt_template_video = prompt_template_video self.hidden_state_skip_layer = hidden_state_skip_layer @@ -153,8 +150,7 @@ def __init__( ), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}" assert "{}" in str(self.prompt_template["template"]), ( "`prompt_template['template']` must contain a placeholder `{}` for the input text, " - f"got {self.prompt_template['template']}" - ) + f"got {self.prompt_template['template']}") self.use_video_template = self.prompt_template_video is not None if self.use_video_template: @@ -165,8 +161,7 @@ def __init__( ), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}" assert "{}" in str(self.prompt_template_video["template"]), ( "`prompt_template_video['template']` must contain a placeholder `{}` for the input text, " - f"got {self.prompt_template_video['template']}" - ) + f"got {self.prompt_template_video['template']}") if "t5" in text_encoder_type: self.output_key = output_key or "last_hidden_state" @@ -175,7 +170,8 @@ def __init__( elif "llm" in text_encoder_type or "glm" in text_encoder_type: self.output_key = output_key or "last_hidden_state" else: - raise ValueError(f"Unsupported text encoder type: {text_encoder_type}") + raise ValueError( + f"Unsupported text encoder type: {text_encoder_type}") self.model, self.model_path = load_text_encoder( text_encoder_type=self.text_encoder_type, @@ -205,7 +201,7 @@ def apply_text_to_template(text, template, prevent_empty_text=True): Args: text (str): Input text. template (str or list): Template string or list of chat conversation. - prevent_empty_text (bool): If Ture, we will prevent the user text from being empty + prevent_empty_text (bool): If True, we will prevent the user text from being empty by adding a space. Defaults to True. """ if isinstance(template, str): @@ -266,7 +262,8 @@ def text2tokens(self, text, data_type="image"): **kwargs, ) else: - raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}") + raise ValueError( + f"Unsupported tokenize_input_type: {tokenize_input_type}") def encode( self, @@ -294,14 +291,13 @@ def encode( return_texts (bool): Whether to return the decoded texts. Defaults to False. """ device = self.model.device if device is None else device - use_attention_mask = use_default(use_attention_mask, self.use_attention_mask) - hidden_state_skip_layer = use_default( - hidden_state_skip_layer, self.hidden_state_skip_layer - ) + use_attention_mask = use_default(use_attention_mask, + self.use_attention_mask) + hidden_state_skip_layer = use_default(hidden_state_skip_layer, + self.hidden_state_skip_layer) do_sample = use_default(do_sample, not self.reproduce) - attention_mask = ( - batch_encoding["attention_mask"].to(device) if use_attention_mask else None - ) + attention_mask = (batch_encoding["attention_mask"].to(device) + if use_attention_mask else None) outputs = self.model( input_ids=batch_encoding["input_ids"].to(device), attention_mask=attention_mask, @@ -309,11 +305,13 @@ def encode( or hidden_state_skip_layer is not None, ) if hidden_state_skip_layer is not None: - last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)] + last_hidden_state = outputs.hidden_states[-( + hidden_state_skip_layer + 1)] # Real last hidden state already has layer norm applied. So here we only apply it # for intermediate layers. if hidden_state_skip_layer > 0 and self.apply_final_norm: - last_hidden_state = self.model.final_layer_norm(last_hidden_state) + last_hidden_state = self.model.final_layer_norm( + last_hidden_state) else: last_hidden_state = outputs[self.output_key] @@ -327,14 +325,12 @@ def encode( raise ValueError(f"Unsupported data type: {data_type}") if crop_start > 0: last_hidden_state = last_hidden_state[:, crop_start:] - attention_mask = ( - attention_mask[:, crop_start:] if use_attention_mask else None - ) + attention_mask = (attention_mask[:, crop_start:] + if use_attention_mask else None) if output_hidden_states: - return TextEncoderModelOutput( - last_hidden_state, attention_mask, outputs.hidden_states - ) + return TextEncoderModelOutput(last_hidden_state, attention_mask, + outputs.hidden_states) return TextEncoderModelOutput(last_hidden_state, attention_mask) def forward( diff --git a/fastvideo/models/hunyuan/utils/data_utils.py b/fastvideo/models/hunyuan/utils/data_utils.py index 583a903..5241181 100644 --- a/fastvideo/models/hunyuan/utils/data_utils.py +++ b/fastvideo/models/hunyuan/utils/data_utils.py @@ -1,9 +1,8 @@ -import numpy as np import math def align_to(value, alignment): - """align hight, width according to alignment + """align height, width according to alignment Args: value (int): height or width diff --git a/fastvideo/models/hunyuan/utils/file_utils.py b/fastvideo/models/hunyuan/utils/file_utils.py index fbb0058..c87a95e 100644 --- a/fastvideo/models/hunyuan/utils/file_utils.py +++ b/fastvideo/models/hunyuan/utils/file_utils.py @@ -1,11 +1,11 @@ import os from pathlib import Path -from einops import rearrange +import imageio +import numpy as np import torch import torchvision -import numpy as np -import imageio +from einops import rearrange CODE_SUFFIXES = { ".py", # Python codes @@ -45,7 +45,11 @@ def safe_file(path): return path -def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24): +def save_videos_grid(videos: torch.Tensor, + path: str, + rescale=False, + n_rows=1, + fps=24): """save videos by video tensor copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61 diff --git a/fastvideo/models/hunyuan/utils/helpers.py b/fastvideo/models/hunyuan/utils/helpers.py index 8768812..7b5b873 100644 --- a/fastvideo/models/hunyuan/utils/helpers.py +++ b/fastvideo/models/hunyuan/utils/helpers.py @@ -1,9 +1,9 @@ import collections.abc - from itertools import repeat def _ntuple(n): + def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): x = tuple(x) @@ -25,7 +25,7 @@ def as_tuple(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) if x is None or isinstance(x, (int, float, str)): - return (x,) + return (x, ) else: raise ValueError(f"Unknown type {type(x)}") diff --git a/fastvideo/models/hunyuan/utils/preprocess_text_encoder_tokenizer_utils.py b/fastvideo/models/hunyuan/utils/preprocess_text_encoder_tokenizer_utils.py index 71b874c..1a6f46c 100644 --- a/fastvideo/models/hunyuan/utils/preprocess_text_encoder_tokenizer_utils.py +++ b/fastvideo/models/hunyuan/utils/preprocess_text_encoder_tokenizer_utils.py @@ -1,16 +1,16 @@ import argparse + import torch -from transformers import ( - AutoProcessor, - LlavaForConditionalGeneration, -) +from transformers import AutoProcessor, LlavaForConditionalGeneration def preprocess_text_encoder_tokenizer(args): processor = AutoProcessor.from_pretrained(args.input_dir) model = LlavaForConditionalGeneration.from_pretrained( - args.input_dir, torch_dtype=torch.float16, low_cpu_mem_usage=True, + args.input_dir, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, ).to(0) model.language_model.save_pretrained(f"{args.output_dir}") diff --git a/fastvideo/models/hunyuan/vae/__init__.py b/fastvideo/models/hunyuan/vae/__init__.py index da71425..f570423 100644 --- a/fastvideo/models/hunyuan/vae/__init__.py +++ b/fastvideo/models/hunyuan/vae/__init__.py @@ -2,8 +2,8 @@ import torch +from ..constants import PRECISION_TO_TYPE, VAE_PATH from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D -from ..constants import VAE_PATH, PRECISION_TO_TYPE def load_vae( @@ -14,7 +14,7 @@ def load_vae( logger=None, device=None, ): - """the fucntion to load the 3D VAE model + """the function to load the 3D VAE model Args: vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy". @@ -31,7 +31,8 @@ def load_vae( logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}") config = AutoencoderKLCausal3D.load_config(vae_path) if sample_size: - vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size) + vae = AutoencoderKLCausal3D.from_config(config, + sample_size=sample_size) else: vae = AutoencoderKLCausal3D.from_config(config) @@ -43,7 +44,8 @@ def load_vae( ckpt = ckpt["state_dict"] if any(k.startswith("vae.") for k in ckpt.keys()): ckpt = { - k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.") + k.replace("vae.", ""): v + for k, v in ckpt.items() if k.startswith("vae.") } vae.load_state_dict(ckpt) diff --git a/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py b/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py index fa43248..3d3b530 100644 --- a/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py +++ b/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py @@ -16,12 +16,11 @@ # Modified from diffusers==0.29.2 # # ============================================================================== -from typing import Dict, Optional, Tuple, Union from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn - from diffusers.configuration_utils import ConfigMixin, register_to_config try: @@ -30,26 +29,17 @@ except ImportError: # Use this to be compatible with the original diffusers. from diffusers.loaders.single_file_model import ( - FromOriginalModelMixin as FromOriginalVAEMixin, - ) -from diffusers.utils.accelerate_utils import apply_forward_hook + FromOriginalModelMixin as FromOriginalVAEMixin, ) + from diffusers.models.attention_processor import ( - ADDED_KV_ATTENTION_PROCESSORS, - CROSS_ATTENTION_PROCESSORS, - Attention, - AttentionProcessor, - AttnAddedKVProcessor, - AttnProcessor, -) + ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention, + AttentionProcessor, AttnAddedKVProcessor, AttnProcessor) from diffusers.models.modeling_outputs import AutoencoderKLOutput from diffusers.models.modeling_utils import ModelMixin -from .vae import ( - DecoderCausal3D, - BaseOutput, - DecoderOutput, - DiagonalGaussianDistribution, - EncoderCausal3D, -) +from diffusers.utils.accelerate_utils import apply_forward_hook + +from .vae import (BaseOutput, DecoderCausal3D, DecoderOutput, + DiagonalGaussianDistribution, EncoderCausal3D) @dataclass @@ -73,9 +63,9 @@ def __init__( self, in_channels: int = 3, out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",), - up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",), - block_out_channels: Tuple[int] = (64,), + down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D", ), + up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D", ), + block_out_channels: Tuple[int] = (64, ), layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 4, @@ -119,12 +109,12 @@ def __init__( mid_block_add_attention=mid_block_add_attention, ) - self.quant_conv = nn.Conv3d( - 2 * latent_channels, 2 * latent_channels, kernel_size=1 - ) - self.post_quant_conv = nn.Conv3d( - latent_channels, latent_channels, kernel_size=1 - ) + self.quant_conv = nn.Conv3d(2 * latent_channels, + 2 * latent_channels, + kernel_size=1) + self.post_quant_conv = nn.Conv3d(latent_channels, + latent_channels, + kernel_size=1) self.use_slicing = False self.use_spatial_tiling = False @@ -135,14 +125,11 @@ def __init__( self.tile_latent_min_tsize = sample_tsize // time_compression_ratio self.tile_sample_min_size = self.config.sample_size - sample_size = ( - self.config.sample_size[0] - if isinstance(self.config.sample_size, (list, tuple)) - else self.config.sample_size - ) + sample_size = (self.config.sample_size[0] if isinstance( + self.config.sample_size, + (list, tuple)) else self.config.sample_size) self.tile_latent_min_size = int( - sample_size / (2 ** (len(self.config.block_out_channels) - 1)) - ) + sample_size / (2**(len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 def _set_gradient_checkpointing(self, module, value=False): @@ -210,11 +197,11 @@ def fn_recursive_add_processors( ): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor( - return_deprecated_lora=True - ) + return_deprecated_lora=True) for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + fn_recursive_add_processors(f"{name}.{sub_name}", child, + processors) return processors @@ -249,17 +236,18 @@ def set_attn_processor( f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, + processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor, _remove_lora=_remove_lora) else: - module.set_processor( - processor.pop(f"{name}.processor"), _remove_lora=_remove_lora - ) + module.set_processor(processor.pop(f"{name}.processor"), + _remove_lora=_remove_lora) for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + fn_recursive_attn_processor(f"{name}.{sub_name}", child, + processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) @@ -269,15 +257,11 @@ def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ - if all( - proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS - for proc in self.attn_processors.values() - ): + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS + for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() - elif all( - proc.__class__ in CROSS_ATTENTION_PROCESSORS - for proc in self.attn_processors.values() - ): + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS + for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( @@ -288,7 +272,9 @@ def set_default_attn_processor(self): @apply_forward_hook def encode( - self, x: torch.FloatTensor, return_dict: bool = True + self, + x: torch.FloatTensor, + return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images/videos into latents. @@ -308,9 +294,8 @@ def encode( return self.temporal_tiled_encode(x, return_dict=return_dict) if self.use_spatial_tiling and ( - x.shape[-1] > self.tile_sample_min_size - or x.shape[-2] > self.tile_sample_min_size - ): + x.shape[-1] > self.tile_sample_min_size + or x.shape[-2] > self.tile_sample_min_size): return self.spatial_tiled_encode(x, return_dict=return_dict) if self.use_slicing and x.shape[0] > 1: @@ -323,12 +308,14 @@ def encode( posterior = DiagonalGaussianDistribution(moments) if not return_dict: - return (posterior,) + return (posterior, ) return AutoencoderKLOutput(latent_dist=posterior) def _decode( - self, z: torch.FloatTensor, return_dict: bool = True + self, + z: torch.FloatTensor, + return_dict: bool = True ) -> Union[DecoderOutput, torch.FloatTensor]: assert len(z.shape) == 5, "The input tensor should have 5 dimensions." @@ -336,23 +323,23 @@ def _decode( return self.temporal_tiled_decode(z, return_dict=return_dict) if self.use_spatial_tiling and ( - z.shape[-1] > self.tile_latent_min_size - or z.shape[-2] > self.tile_latent_min_size - ): + z.shape[-1] > self.tile_latent_min_size + or z.shape[-2] > self.tile_latent_min_size): return self.spatial_tiled_decode(z, return_dict=return_dict) z = self.post_quant_conv(z) dec = self.decoder(z) if not return_dict: - return (dec,) + return (dec, ) return DecoderOutput(sample=dec) @apply_forward_hook - def decode( - self, z: torch.FloatTensor, return_dict: bool = True, generator=None - ) -> Union[DecoderOutput, torch.FloatTensor]: + def decode(self, + z: torch.FloatTensor, + return_dict: bool = True, + generator=None) -> Union[DecoderOutput, torch.FloatTensor]: """ Decode a batch of images/videos. @@ -368,44 +355,40 @@ def decode( """ if self.use_slicing and z.shape[0] > 1: - decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded_slices = [ + self._decode(z_slice).sample for z_slice in z.split(1) + ] decoded = torch.cat(decoded_slices) else: decoded = self._decode(z).sample if not return_dict: - return (decoded,) + return (decoded, ) return DecoderOutput(sample=decoded) - def blend_v( - self, a: torch.Tensor, b: torch.Tensor, blend_extent: int - ) -> torch.Tensor: + def blend_v(self, a: torch.Tensor, b: torch.Tensor, + blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) for y in range(blend_extent): b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( - 1 - y / blend_extent - ) + b[:, :, :, y, :] * (y / blend_extent) + 1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) return b - def blend_h( - self, a: torch.Tensor, b: torch.Tensor, blend_extent: int - ) -> torch.Tensor: + def blend_h(self, a: torch.Tensor, b: torch.Tensor, + blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) for x in range(blend_extent): b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( - 1 - x / blend_extent - ) + b[:, :, :, :, x] * (x / blend_extent) + 1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) return b - def blend_t( - self, a: torch.Tensor, b: torch.Tensor, blend_extent: int - ) -> torch.Tensor: + def blend_t(self, a: torch.Tensor, b: torch.Tensor, + blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) for x in range(blend_extent): b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * ( - 1 - x / blend_extent - ) + b[:, :, x, :, :] * (x / blend_extent) + 1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent) return b def spatial_tiled_encode( @@ -432,8 +415,10 @@ def spatial_tiled_encode( If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ - overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + overlap_size = int(self.tile_sample_min_size * + (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * + self.tile_overlap_factor) row_limit = self.tile_latent_min_size - blend_extent # Split video into tiles and encode them separately. @@ -441,13 +426,8 @@ def spatial_tiled_encode( for i in range(0, x.shape[-2], overlap_size): row = [] for j in range(0, x.shape[-1], overlap_size): - tile = x[ - :, - :, - :, - i : i + self.tile_sample_min_size, - j : j + self.tile_sample_min_size, - ] + tile = x[:, :, :, i:i + self.tile_sample_min_size, + j:j + self.tile_sample_min_size, ] tile = self.encoder(tile) tile = self.quant_conv(tile) row.append(tile) @@ -471,13 +451,14 @@ def spatial_tiled_encode( posterior = DiagonalGaussianDistribution(moments) if not return_dict: - return (posterior,) + return (posterior, ) return AutoencoderKLOutput(latent_dist=posterior) - def spatial_tiled_decode( - self, z: torch.FloatTensor, return_dict: bool = True - ) -> Union[DecoderOutput, torch.FloatTensor]: + def spatial_tiled_decode(self, + z: torch.FloatTensor, + return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: r""" Decode a batch of images/videos using a tiled decoder. @@ -491,8 +472,10 @@ def spatial_tiled_decode( If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ - overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + overlap_size = int(self.tile_latent_min_size * + (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * + self.tile_overlap_factor) row_limit = self.tile_sample_min_size - blend_extent # Split z into overlapping tiles and decode them separately. @@ -501,13 +484,8 @@ def spatial_tiled_decode( for i in range(0, z.shape[-2], overlap_size): row = [] for j in range(0, z.shape[-1], overlap_size): - tile = z[ - :, - :, - :, - i : i + self.tile_latent_min_size, - j : j + self.tile_latent_min_size, - ] + tile = z[:, :, :, i:i + self.tile_latent_min_size, + j:j + self.tile_latent_min_size, ] tile = self.post_quant_conv(tile) decoded = self.decoder(tile) row.append(decoded) @@ -527,27 +505,28 @@ def spatial_tiled_decode( dec = torch.cat(result_rows, dim=-2) if not return_dict: - return (dec,) + return (dec, ) return DecoderOutput(sample=dec) - def temporal_tiled_encode( - self, x: torch.FloatTensor, return_dict: bool = True - ) -> AutoencoderKLOutput: + def temporal_tiled_encode(self, + x: torch.FloatTensor, + return_dict: bool = True) -> AutoencoderKLOutput: B, C, T, H, W = x.shape - overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) + overlap_size = int(self.tile_sample_min_tsize * + (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_tsize * + self.tile_overlap_factor) t_limit = self.tile_latent_min_tsize - blend_extent # Split the video into tiles and encode them separately. row = [] for i in range(0, T, overlap_size): - tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :] + tile = x[:, :, i:i + self.tile_sample_min_tsize + 1, :, :] if self.use_spatial_tiling and ( - tile.shape[-1] > self.tile_sample_min_size - or tile.shape[-2] > self.tile_sample_min_size - ): + tile.shape[-1] > self.tile_sample_min_size + or tile.shape[-2] > self.tile_sample_min_size): tile = self.spatial_tiled_encode(tile, return_moments=True) else: tile = self.encoder(tile) @@ -561,34 +540,37 @@ def temporal_tiled_encode( tile = self.blend_t(row[i - 1], tile, blend_extent) result_row.append(tile[:, :, :t_limit, :, :]) else: - result_row.append(tile[:, :, : t_limit + 1, :, :]) + result_row.append(tile[:, :, :t_limit + 1, :, :]) moments = torch.cat(result_row, dim=2) posterior = DiagonalGaussianDistribution(moments) if not return_dict: - return (posterior,) + return (posterior, ) return AutoencoderKLOutput(latent_dist=posterior) - def temporal_tiled_decode( - self, z: torch.FloatTensor, return_dict: bool = True - ) -> Union[DecoderOutput, torch.FloatTensor]: + def temporal_tiled_decode(self, + z: torch.FloatTensor, + return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: # Split z into overlapping tiles and decode them separately. B, C, T, H, W = z.shape - overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) + overlap_size = int(self.tile_latent_min_tsize * + (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_tsize * + self.tile_overlap_factor) t_limit = self.tile_sample_min_tsize - blend_extent row = [] for i in range(0, T, overlap_size): - tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :] + tile = z[:, :, i:i + self.tile_latent_min_tsize + 1, :, :] if self.use_spatial_tiling and ( - tile.shape[-1] > self.tile_latent_min_size - or tile.shape[-2] > self.tile_latent_min_size - ): - decoded = self.spatial_tiled_decode(tile, return_dict=True).sample + tile.shape[-1] > self.tile_latent_min_size + or tile.shape[-2] > self.tile_latent_min_size): + decoded = self.spatial_tiled_decode(tile, + return_dict=True).sample else: tile = self.post_quant_conv(tile) decoded = self.decoder(tile) @@ -601,11 +583,11 @@ def temporal_tiled_decode( tile = self.blend_t(row[i - 1], tile, blend_extent) result_row.append(tile[:, :, :t_limit, :, :]) else: - result_row.append(tile[:, :, : t_limit + 1, :, :]) + result_row.append(tile[:, :, :t_limit + 1, :, :]) dec = torch.cat(result_row, dim=2) if not return_dict: - return (dec,) + return (dec, ) return DecoderOutput(sample=dec) @@ -637,7 +619,7 @@ def forward( if return_posterior: return (dec, posterior) else: - return (dec,) + return (dec, ) if return_posterior: return DecoderOutput2(sample=dec, posterior=posterior) else: diff --git a/fastvideo/models/hunyuan/vae/unet_causal_3d_blocks.py b/fastvideo/models/hunyuan/vae/unet_causal_3d_blocks.py index f0eb6eb..37dce39 100644 --- a/fastvideo/models/hunyuan/vae/unet_causal_3d_blocks.py +++ b/fastvideo/models/hunyuan/vae/unet_causal_3d_blocks.py @@ -21,27 +21,29 @@ import torch import torch.nn.functional as F -from torch import nn -from einops import rearrange - -from diffusers.utils import logging from diffusers.models.activations import get_activation -from diffusers.models.attention_processor import SpatialNorm -from diffusers.models.attention_processor import Attention -from diffusers.models.normalization import AdaGroupNorm -from diffusers.models.normalization import RMSNorm +from diffusers.models.attention_processor import Attention, SpatialNorm +from diffusers.models.normalization import AdaGroupNorm, RMSNorm +from diffusers.utils import logging +from einops import rearrange +from torch import nn logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def prepare_causal_attention_mask( - n_frame: int, n_hw: int, dtype, device, batch_size: int = None -): +def prepare_causal_attention_mask(n_frame: int, + n_hw: int, + dtype, + device, + batch_size: int = None): seq_len = n_frame * n_hw - mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) + mask = torch.full((seq_len, seq_len), + float("-inf"), + dtype=dtype, + device=device) for i in range(seq_len): i_frame = i // n_hw - mask[i, : (i_frame + 1) * n_hw] = 0 + mask[i, :(i_frame + 1) * n_hw] = 0 if batch_size is not None: mask = mask.unsqueeze(0).expand(batch_size, -1, -1) return mask @@ -76,9 +78,12 @@ def __init__( ) # W, H, T self.time_causal_padding = padding - self.conv = nn.Conv3d( - chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs - ) + self.conv = nn.Conv3d(chan_in, + chan_out, + kernel_size, + stride=stride, + dilation=dilation, + **kwargs) def forward(self, x): x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) @@ -91,20 +96,20 @@ class UpsampleCausal3D(nn.Module): """ def __init__( - self, - channels: int, - use_conv: bool = False, - use_conv_transpose: bool = False, - out_channels: Optional[int] = None, - name: str = "conv", - kernel_size: Optional[int] = None, - padding=1, - norm_type=None, - eps=None, - elementwise_affine=None, - bias=True, - interpolate=True, - upsample_factor=(2, 2, 2), + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + kernel_size: Optional[int] = None, + padding=1, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + interpolate=True, + upsample_factor=(2, 2, 2), ): super().__init__() self.channels = channels @@ -130,9 +135,10 @@ def __init__( elif use_conv: if kernel_size is None: kernel_size = 3 - conv = CausalConv3d( - self.channels, self.out_channels, kernel_size=kernel_size, bias=bias - ) + conv = CausalConv3d(self.channels, + self.out_channels, + kernel_size=kernel_size, + bias=bias) if name == "conv": self.conv = conv @@ -169,14 +175,14 @@ def forward( first_h, other_h = hidden_states.split((1, T - 1), dim=2) if output_size is None: if T > 1: - other_h = F.interpolate( - other_h, scale_factor=self.upsample_factor, mode="nearest" - ) + other_h = F.interpolate(other_h, + scale_factor=self.upsample_factor, + mode="nearest") first_h = first_h.squeeze(2) - first_h = F.interpolate( - first_h, scale_factor=self.upsample_factor[1:], mode="nearest" - ) + first_h = F.interpolate(first_h, + scale_factor=self.upsample_factor[1:], + mode="nearest") first_h = first_h.unsqueeze(2) else: raise NotImplementedError @@ -254,15 +260,15 @@ def __init__( else: self.conv = conv - def forward( - self, hidden_states: torch.FloatTensor, scale: float = 1.0 - ) -> torch.FloatTensor: + def forward(self, + hidden_states: torch.FloatTensor, + scale: float = 1.0) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels if self.norm is not None: - hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute( - 0, 3, 1, 2 - ) + hidden_states = self.norm(hidden_states.permute(0, 2, 3, + 1)).permute( + 0, 3, 1, 2) assert hidden_states.shape[1] == self.channels @@ -319,25 +325,31 @@ def __init__( groups_out = groups if self.time_embedding_norm == "ada_group": - self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) + self.norm1 = AdaGroupNorm(temb_channels, + in_channels, + groups, + eps=eps) elif self.time_embedding_norm == "spatial": self.norm1 = SpatialNorm(in_channels, temb_channels) else: - self.norm1 = torch.nn.GroupNorm( - num_groups=groups, num_channels=in_channels, eps=eps, affine=True - ) + self.norm1 = torch.nn.GroupNorm(num_groups=groups, + num_channels=in_channels, + eps=eps, + affine=True) - self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1) + self.conv1 = CausalConv3d(in_channels, + out_channels, + kernel_size=3, + stride=1) if temb_channels is not None: if self.time_embedding_norm == "default": self.time_emb_proj = linear_cls(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": - self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels) - elif ( - self.time_embedding_norm == "ada_group" - or self.time_embedding_norm == "spatial" - ): + self.time_emb_proj = linear_cls(temb_channels, + 2 * out_channels) + elif (self.time_embedding_norm == "ada_group" + or self.time_embedding_norm == "spatial"): self.time_emb_proj = None else: raise ValueError( @@ -347,19 +359,24 @@ def __init__( self.time_emb_proj = None if self.time_embedding_norm == "ada_group": - self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) + self.norm2 = AdaGroupNorm(temb_channels, + out_channels, + groups_out, + eps=eps) elif self.time_embedding_norm == "spatial": self.norm2 = SpatialNorm(out_channels, temb_channels) else: - self.norm2 = torch.nn.GroupNorm( - num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True - ) + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, + num_channels=out_channels, + eps=eps, + affine=True) self.dropout = torch.nn.Dropout(dropout) conv_3d_out_channels = conv_3d_out_channels or out_channels - self.conv2 = CausalConv3d( - out_channels, conv_3d_out_channels, kernel_size=3, stride=1 - ) + self.conv2 = CausalConv3d(out_channels, + conv_3d_out_channels, + kernel_size=3, + stride=1) self.nonlinearity = get_activation(non_linearity) @@ -367,13 +384,12 @@ def __init__( if self.up: self.upsample = UpsampleCausal3D(in_channels, use_conv=False) elif self.down: - self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op") + self.downsample = DownsampleCausal3D(in_channels, + use_conv=False, + name="op") - self.use_in_shortcut = ( - self.in_channels != conv_3d_out_channels - if use_in_shortcut is None - else use_in_shortcut - ) + self.use_in_shortcut = (self.in_channels != conv_3d_out_channels if + use_in_shortcut is None else use_in_shortcut) self.conv_shortcut = None if self.use_in_shortcut: @@ -393,10 +409,8 @@ def forward( ) -> torch.FloatTensor: hidden_states = input_tensor - if ( - self.time_embedding_norm == "ada_group" - or self.time_embedding_norm == "spatial" - ): + if (self.time_embedding_norm == "ada_group" + or self.time_embedding_norm == "spatial"): hidden_states = self.norm1(hidden_states, temb) else: hidden_states = self.norm1(hidden_states) @@ -424,10 +438,8 @@ def forward( if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb - if ( - self.time_embedding_norm == "ada_group" - or self.time_embedding_norm == "spatial" - ): + if (self.time_embedding_norm == "ada_group" + or self.time_embedding_norm == "spatial"): hidden_states = self.norm2(hidden_states, temb) else: hidden_states = self.norm2(hidden_states) @@ -444,7 +456,8 @@ def forward( if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) - output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + output_tensor = (input_tensor + + hidden_states) / self.output_scale_factor return output_tensor @@ -484,11 +497,9 @@ def get_down_block3d( ) attention_head_dim = num_attention_heads - down_block_type = ( - down_block_type[7:] - if down_block_type.startswith("UNetRes") - else down_block_type - ) + down_block_type = (down_block_type[7:] + if down_block_type.startswith("UNetRes") else + down_block_type) if down_block_type == "DownEncoderBlockCausal3D": return DownEncoderBlockCausal3D( num_layers=num_layers, @@ -542,9 +553,8 @@ def get_up_block3d( ) attention_head_dim = num_attention_heads - up_block_type = ( - up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type - ) + up_block_type = (up_block_type[7:] + if up_block_type.startswith("UNetRes") else up_block_type) if up_block_type == "UpDecoderBlockCausal3D": return UpDecoderBlockCausal3D( num_layers=num_layers, @@ -585,15 +595,13 @@ def __init__( output_scale_factor: float = 1.0, ): super().__init__() - resnet_groups = ( - resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - ) + resnet_groups = (resnet_groups if resnet_groups is not None else min( + in_channels // 4, 32)) self.add_attention = add_attention if attn_groups is None: - attn_groups = ( - resnet_groups if resnet_time_scale_shift == "default" else None - ) + attn_groups = (resnet_groups + if resnet_time_scale_shift == "default" else None) # there is always at least one resnet resnets = [ @@ -628,17 +636,14 @@ def __init__( rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=attn_groups, - spatial_norm_dim=( - temb_channels - if resnet_time_scale_shift == "spatial" - else None - ), + spatial_norm_dim=(temb_channels + if resnet_time_scale_shift + == "spatial" else None), residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, - ) - ) + )) else: attentions.append(None) @@ -654,35 +659,41 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - ) - ) + )) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward( - self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None - ) -> torch.FloatTensor: + def forward(self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: B, C, T, H, W = hidden_states.shape - hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c") + hidden_states = rearrange(hidden_states, + "b c f h w -> b (f h w) c") attention_mask = prepare_causal_attention_mask( - T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B - ) - hidden_states = attn( - hidden_states, temb=temb, attention_mask=attention_mask - ) - hidden_states = rearrange( - hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W - ) + T, + H * W, + hidden_states.dtype, + hidden_states.device, + batch_size=B) + hidden_states = attn(hidden_states, + temb=temb, + attention_mask=attention_mask) + hidden_states = rearrange(hidden_states, + "b (f h w) c -> b c f h w", + f=T, + h=H, + w=W) hidden_states = resnet(hidden_states, temb) return hidden_states class DownEncoderBlockCausal3D(nn.Module): + def __init__( self, in_channels: int, @@ -716,30 +727,27 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - ) - ) + )) self.resnets = nn.ModuleList(resnets) if add_downsample: - self.downsamplers = nn.ModuleList( - [ - DownsampleCausal3D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - stride=downsample_stride, - ) - ] - ) + self.downsamplers = nn.ModuleList([ + DownsampleCausal3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + stride=downsample_stride, + ) + ]) else: self.downsamplers = None - def forward( - self, hidden_states: torch.FloatTensor, scale: float = 1.0 - ) -> torch.FloatTensor: + def forward(self, + hidden_states: torch.FloatTensor, + scale: float = 1.0) -> torch.FloatTensor: for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None, scale=scale) @@ -751,6 +759,7 @@ def forward( class UpDecoderBlockCausal3D(nn.Module): + def __init__( self, in_channels: int, @@ -786,22 +795,19 @@ def __init__( non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - ) - ) + )) self.resnets = nn.ModuleList(resnets) if add_upsample: - self.upsamplers = nn.ModuleList( - [ - UpsampleCausal3D( - out_channels, - use_conv=True, - out_channels=out_channels, - upsample_factor=upsample_scale_factor, - ) - ] - ) + self.upsamplers = nn.ModuleList([ + UpsampleCausal3D( + out_channels, + use_conv=True, + out_channels=out_channels, + upsample_factor=upsample_scale_factor, + ) + ]) else: self.upsamplers = None diff --git a/fastvideo/models/hunyuan/vae/vae.py b/fastvideo/models/hunyuan/vae/vae.py index 9792f95..117da9a 100644 --- a/fastvideo/models/hunyuan/vae/vae.py +++ b/fastvideo/models/hunyuan/vae/vae.py @@ -4,16 +4,12 @@ import numpy as np import torch import torch.nn as nn - +from diffusers.models.attention_processor import SpatialNorm from diffusers.utils import BaseOutput, is_torch_version from diffusers.utils.torch_utils import randn_tensor -from diffusers.models.attention_processor import SpatialNorm -from .unet_causal_3d_blocks import ( - CausalConv3d, - UNetMidBlockCausal3D, - get_down_block3d, - get_up_block3d, -) + +from .unet_causal_3d_blocks import (CausalConv3d, UNetMidBlockCausal3D, + get_down_block3d, get_up_block3d) @dataclass @@ -38,8 +34,8 @@ def __init__( self, in_channels: int = 3, out_channels: int = 3, - down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",), - block_out_channels: Tuple[int, ...] = (64,), + down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D", ), + block_out_channels: Tuple[int, ...] = (64, ), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", @@ -51,9 +47,10 @@ def __init__( super().__init__() self.layers_per_block = layers_per_block - self.conv_in = CausalConv3d( - in_channels, block_out_channels[0], kernel_size=3, stride=1 - ) + self.conv_in = CausalConv3d(in_channels, + block_out_channels[0], + kernel_size=3, + stride=1) self.mid_block = None self.down_blocks = nn.ModuleList([]) @@ -63,29 +60,33 @@ def __init__( input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 - num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) + num_spatial_downsample_layers = int( + np.log2(spatial_compression_ratio)) num_time_downsample_layers = int(np.log2(time_compression_ratio)) if time_compression_ratio == 4: - add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_spatial_downsample = bool( + i < num_spatial_downsample_layers) add_time_downsample = bool( - i >= (len(block_out_channels) - 1 - num_time_downsample_layers) - and not is_final_block - ) + i >= + (len(block_out_channels) - 1 - num_time_downsample_layers) + and not is_final_block) else: raise ValueError( f"Unsupported time_compression_ratio: {time_compression_ratio}." ) downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) - downsample_stride_T = (2,) if add_time_downsample else (1,) - downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) + downsample_stride_T = (2, ) if add_time_downsample else (1, ) + downsample_stride = tuple(downsample_stride_T + + downsample_stride_HW) down_block = get_down_block3d( down_block_type, num_layers=self.layers_per_block, in_channels=input_channel, out_channels=output_channel, - add_downsample=bool(add_spatial_downsample or add_time_downsample), + add_downsample=bool(add_spatial_downsample + or add_time_downsample), downsample_stride=downsample_stride, resnet_eps=1e-6, downsample_padding=0, @@ -110,19 +111,20 @@ def __init__( ) # out - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 - ) + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], + num_groups=norm_num_groups, + eps=1e-6) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = CausalConv3d( - block_out_channels[-1], conv_out_channels, kernel_size=3 - ) + self.conv_out = CausalConv3d(block_out_channels[-1], + conv_out_channels, + kernel_size=3) def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: r"""The forward method of the `EncoderCausal3D` class.""" - assert len(sample.shape) == 5, "The input tensor should have 5 dimensions" + assert len( + sample.shape) == 5, "The input tensor should have 5 dimensions" sample = self.conv_in(sample) @@ -150,8 +152,8 @@ def __init__( self, in_channels: int = 3, out_channels: int = 3, - up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",), - block_out_channels: Tuple[int, ...] = (64,), + up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D", ), + block_out_channels: Tuple[int, ...] = (64, ), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", @@ -163,9 +165,10 @@ def __init__( super().__init__() self.layers_per_block = layers_per_block - self.conv_in = CausalConv3d( - in_channels, block_out_channels[-1], kernel_size=3, stride=1 - ) + self.conv_in = CausalConv3d(in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1) self.mid_block = None self.up_blocks = nn.ModuleList([]) @@ -177,7 +180,8 @@ def __init__( resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, - resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + resnet_time_scale_shift="default" + if norm_type == "group" else norm_type, attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=temb_channels, @@ -191,25 +195,25 @@ def __init__( prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 - num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio)) + num_spatial_upsample_layers = int( + np.log2(spatial_compression_ratio)) num_time_upsample_layers = int(np.log2(time_compression_ratio)) if time_compression_ratio == 4: add_spatial_upsample = bool(i < num_spatial_upsample_layers) add_time_upsample = bool( i >= len(block_out_channels) - 1 - num_time_upsample_layers - and not is_final_block - ) + and not is_final_block) else: raise ValueError( f"Unsupported time_compression_ratio: {time_compression_ratio}." ) - upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1) - upsample_scale_factor_T = (2,) if add_time_upsample else (1,) - upsample_scale_factor = tuple( - upsample_scale_factor_T + upsample_scale_factor_HW - ) + upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, + 1) + upsample_scale_factor_T = (2, ) if add_time_upsample else (1, ) + upsample_scale_factor = tuple(upsample_scale_factor_T + + upsample_scale_factor_HW) up_block = get_up_block3d( up_block_type, num_layers=self.layers_per_block + 1, @@ -230,13 +234,17 @@ def __init__( # out if norm_type == "spatial": - self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + self.conv_norm_out = SpatialNorm(block_out_channels[0], + temb_channels) else: self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 - ) + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=1e-6) self.conv_act = nn.SiLU() - self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3) + self.conv_out = CausalConv3d(block_out_channels[0], + out_channels, + kernel_size=3) self.gradient_checkpointing = False @@ -246,7 +254,8 @@ def forward( latent_embeds: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: r"""The forward method of the `DecoderCausal3D` class.""" - assert len(sample.shape) == 5, "The input tensor should have 5 dimensions." + assert len( + sample.shape) == 5, "The input tensor should have 5 dimensions." sample = self.conv_in(sample) @@ -254,6 +263,7 @@ def forward( if self.training and self.gradient_checkpointing: def create_custom_forward(module): + def custom_forward(*inputs): return module(*inputs) @@ -280,15 +290,14 @@ def custom_forward(*inputs): else: # middle sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, latent_embeds - ) + create_custom_forward(self.mid_block), sample, + latent_embeds) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), sample, latent_embeds - ) + create_custom_forward(up_block), sample, latent_embeds) else: # middle sample = self.mid_block(sample, latent_embeds) @@ -310,6 +319,7 @@ def custom_forward(*inputs): class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): if parameters.ndim == 3: dim = 2 # (B, L, C) @@ -325,10 +335,13 @@ def __init__(self, parameters: torch.Tensor, deterministic: bool = False): self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like( - self.mean, device=self.parameters.device, dtype=self.parameters.dtype - ) + self.mean, + device=self.parameters.device, + dtype=self.parameters.dtype) - def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: + def sample( + self, + generator: Optional[torch.Generator] = None) -> torch.FloatTensor: # make sure sample is on the same device as the parameters and has same dtype sample = randn_tensor( self.mean.shape, @@ -351,22 +364,20 @@ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: ) else: return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=reduce_dim, ) - def nll( - self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3] - ) -> torch.Tensor: + def nll(self, + sample: torch.Tensor, + dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + logtwopi + self.logvar + + torch.pow(sample - self.mean, 2) / self.var, dim=dims, ) diff --git a/fastvideo/models/mochi_hf/convert_diffusers_to_mochi.py b/fastvideo/models/mochi_hf/convert_diffusers_to_mochi.py index 319605e..7382016 100644 --- a/fastvideo/models/mochi_hf/convert_diffusers_to_mochi.py +++ b/fastvideo/models/mochi_hf/convert_diffusers_to_mochi.py @@ -1,19 +1,23 @@ -import torch import argparse -from safetensors.torch import save_file import os +import torch +from safetensors.torch import save_file + parser = argparse.ArgumentParser() parser.add_argument("--diffusers_path", required=True, type=str) -parser.add_argument( - "--transformer_path", type=str, default=None, help="Path to save transformer model" -) -parser.add_argument( - "--vae_encoder_path", type=str, default=None, help="Path to save VAE encoder model" -) -parser.add_argument( - "--vae_decoder_path", type=str, default=None, help="Path to save VAE decoder model" -) +parser.add_argument("--transformer_path", + type=str, + default=None, + help="Path to save transformer model") +parser.add_argument("--vae_encoder_path", + type=str, + default=None, + help="Path to save VAE encoder model") +parser.add_argument("--vae_decoder_path", + type=str, + default=None, + help="Path to save VAE decoder model") args = parser.parse_args() @@ -36,49 +40,35 @@ def convert_diffusers_transformer_to_mochi(state_dict): # Convert patch_embed new_state_dict["x_embedder.proj.weight"] = original_state_dict.pop( - "patch_embed.proj.weight" - ) + "patch_embed.proj.weight") new_state_dict["x_embedder.proj.bias"] = original_state_dict.pop( - "patch_embed.proj.bias" - ) + "patch_embed.proj.bias") # Convert time_embed new_state_dict["t_embedder.mlp.0.weight"] = original_state_dict.pop( - "time_embed.timestep_embedder.linear_1.weight" - ) + "time_embed.timestep_embedder.linear_1.weight") new_state_dict["t_embedder.mlp.0.bias"] = original_state_dict.pop( - "time_embed.timestep_embedder.linear_1.bias" - ) + "time_embed.timestep_embedder.linear_1.bias") new_state_dict["t_embedder.mlp.2.weight"] = original_state_dict.pop( - "time_embed.timestep_embedder.linear_2.weight" - ) + "time_embed.timestep_embedder.linear_2.weight") new_state_dict["t_embedder.mlp.2.bias"] = original_state_dict.pop( - "time_embed.timestep_embedder.linear_2.bias" - ) + "time_embed.timestep_embedder.linear_2.bias") new_state_dict["t5_y_embedder.to_kv.weight"] = original_state_dict.pop( - "time_embed.pooler.to_kv.weight" - ) + "time_embed.pooler.to_kv.weight") new_state_dict["t5_y_embedder.to_kv.bias"] = original_state_dict.pop( - "time_embed.pooler.to_kv.bias" - ) + "time_embed.pooler.to_kv.bias") new_state_dict["t5_y_embedder.to_q.weight"] = original_state_dict.pop( - "time_embed.pooler.to_q.weight" - ) + "time_embed.pooler.to_q.weight") new_state_dict["t5_y_embedder.to_q.bias"] = original_state_dict.pop( - "time_embed.pooler.to_q.bias" - ) + "time_embed.pooler.to_q.bias") new_state_dict["t5_y_embedder.to_out.weight"] = original_state_dict.pop( - "time_embed.pooler.to_out.weight" - ) + "time_embed.pooler.to_out.weight") new_state_dict["t5_y_embedder.to_out.bias"] = original_state_dict.pop( - "time_embed.pooler.to_out.bias" - ) + "time_embed.pooler.to_out.bias") new_state_dict["t5_yproj.weight"] = original_state_dict.pop( - "time_embed.caption_proj.weight" - ) + "time_embed.caption_proj.weight") new_state_dict["t5_yproj.bias"] = original_state_dict.pop( - "time_embed.caption_proj.bias" - ) + "time_embed.caption_proj.bias") # Convert transformer blocks num_layers = 48 @@ -88,26 +78,24 @@ def convert_diffusers_transformer_to_mochi(state_dict): # norm1 new_state_dict[new_prefix + "mod_x.weight"] = original_state_dict.pop( - block_prefix + "norm1.linear.weight" - ) + block_prefix + "norm1.linear.weight") new_state_dict[new_prefix + "mod_x.bias"] = original_state_dict.pop( - block_prefix + "norm1.linear.bias" - ) + block_prefix + "norm1.linear.bias") if i < num_layers - 1: - new_state_dict[new_prefix + "mod_y.weight"] = original_state_dict.pop( - block_prefix + "norm1_context.linear.weight" - ) - new_state_dict[new_prefix + "mod_y.bias"] = original_state_dict.pop( - block_prefix + "norm1_context.linear.bias" - ) + new_state_dict[new_prefix + + "mod_y.weight"] = original_state_dict.pop( + block_prefix + "norm1_context.linear.weight") + new_state_dict[new_prefix + + "mod_y.bias"] = original_state_dict.pop( + block_prefix + "norm1_context.linear.bias") else: - new_state_dict[new_prefix + "mod_y.weight"] = original_state_dict.pop( - block_prefix + "norm1_context.linear_1.weight" - ) - new_state_dict[new_prefix + "mod_y.bias"] = original_state_dict.pop( - block_prefix + "norm1_context.linear_1.bias" - ) + new_state_dict[new_prefix + + "mod_y.weight"] = original_state_dict.pop( + block_prefix + "norm1_context.linear_1.weight") + new_state_dict[new_prefix + + "mod_y.bias"] = original_state_dict.pop( + block_prefix + "norm1_context.linear_1.bias") # Visual attention q = original_state_dict.pop(block_prefix + "attn1.to_q.weight") @@ -116,18 +104,18 @@ def convert_diffusers_transformer_to_mochi(state_dict): qkv_weight = torch.cat([q, k, v], dim=0) new_state_dict[new_prefix + "attn.qkv_x.weight"] = qkv_weight - new_state_dict[new_prefix + "attn.q_norm_x.weight"] = original_state_dict.pop( - block_prefix + "attn1.norm_q.weight" - ) - new_state_dict[new_prefix + "attn.k_norm_x.weight"] = original_state_dict.pop( - block_prefix + "attn1.norm_k.weight" - ) - new_state_dict[new_prefix + "attn.proj_x.weight"] = original_state_dict.pop( - block_prefix + "attn1.to_out.0.weight" - ) - new_state_dict[new_prefix + "attn.proj_x.bias"] = original_state_dict.pop( - block_prefix + "attn1.to_out.0.bias" - ) + new_state_dict[new_prefix + + "attn.q_norm_x.weight"] = original_state_dict.pop( + block_prefix + "attn1.norm_q.weight") + new_state_dict[new_prefix + + "attn.k_norm_x.weight"] = original_state_dict.pop( + block_prefix + "attn1.norm_k.weight") + new_state_dict[new_prefix + + "attn.proj_x.weight"] = original_state_dict.pop( + block_prefix + "attn1.to_out.0.weight") + new_state_dict[new_prefix + + "attn.proj_x.bias"] = original_state_dict.pop( + block_prefix + "attn1.to_out.0.bias") # Context attention q = original_state_dict.pop(block_prefix + "attn1.add_q_proj.weight") @@ -136,48 +124,46 @@ def convert_diffusers_transformer_to_mochi(state_dict): qkv_weight = torch.cat([q, k, v], dim=0) new_state_dict[new_prefix + "attn.qkv_y.weight"] = qkv_weight - new_state_dict[new_prefix + "attn.q_norm_y.weight"] = original_state_dict.pop( - block_prefix + "attn1.norm_added_q.weight" - ) - new_state_dict[new_prefix + "attn.k_norm_y.weight"] = original_state_dict.pop( - block_prefix + "attn1.norm_added_k.weight" - ) + new_state_dict[new_prefix + + "attn.q_norm_y.weight"] = original_state_dict.pop( + block_prefix + "attn1.norm_added_q.weight") + new_state_dict[new_prefix + + "attn.k_norm_y.weight"] = original_state_dict.pop( + block_prefix + "attn1.norm_added_k.weight") if i < num_layers - 1: - new_state_dict[new_prefix + "attn.proj_y.weight"] = original_state_dict.pop( - block_prefix + "attn1.to_add_out.weight" - ) - new_state_dict[new_prefix + "attn.proj_y.bias"] = original_state_dict.pop( - block_prefix + "attn1.to_add_out.bias" - ) + new_state_dict[new_prefix + + "attn.proj_y.weight"] = original_state_dict.pop( + block_prefix + "attn1.to_add_out.weight") + new_state_dict[new_prefix + + "attn.proj_y.bias"] = original_state_dict.pop( + block_prefix + "attn1.to_add_out.bias") # MLP new_state_dict[new_prefix + "mlp_x.w1.weight"] = reverse_proj_gate( - original_state_dict.pop(block_prefix + "ff.net.0.proj.weight") - ) - new_state_dict[new_prefix + "mlp_x.w2.weight"] = original_state_dict.pop( - block_prefix + "ff.net.2.weight" - ) + original_state_dict.pop(block_prefix + "ff.net.0.proj.weight")) + new_state_dict[new_prefix + + "mlp_x.w2.weight"] = original_state_dict.pop( + block_prefix + "ff.net.2.weight") if i < num_layers - 1: new_state_dict[new_prefix + "mlp_y.w1.weight"] = reverse_proj_gate( - original_state_dict.pop(block_prefix + "ff_context.net.0.proj.weight") - ) - new_state_dict[new_prefix + "mlp_y.w2.weight"] = original_state_dict.pop( - block_prefix + "ff_context.net.2.weight" - ) + original_state_dict.pop(block_prefix + + "ff_context.net.0.proj.weight")) + new_state_dict[new_prefix + + "mlp_y.w2.weight"] = original_state_dict.pop( + block_prefix + "ff_context.net.2.weight") # Output layers new_state_dict["final_layer.mod.weight"] = reverse_scale_shift( - original_state_dict.pop("norm_out.linear.weight"), dim=0 - ) + original_state_dict.pop("norm_out.linear.weight"), dim=0) new_state_dict["final_layer.mod.bias"] = reverse_scale_shift( - original_state_dict.pop("norm_out.linear.bias"), dim=0 - ) + original_state_dict.pop("norm_out.linear.bias"), dim=0) new_state_dict["final_layer.linear.weight"] = original_state_dict.pop( - "proj_out.weight" - ) - new_state_dict["final_layer.linear.bias"] = original_state_dict.pop("proj_out.bias") + "proj_out.weight") + new_state_dict["final_layer.linear.bias"] = original_state_dict.pop( + "proj_out.bias") - new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies") + new_state_dict["pos_frequencies"] = original_state_dict.pop( + "pos_frequencies") print("Remaining Keys:", original_state_dict.keys()) @@ -193,307 +179,270 @@ def convert_diffusers_vae_to_mochi(state_dict): prefix = "encoder." encoder_state_dict["layers.0.weight"] = original_state_dict.pop( - f"{prefix}proj_in.weight" - ) + f"{prefix}proj_in.weight") encoder_state_dict["layers.0.bias"] = original_state_dict.pop( - f"{prefix}proj_in.bias" - ) + f"{prefix}proj_in.bias") # Convert block_in for i in range(3): - encoder_state_dict[f"layers.{i+1}.stack.0.weight"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight" - ) - encoder_state_dict[f"layers.{i+1}.stack.0.bias"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias" - ) - encoder_state_dict[f"layers.{i+1}.stack.2.weight"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.conv1.conv.weight" - ) - encoder_state_dict[f"layers.{i+1}.stack.2.bias"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.conv1.conv.bias" - ) - encoder_state_dict[f"layers.{i+1}.stack.3.weight"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight" - ) - encoder_state_dict[f"layers.{i+1}.stack.3.bias"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias" - ) - encoder_state_dict[f"layers.{i+1}.stack.5.weight"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.conv2.conv.weight" - ) - encoder_state_dict[f"layers.{i+1}.stack.5.bias"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.conv2.conv.bias" - ) + encoder_state_dict[ + f"layers.{i+1}.stack.0.weight"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight") + encoder_state_dict[ + f"layers.{i+1}.stack.0.bias"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias") + encoder_state_dict[ + f"layers.{i+1}.stack.2.weight"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.conv1.conv.weight") + encoder_state_dict[ + f"layers.{i+1}.stack.2.bias"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.conv1.conv.bias") + encoder_state_dict[ + f"layers.{i+1}.stack.3.weight"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight") + encoder_state_dict[ + f"layers.{i+1}.stack.3.bias"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias") + encoder_state_dict[ + f"layers.{i+1}.stack.5.weight"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.conv2.conv.weight") + encoder_state_dict[ + f"layers.{i+1}.stack.5.bias"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.conv2.conv.bias") # Convert down_blocks down_block_layers = [3, 4, 6] for block in range(3): encoder_state_dict[ - f"layers.{block+4}.layers.0.weight" - ] = original_state_dict.pop(f"{prefix}down_blocks.{block}.conv_in.conv.weight") - encoder_state_dict[f"layers.{block+4}.layers.0.bias"] = original_state_dict.pop( - f"{prefix}down_blocks.{block}.conv_in.conv.bias" - ) + f"layers.{block+4}.layers.0.weight"] = original_state_dict.pop( + f"{prefix}down_blocks.{block}.conv_in.conv.weight") + encoder_state_dict[ + f"layers.{block+4}.layers.0.bias"] = original_state_dict.pop( + f"{prefix}down_blocks.{block}.conv_in.conv.bias") for i in range(down_block_layers[block]): # Convert resnets encoder_state_dict[ - f"layers.{block+4}.layers.{i+1}.stack.0.weight" - ] = original_state_dict.pop( - f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight" - ) + f"layers.{block+4}.layers.{i+1}.stack.0.weight"] = original_state_dict.pop( + f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight" + ) encoder_state_dict[ - f"layers.{block+4}.layers.{i+1}.stack.0.bias" - ] = original_state_dict.pop( - f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias" - ) + f"layers.{block+4}.layers.{i+1}.stack.0.bias"] = original_state_dict.pop( + f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias" + ) encoder_state_dict[ - f"layers.{block+4}.layers.{i+1}.stack.2.weight" - ] = original_state_dict.pop( - f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight" - ) + f"layers.{block+4}.layers.{i+1}.stack.2.weight"] = original_state_dict.pop( + f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight" + ) encoder_state_dict[ - f"layers.{block+4}.layers.{i+1}.stack.2.bias" - ] = original_state_dict.pop( - f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias" - ) + f"layers.{block+4}.layers.{i+1}.stack.2.bias"] = original_state_dict.pop( + f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias") encoder_state_dict[ - f"layers.{block+4}.layers.{i+1}.stack.3.weight" - ] = original_state_dict.pop( - f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight" - ) + f"layers.{block+4}.layers.{i+1}.stack.3.weight"] = original_state_dict.pop( + f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight" + ) encoder_state_dict[ - f"layers.{block+4}.layers.{i+1}.stack.3.bias" - ] = original_state_dict.pop( - f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias" - ) + f"layers.{block+4}.layers.{i+1}.stack.3.bias"] = original_state_dict.pop( + f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias" + ) encoder_state_dict[ - f"layers.{block+4}.layers.{i+1}.stack.5.weight" - ] = original_state_dict.pop( - f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight" - ) + f"layers.{block+4}.layers.{i+1}.stack.5.weight"] = original_state_dict.pop( + f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight" + ) encoder_state_dict[ - f"layers.{block+4}.layers.{i+1}.stack.5.bias" - ] = original_state_dict.pop( - f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias" - ) + f"layers.{block+4}.layers.{i+1}.stack.5.bias"] = original_state_dict.pop( + f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias") # Convert attentions q = original_state_dict.pop( - f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight" - ) + f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight") k = original_state_dict.pop( - f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight" - ) + f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight") v = original_state_dict.pop( - f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight" - ) + f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight") qkv_weight = torch.cat([q, k, v], dim=0) encoder_state_dict[ - f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight" - ] = qkv_weight + f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight"] = qkv_weight encoder_state_dict[ - f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight" - ] = original_state_dict.pop( - f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight" - ) + f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight"] = original_state_dict.pop( + f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight" + ) encoder_state_dict[ - f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias" - ] = original_state_dict.pop( - f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias" - ) + f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias"] = original_state_dict.pop( + f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias" + ) encoder_state_dict[ - f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight" - ] = original_state_dict.pop( - f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight" - ) + f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight"] = original_state_dict.pop( + f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight") encoder_state_dict[ - f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias" - ] = original_state_dict.pop( - f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias" - ) + f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias"] = original_state_dict.pop( + f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias") # Convert block_out for i in range(3): - encoder_state_dict[f"layers.{i+7}.stack.0.weight"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight" - ) - encoder_state_dict[f"layers.{i+7}.stack.0.bias"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias" - ) - encoder_state_dict[f"layers.{i+7}.stack.2.weight"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.conv1.conv.weight" - ) - encoder_state_dict[f"layers.{i+7}.stack.2.bias"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.conv1.conv.bias" - ) - encoder_state_dict[f"layers.{i+7}.stack.3.weight"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight" - ) - encoder_state_dict[f"layers.{i+7}.stack.3.bias"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias" - ) - encoder_state_dict[f"layers.{i+7}.stack.5.weight"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.conv2.conv.weight" - ) - encoder_state_dict[f"layers.{i+7}.stack.5.bias"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.conv2.conv.bias" - ) - - q = original_state_dict.pop(f"{prefix}block_out.attentions.{i}.to_q.weight") - k = original_state_dict.pop(f"{prefix}block_out.attentions.{i}.to_k.weight") - v = original_state_dict.pop(f"{prefix}block_out.attentions.{i}.to_v.weight") + encoder_state_dict[ + f"layers.{i+7}.stack.0.weight"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight") + encoder_state_dict[ + f"layers.{i+7}.stack.0.bias"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias") + encoder_state_dict[ + f"layers.{i+7}.stack.2.weight"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.conv1.conv.weight") + encoder_state_dict[ + f"layers.{i+7}.stack.2.bias"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.conv1.conv.bias") + encoder_state_dict[ + f"layers.{i+7}.stack.3.weight"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight") + encoder_state_dict[ + f"layers.{i+7}.stack.3.bias"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias") + encoder_state_dict[ + f"layers.{i+7}.stack.5.weight"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.conv2.conv.weight") + encoder_state_dict[ + f"layers.{i+7}.stack.5.bias"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.conv2.conv.bias") + + q = original_state_dict.pop( + f"{prefix}block_out.attentions.{i}.to_q.weight") + k = original_state_dict.pop( + f"{prefix}block_out.attentions.{i}.to_k.weight") + v = original_state_dict.pop( + f"{prefix}block_out.attentions.{i}.to_v.weight") qkv_weight = torch.cat([q, k, v], dim=0) - encoder_state_dict[f"layers.{i+7}.attn_block.attn.qkv.weight"] = qkv_weight + encoder_state_dict[ + f"layers.{i+7}.attn_block.attn.qkv.weight"] = qkv_weight encoder_state_dict[ - f"layers.{i+7}.attn_block.attn.out.weight" - ] = original_state_dict.pop(f"{prefix}block_out.attentions.{i}.to_out.0.weight") + f"layers.{i+7}.attn_block.attn.out.weight"] = original_state_dict.pop( + f"{prefix}block_out.attentions.{i}.to_out.0.weight") encoder_state_dict[ - f"layers.{i+7}.attn_block.attn.out.bias" - ] = original_state_dict.pop(f"{prefix}block_out.attentions.{i}.to_out.0.bias") + f"layers.{i+7}.attn_block.attn.out.bias"] = original_state_dict.pop( + f"{prefix}block_out.attentions.{i}.to_out.0.bias") encoder_state_dict[ - f"layers.{i+7}.attn_block.norm.weight" - ] = original_state_dict.pop(f"{prefix}block_out.norms.{i}.norm_layer.weight") + f"layers.{i+7}.attn_block.norm.weight"] = original_state_dict.pop( + f"{prefix}block_out.norms.{i}.norm_layer.weight") encoder_state_dict[ - f"layers.{i+7}.attn_block.norm.bias" - ] = original_state_dict.pop(f"{prefix}block_out.norms.{i}.norm_layer.bias") + f"layers.{i+7}.attn_block.norm.bias"] = original_state_dict.pop( + f"{prefix}block_out.norms.{i}.norm_layer.bias") # Convert output layers encoder_state_dict["output_norm.weight"] = original_state_dict.pop( - f"{prefix}norm_out.norm_layer.weight" - ) + f"{prefix}norm_out.norm_layer.weight") encoder_state_dict["output_norm.bias"] = original_state_dict.pop( - f"{prefix}norm_out.norm_layer.bias" - ) + f"{prefix}norm_out.norm_layer.bias") encoder_state_dict["output_proj.weight"] = original_state_dict.pop( - f"{prefix}proj_out.weight" - ) + f"{prefix}proj_out.weight") # Convert decoder prefix = "decoder." decoder_state_dict["blocks.0.0.weight"] = original_state_dict.pop( - f"{prefix}conv_in.weight" - ) + f"{prefix}conv_in.weight") decoder_state_dict["blocks.0.0.bias"] = original_state_dict.pop( - f"{prefix}conv_in.bias" - ) + f"{prefix}conv_in.bias") # Convert block_in for i in range(3): - decoder_state_dict[f"blocks.0.{i+1}.stack.0.weight"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight" - ) - decoder_state_dict[f"blocks.0.{i+1}.stack.0.bias"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias" - ) - decoder_state_dict[f"blocks.0.{i+1}.stack.2.weight"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.conv1.conv.weight" - ) - decoder_state_dict[f"blocks.0.{i+1}.stack.2.bias"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.conv1.conv.bias" - ) - decoder_state_dict[f"blocks.0.{i+1}.stack.3.weight"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight" - ) - decoder_state_dict[f"blocks.0.{i+1}.stack.3.bias"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias" - ) - decoder_state_dict[f"blocks.0.{i+1}.stack.5.weight"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.conv2.conv.weight" - ) - decoder_state_dict[f"blocks.0.{i+1}.stack.5.bias"] = original_state_dict.pop( - f"{prefix}block_in.resnets.{i}.conv2.conv.bias" - ) + decoder_state_dict[ + f"blocks.0.{i+1}.stack.0.weight"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight") + decoder_state_dict[ + f"blocks.0.{i+1}.stack.0.bias"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias") + decoder_state_dict[ + f"blocks.0.{i+1}.stack.2.weight"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.conv1.conv.weight") + decoder_state_dict[ + f"blocks.0.{i+1}.stack.2.bias"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.conv1.conv.bias") + decoder_state_dict[ + f"blocks.0.{i+1}.stack.3.weight"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight") + decoder_state_dict[ + f"blocks.0.{i+1}.stack.3.bias"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias") + decoder_state_dict[ + f"blocks.0.{i+1}.stack.5.weight"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.conv2.conv.weight") + decoder_state_dict[ + f"blocks.0.{i+1}.stack.5.bias"] = original_state_dict.pop( + f"{prefix}block_in.resnets.{i}.conv2.conv.bias") # Convert up_blocks up_block_layers = [6, 4, 3] for block in range(3): for i in range(up_block_layers[block]): decoder_state_dict[ - f"blocks.{block+1}.blocks.{i}.stack.0.weight" - ] = original_state_dict.pop( - f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight" - ) + f"blocks.{block+1}.blocks.{i}.stack.0.weight"] = original_state_dict.pop( + f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight" + ) decoder_state_dict[ - f"blocks.{block+1}.blocks.{i}.stack.0.bias" - ] = original_state_dict.pop( - f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias" - ) + f"blocks.{block+1}.blocks.{i}.stack.0.bias"] = original_state_dict.pop( + f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias" + ) decoder_state_dict[ - f"blocks.{block+1}.blocks.{i}.stack.2.weight" - ] = original_state_dict.pop( - f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight" - ) + f"blocks.{block+1}.blocks.{i}.stack.2.weight"] = original_state_dict.pop( + f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight") decoder_state_dict[ - f"blocks.{block+1}.blocks.{i}.stack.2.bias" - ] = original_state_dict.pop( - f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias" - ) + f"blocks.{block+1}.blocks.{i}.stack.2.bias"] = original_state_dict.pop( + f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias") decoder_state_dict[ - f"blocks.{block+1}.blocks.{i}.stack.3.weight" - ] = original_state_dict.pop( - f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight" - ) + f"blocks.{block+1}.blocks.{i}.stack.3.weight"] = original_state_dict.pop( + f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight" + ) decoder_state_dict[ - f"blocks.{block+1}.blocks.{i}.stack.3.bias" - ] = original_state_dict.pop( - f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias" - ) + f"blocks.{block+1}.blocks.{i}.stack.3.bias"] = original_state_dict.pop( + f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias" + ) decoder_state_dict[ - f"blocks.{block+1}.blocks.{i}.stack.5.weight" - ] = original_state_dict.pop( - f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight" - ) + f"blocks.{block+1}.blocks.{i}.stack.5.weight"] = original_state_dict.pop( + f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight") decoder_state_dict[ - f"blocks.{block+1}.blocks.{i}.stack.5.bias" - ] = original_state_dict.pop( - f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias" - ) - decoder_state_dict[f"blocks.{block+1}.proj.weight"] = original_state_dict.pop( - f"{prefix}up_blocks.{block}.proj.weight" - ) - decoder_state_dict[f"blocks.{block+1}.proj.bias"] = original_state_dict.pop( - f"{prefix}up_blocks.{block}.proj.bias" - ) + f"blocks.{block+1}.blocks.{i}.stack.5.bias"] = original_state_dict.pop( + f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias") + decoder_state_dict[ + f"blocks.{block+1}.proj.weight"] = original_state_dict.pop( + f"{prefix}up_blocks.{block}.proj.weight") + decoder_state_dict[ + f"blocks.{block+1}.proj.bias"] = original_state_dict.pop( + f"{prefix}up_blocks.{block}.proj.bias") # Convert block_out for i in range(3): - decoder_state_dict[f"blocks.4.{i}.stack.0.weight"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight" - ) - decoder_state_dict[f"blocks.4.{i}.stack.0.bias"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias" - ) - decoder_state_dict[f"blocks.4.{i}.stack.2.weight"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.conv1.conv.weight" - ) - decoder_state_dict[f"blocks.4.{i}.stack.2.bias"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.conv1.conv.bias" - ) - decoder_state_dict[f"blocks.4.{i}.stack.3.weight"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight" - ) - decoder_state_dict[f"blocks.4.{i}.stack.3.bias"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias" - ) - decoder_state_dict[f"blocks.4.{i}.stack.5.weight"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.conv2.conv.weight" - ) - decoder_state_dict[f"blocks.4.{i}.stack.5.bias"] = original_state_dict.pop( - f"{prefix}block_out.resnets.{i}.conv2.conv.bias" - ) + decoder_state_dict[ + f"blocks.4.{i}.stack.0.weight"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight") + decoder_state_dict[ + f"blocks.4.{i}.stack.0.bias"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias") + decoder_state_dict[ + f"blocks.4.{i}.stack.2.weight"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.conv1.conv.weight") + decoder_state_dict[ + f"blocks.4.{i}.stack.2.bias"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.conv1.conv.bias") + decoder_state_dict[ + f"blocks.4.{i}.stack.3.weight"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight") + decoder_state_dict[ + f"blocks.4.{i}.stack.3.bias"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias") + decoder_state_dict[ + f"blocks.4.{i}.stack.5.weight"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.conv2.conv.weight") + decoder_state_dict[ + f"blocks.4.{i}.stack.5.bias"] = original_state_dict.pop( + f"{prefix}block_out.resnets.{i}.conv2.conv.bias") # Convert output layers decoder_state_dict["output_proj.weight"] = original_state_dict.pop( - f"{prefix}proj_out.weight" - ) + f"{prefix}proj_out.weight") decoder_state_dict["output_proj.bias"] = original_state_dict.pop( - f"{prefix}proj_out.bias" - ) + f"{prefix}proj_out.bias") return encoder_state_dict, decoder_state_dict @@ -519,10 +468,9 @@ def main(args): transformer_path = ensure_safetensors_extension(args.transformer_path) ensure_directory_exists(transformer_path) - print(f"Converting transformer model...") + print("Converting transformer model...") transformer_state_dict = convert_diffusers_transformer_to_mochi( - pipe.transformer.state_dict() - ) + pipe.transformer.state_dict()) save_file(transformer_state_dict, transformer_path) print(f"Saved transformer to {transformer_path}") @@ -533,10 +481,9 @@ def main(args): ensure_directory_exists(encoder_path) ensure_directory_exists(decoder_path) - print(f"Converting VAE models...") + print("Converting VAE models...") encoder_state_dict, decoder_state_dict = convert_diffusers_vae_to_mochi( - pipe.vae.state_dict() - ) + pipe.vae.state_dict()) save_file(encoder_state_dict, encoder_path) print(f"Saved VAE encoder to {encoder_path}") diff --git a/fastvideo/models/mochi_hf/mochi_latents_utils.py b/fastvideo/models/mochi_hf/mochi_latents_utils.py index 7c9249f..b43de39 100644 --- a/fastvideo/models/mochi_hf/mochi_latents_utils.py +++ b/fastvideo/models/mochi_hf/mochi_latents_utils.py @@ -1,37 +1,33 @@ import torch -mochi_latents_mean = torch.tensor( - [ - -0.06730895953510081, - -0.038011381506090416, - -0.07477820912866141, - -0.05565264470995561, - 0.012767231469026969, - -0.04703542746246419, - 0.043896967884726704, - -0.09346305707025976, - -0.09918314763016893, - -0.008729793427399178, - -0.011931556316503654, - -0.0321993391887285, - ] -).view(1, 12, 1, 1, 1) -mochi_latents_std = torch.tensor( - [ - 0.9263795028493863, - 0.9248894543193766, - 0.9393059390890617, - 0.959253732819592, - 0.8244560132752793, - 0.917259975397747, - 0.9294154431013696, - 1.3720942357788521, - 0.881393668867029, - 0.9168315692124348, - 0.9185249279345552, - 0.9274757570805041, - ] -).view(1, 12, 1, 1, 1) +mochi_latents_mean = torch.tensor([ + -0.06730895953510081, + -0.038011381506090416, + -0.07477820912866141, + -0.05565264470995561, + 0.012767231469026969, + -0.04703542746246419, + 0.043896967884726704, + -0.09346305707025976, + -0.09918314763016893, + -0.008729793427399178, + -0.011931556316503654, + -0.0321993391887285, +]).view(1, 12, 1, 1, 1) +mochi_latents_std = torch.tensor([ + 0.9263795028493863, + 0.9248894543193766, + 0.9393059390890617, + 0.959253732819592, + 0.8244560132752793, + 0.917259975397747, + 0.9294154431013696, + 1.3720942357788521, + 0.881393668867029, + 0.9168315692124348, + 0.9185249279345552, + 0.9274757570805041, +]).view(1, 12, 1, 1, 1) mochi_scaling_factor = 1.0 diff --git a/fastvideo/models/mochi_hf/modeling_mochi.py b/fastvideo/models/mochi_hf/modeling_mochi.py index a57d0fa..330f31a 100644 --- a/fastvideo/models/mochi_hf/modeling_mochi.py +++ b/fastvideo/models/mochi_hf/modeling_mochi.py @@ -16,47 +16,33 @@ import torch import torch.nn as nn -import diffusers +import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.utils import is_torch_version, logging -from diffusers.utils import ( - USE_PEFT_BACKEND, - is_torch_version, - logging, - scale_lora_layers, - unscale_lora_layers, -) -from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.loaders import PeftAdapterMixin from diffusers.models.attention import FeedForward as HF_FeedForward from diffusers.models.attention_processor import Attention -from diffusers.models.embeddings import ( - MochiCombinedTimestepCaptionEmbedding, - PatchEmbed, -) -from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.embeddings import (MochiCombinedTimestepCaptionEmbedding, + PatchEmbed) from diffusers.models.modeling_utils import ModelMixin -from diffusers.loaders import PeftAdapterMixin -from fastvideo.models.mochi_hf.norm import ( - MochiLayerNormContinuous, - MochiRMSNormZero, - MochiModulatedRMSNorm, - MochiRMSNorm, -) from diffusers.models.normalization import AdaLayerNormContinuous - -from fastvideo.utils.parallel_states import get_sequence_parallel_state, nccl_info -from fastvideo.utils.communications import all_gather, all_to_all_4D -import torch.nn.functional as F -from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph +from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging, + scale_lora_layers, unscale_lora_layers) +from diffusers.utils.torch_utils import maybe_allow_in_graph +from liger_kernel.ops.swiglu import LigerSiLUMulFunction from fastvideo.models.flash_attn_no_pad import flash_attn_no_pad - -from liger_kernel.ops.swiglu import LigerSiLUMulFunction +from fastvideo.models.mochi_hf.norm import (MochiLayerNormContinuous, + MochiModulatedRMSNorm, + MochiRMSNorm, MochiRMSNormZero) +from fastvideo.utils.communications import all_gather, all_to_all_4D +from fastvideo.utils.parallel_states import (get_sequence_parallel_state, + nccl_info) logger = logging.get_logger(__name__) # pylint: disable=invalid-name class FeedForward(HF_FeedForward): + def __init__( self, dim: int, @@ -68,9 +54,8 @@ def __init__( inner_dim=None, bias: bool = True, ): - super().__init__( - dim, dim_out, mult, dropout, activation_fn, final_dropout, inner_dim, bias - ) + super().__init__(dim, dim_out, mult, dropout, activation_fn, + final_dropout, inner_dim, bias) assert activation_fn == "swiglu" def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -81,6 +66,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MochiAttention(nn.Module): + def __init__( self, query_dim: int, @@ -114,25 +100,26 @@ def __init__( self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) - self.add_k_proj = nn.Linear( - added_kv_proj_dim, self.inner_dim, bias=added_proj_bias - ) - self.add_v_proj = nn.Linear( - added_kv_proj_dim, self.inner_dim, bias=added_proj_bias - ) + self.add_k_proj = nn.Linear(added_kv_proj_dim, + self.inner_dim, + bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, + self.inner_dim, + bias=added_proj_bias) if self.context_pre_only is not None: - self.add_q_proj = nn.Linear( - added_kv_proj_dim, self.inner_dim, bias=added_proj_bias - ) + self.add_q_proj = nn.Linear(added_kv_proj_dim, + self.inner_dim, + bias=added_proj_bias) self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append( + nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) if not self.context_pre_only: - self.to_add_out = nn.Linear( - self.inner_dim, self.out_context_dim, bias=out_bias - ) + self.to_add_out = nn.Linear(self.inner_dim, + self.out_context_dim, + bias=out_bias) self.processor = processor @@ -212,8 +199,8 @@ def __call__( def shrink_head(encoder_state, dim): local_heads = encoder_state.shape[dim] // nccl_info.sp_size return encoder_state.narrow( - dim, nccl_info.rank_within_group * local_heads, local_heads - ) + dim, nccl_info.rank_within_group * local_heads, + local_heads) encoder_query = shrink_head(encoder_query, dim=2) encoder_key = shrink_head(encoder_key, dim=2) @@ -254,9 +241,11 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): attn_mask = encoder_attention_mask[:, :].bool() attn_mask = F.pad(attn_mask, (sequence_length, 0), value=True) - hidden_states = flash_attn_no_pad( - qkv, attn_mask, causal=False, dropout_p=0.0, softmax_scale=None - ) + hidden_states = flash_attn_no_pad(qkv, + attn_mask, + causal=False, + dropout_p=0.0, + softmax_scale=None) # hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask = None, dropout_p=0.0, is_causal=False) @@ -267,13 +256,13 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): # hidden_states = flex_attention(query, key, value, score_mod=no_padding_mask) if get_sequence_parallel_state(): hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( - (sequence_length, encoder_sequence_length), dim=1 - ) + (sequence_length, encoder_sequence_length), dim=1) # B, S, H, D - hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2) - encoder_hidden_states = all_gather( - encoder_hidden_states, dim=2 - ).contiguous() + hidden_states = all_to_all_4D(hidden_states, + scatter_dim=1, + gather_dim=2) + encoder_hidden_states = all_gather(encoder_hidden_states, + dim=2).contiguous() hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) encoder_hidden_states = encoder_hidden_states.flatten(2, 3) @@ -283,8 +272,7 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): hidden_states = hidden_states.to(query.dtype) hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( - (sequence_length, encoder_sequence_length), dim=1 - ) + (sequence_length, encoder_sequence_length), dim=1) # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -336,12 +324,16 @@ def __init__( self.ff_inner_dim = (4 * dim * 2) // 3 self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3 - self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False) + self.norm1 = MochiRMSNormZero(dim, + 4 * dim, + eps=eps, + elementwise_affine=False) if not context_pre_only: - self.norm1_context = MochiRMSNormZero( - dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False - ) + self.norm1_context = MochiRMSNormZero(dim, + 4 * pooled_projection_dim, + eps=eps, + elementwise_affine=False) else: self.norm1_context = MochiLayerNormContinuous( embedding_dim=pooled_projection_dim, @@ -365,18 +357,17 @@ def __init__( # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True self.norm2 = MochiModulatedRMSNorm(eps=eps) - self.norm2_context = ( - MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None - ) + self.norm2_context = (MochiModulatedRMSNorm( + eps=eps) if not self.context_pre_only else None) self.norm3 = MochiModulatedRMSNorm(eps) - self.norm3_context = ( - MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None - ) + self.norm3_context = (MochiModulatedRMSNorm( + eps=eps) if not self.context_pre_only else None) - self.ff = FeedForward( - dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False - ) + self.ff = FeedForward(dim, + inner_dim=self.ff_inner_dim, + activation_fn=activation_fn, + bias=False) self.ff_context = None if not context_pre_only: self.ff_context = FeedForward( @@ -399,8 +390,7 @@ def forward( output_attn=False, ) -> Tuple[torch.Tensor, torch.Tensor]: norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1( - hidden_states, temb - ) + hidden_states, temb) if not self.context_pre_only: ( @@ -410,7 +400,8 @@ def forward( enc_gate_mlp, ) = self.norm1_context(encoder_hidden_states, temb) else: - norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + norm_encoder_hidden_states = self.norm1_context( + encoder_hidden_states, temb) attn_hidden_states, context_attn_hidden_states = self.attn1( hidden_states=norm_hidden_states, @@ -420,28 +411,27 @@ def forward( ) hidden_states = hidden_states + self.norm2( - attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1) - ) + attn_hidden_states, + torch.tanh(gate_msa).unsqueeze(1)) norm_hidden_states = self.norm3( - hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)) - ) + hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32))) ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + self.norm4( - ff_output, torch.tanh(gate_mlp).unsqueeze(1) - ) + ff_output, + torch.tanh(gate_mlp).unsqueeze(1)) if not self.context_pre_only: encoder_hidden_states = encoder_hidden_states + self.norm2_context( - context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1) - ) + context_attn_hidden_states, + torch.tanh(enc_gate_msa).unsqueeze(1)) norm_encoder_hidden_states = self.norm3_context( encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32)), ) context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + self.norm4_context( - context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1) - ) + context_ff_output, + torch.tanh(enc_gate_mlp).unsqueeze(1)) if not output_attn: attn_hidden_states = None @@ -465,7 +455,11 @@ def __init__(self, base_height: int = 192, base_width: int = 192) -> None: self.target_area = base_height * base_width def _centers(self, start, stop, num, device, dtype) -> torch.Tensor: - edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype) + edges = torch.linspace(start, + stop, + num + 1, + device=device, + dtype=dtype) return (edges[:-1] + edges[1:]) / 2 def _get_positions( @@ -476,24 +470,28 @@ def _get_positions( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - scale = (self.target_area / (height * width)) ** 0.5 - t = torch.arange(num_frames * nccl_info.sp_size, device=device, dtype=dtype) - h = self._centers( - -height * scale / 2, height * scale / 2, height, device, dtype - ) - w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype) + scale = (self.target_area / (height * width))**0.5 + t = torch.arange(num_frames * nccl_info.sp_size, + device=device, + dtype=dtype) + h = self._centers(-height * scale / 2, height * scale / 2, height, + device, dtype) + w = self._centers(-width * scale / 2, width * scale / 2, width, device, + dtype) grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3) return positions - def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: + def _create_rope(self, freqs: torch.Tensor, + pos: torch.Tensor) -> torch.Tensor: with torch.autocast(freqs.device.type, enabled=False): # Always run ROPE freqs computation in FP32 freqs = torch.einsum( - "nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32) - ) + "nd,dhf->nhf", # codespell:ignore + pos.to(torch.float32), # codespell:ignore + freqs.to(torch.float32)) freqs_cos = torch.cos(freqs) freqs_sin = torch.sin(freqs) return freqs_cos, freqs_sin @@ -581,24 +579,20 @@ def __init__( ) self.pos_frequencies = nn.Parameter( - torch.full((3, num_attention_heads, attention_head_dim // 2), 0.0) - ) + torch.full((3, num_attention_heads, attention_head_dim // 2), 0.0)) self.rope = MochiRoPE() - self.transformer_blocks = nn.ModuleList( - [ - MochiTransformerBlock( - dim=inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - pooled_projection_dim=pooled_projection_dim, - qk_norm=qk_norm, - activation_fn=activation_fn, - context_pre_only=i == num_layers - 1, - ) - for i in range(num_layers) - ] - ) + self.transformer_blocks = nn.ModuleList([ + MochiTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + pooled_projection_dim=pooled_projection_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + context_pre_only=i == num_layers - 1, + ) for i in range(num_layers) + ]) self.norm_out = AdaLayerNormContinuous( inner_dim, @@ -607,7 +601,8 @@ def __init__( eps=1e-6, norm_type="layer_norm", ) - self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + self.proj_out = nn.Linear(inner_dim, + patch_size * patch_size * out_channels) self.gradient_checkpointing = False @@ -626,9 +621,8 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = False, ) -> torch.Tensor: - assert ( - return_dict is False - ), "return_dict is not supported in MochiTransformer3DModel" + assert (return_dict is False + ), "return_dict is not supported in MochiTransformer3DModel" if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() @@ -640,10 +634,8 @@ def forward( # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if ( - attention_kwargs is not None - and attention_kwargs.get("scale", None) is not None - ): + if (attention_kwargs is not None + and attention_kwargs.get("scale", None) is not None): logger.warning( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) @@ -664,7 +656,8 @@ def forward( hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.patch_embed(hidden_states) - hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten( + 1, 2) image_rotary_emb = self.rope( self.pos_frequencies, @@ -679,14 +672,15 @@ def forward( if self.gradient_checkpointing: def create_custom_forward(module): + def custom_forward(*inputs): return module(*inputs) return custom_forward - ckpt_kwargs: Dict[str, Any] = ( - {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - ) + ckpt_kwargs: Dict[str, Any] = ({ + "use_reentrant": False + } if is_torch_version(">=", "1.11.0") else {}) ( hidden_states, encoder_hidden_states, @@ -716,11 +710,12 @@ def custom_forward(*inputs): hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape( - batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1 - ) + hidden_states = hidden_states.reshape(batch_size, num_frames, + post_patch_height, + post_patch_width, p, p, -1) hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) - output = hidden_states.reshape(batch_size, -1, num_frames, height, width) + output = hidden_states.reshape(batch_size, -1, num_frames, height, + width) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer diff --git a/fastvideo/models/mochi_hf/norm.py b/fastvideo/models/mochi_hf/norm.py index 3f7b4e6..f8fe455 100644 --- a/fastvideo/models/mochi_hf/norm.py +++ b/fastvideo/models/mochi_hf/norm.py @@ -13,15 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numbers -from typing import Dict, Optional, Tuple +from typing import Tuple import torch import torch.nn as nn -import torch.nn.functional as F class MochiModulatedRMSNorm(nn.Module): + def __init__(self, eps: float): super().__init__() @@ -41,6 +40,7 @@ def forward(self, hidden_states, scale=None): class MochiRMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine=True): super().__init__() @@ -66,18 +66,27 @@ def forward(self, hidden_states): class MochiLayerNormContinuous(nn.Module): + def __init__( - self, embedding_dim: int, conditioning_embedding_dim: int, eps=1e-5, bias=True, + self, + embedding_dim: int, + conditioning_embedding_dim: int, + eps=1e-5, + bias=True, ): super().__init__() # AdaLN self.silu = nn.SiLU() - self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + self.linear_1 = nn.Linear(conditioning_embedding_dim, + embedding_dim, + bias=bias) self.norm = MochiModulatedRMSNorm(eps=eps) def forward( - self, x: torch.Tensor, conditioning_embedding: torch.Tensor, + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, ) -> torch.Tensor: input_dtype = x.dtype @@ -116,9 +125,8 @@ def forward( emb = self.linear(self.silu(emb)) scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) - hidden_states = self.norm( - hidden_states, (1 + scale_msa[:, None].to(torch.float32)) - ) + hidden_states = self.norm(hidden_states, + (1 + scale_msa[:, None].to(torch.float32))) hidden_states = hidden_states.to(hidden_states_dtype) return hidden_states, gate_msa, scale_mlp, gate_mlp diff --git a/fastvideo/models/mochi_hf/pipeline_mochi.py b/fastvideo/models/mochi_hf/pipeline_mochi.py index 3bb9ead..e5dc9c4 100644 --- a/fastvideo/models/mochi_hf/pipeline_mochi.py +++ b/fastvideo/models/mochi_hf/pipeline_mochi.py @@ -12,31 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from typing import Callable, Dict, List, Optional, Union, Any import copy +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + import numpy as np import torch -from transformers import T5EncoderModel, T5TokenizerFast - from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.loaders import Mochi1LoraLoaderMixin from diffusers.models.autoencoders import AutoencoderKL -from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel - +from diffusers.pipelines.mochi.pipeline_output import MochiPipelineOutput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import FlowMatchEulerDiscreteScheduler -from diffusers.utils import ( - is_torch_xla_available, - logging, - replace_example_docstring, -) +from diffusers.utils import (is_torch_xla_available, logging, + replace_example_docstring) from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.mochi.pipeline_output import MochiPipelineOutput from einops import rearrange -from fastvideo.utils.parallel_states import get_sequence_parallel_state, nccl_info +from transformers import T5EncoderModel, T5TokenizerFast + +from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel from fastvideo.utils.communications import all_gather -from diffusers.loaders import Mochi1LoraLoaderMixin +from fastvideo.utils.parallel_states import (get_sequence_parallel_state, + nccl_info) + if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -44,7 +43,6 @@ else: XLA_AVAILABLE = False - logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ @@ -85,13 +83,13 @@ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): ] threshold_noise_step_diff = linear_steps - threshold_noise * num_steps quadratic_steps = num_steps - linear_steps - quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2) + quadratic_coef = threshold_noise_step_diff / (linear_steps * + quadratic_steps**2) linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / ( - quadratic_steps ** 2 - ) - const = quadratic_coef * (linear_steps ** 2) + quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) quadratic_sigma_schedule = [ - quadratic_coef * (i ** 2) + linear_coef * i + const + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) ] sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule @@ -137,8 +135,7 @@ def retrieve_timesteps( ) if timesteps is not None: accepts_timesteps = "timesteps" in set( - inspect.signature(scheduler.set_timesteps).parameters.keys() - ) + inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -149,8 +146,7 @@ def retrieve_timesteps( num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set( - inspect.signature(scheduler.set_timesteps).parameters.keys() - ) + inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -191,7 +187,9 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin): model_cpu_offload_seq = "text_encoder->transformer->vae" _optional_components = [] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _callback_tensor_inputs = [ + "latents", "prompt_embeds", "negative_prompt_embeds" + ] def __init__( self, @@ -215,13 +213,10 @@ def __init__( self.patch_size = 2 self.video_processor = VideoProcessor( - vae_scale_factor=self.vae_spatial_scale_factor - ) - self.tokenizer_max_length = ( - self.tokenizer.model_max_length - if hasattr(self, "tokenizer") and self.tokenizer is not None - else 77 - ) + vae_scale_factor=self.vae_spatial_scale_factor) + self.tokenizer_max_length = (self.tokenizer.model_max_length + if hasattr(self, "tokenizer") + and self.tokenizer is not None else 77) self.default_height = 480 self.default_width = 848 @@ -252,35 +247,31 @@ def _get_t5_prompt_embeds( prompt_attention_mask = text_inputs.attention_mask prompt_attention_mask = prompt_attention_mask.bool().to(device) - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids + untruncated_ids = self.tokenizer(prompt, + padding="longest", + return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, max_sequence_length - 1 : -1] - ) + untruncated_ids[:, max_sequence_length - 1:-1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) + f" {max_sequence_length} tokens: {removed_text}") prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=prompt_attention_mask - )[0] + text_input_ids.to(device), attention_mask=prompt_attention_mask)[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - batch_size * num_videos_per_prompt, seq_len, -1 - ) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, + seq_len, -1) prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) - prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.repeat( + num_videos_per_prompt, 1) return prompt_embeds, prompt_attention_mask @@ -344,23 +335,19 @@ def encode_prompt( if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt = ( - batch_size * [negative_prompt] - if isinstance(negative_prompt, str) - else negative_prompt - ) + negative_prompt = (batch_size * [negative_prompt] if isinstance( + negative_prompt, str) else negative_prompt) - if prompt is not None and type(prompt) is not type(negative_prompt): + if prompt is not None and type(prompt) is not type( + negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) + f" {type(prompt)}.") elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) + " the batch size of `prompt`.") ( negative_prompt_embeds, @@ -397,9 +384,8 @@ def check_inputs( ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs - for k in callback_on_step_end_tensor_inputs - ): + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) @@ -407,15 +393,13 @@ def check_inputs( if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) + " only forward one of the two.") elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - elif prompt is not None and ( - not isinstance(prompt, str) and not isinstance(prompt, list) - ): + elif prompt is not None and (not isinstance(prompt, str) + and not isinstance(prompt, list)): raise ValueError( f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" ) @@ -425,10 +409,8 @@ def check_inputs( "Must provide `prompt_attention_mask` when specifying `prompt_embeds`." ) - if ( - negative_prompt_embeds is not None - and negative_prompt_attention_mask is None - ): + if (negative_prompt_embeds is not None + and negative_prompt_attention_mask is None): raise ValueError( "Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`." ) @@ -438,14 +420,12 @@ def check_inputs( raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) + f" {negative_prompt_embeds.shape}.") if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: raise ValueError( "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" - f" {negative_prompt_attention_mask.shape}." - ) + f" {negative_prompt_attention_mask.shape}.") def enable_vae_slicing(self): r""" @@ -502,7 +482,10 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32) + latents = randn_tensor(shape, + generator=generator, + device=device, + dtype=torch.float32) latents = latents.to(dtype) return latents @@ -539,7 +522,8 @@ def __call__( timesteps: List[int] = None, guidance_scale: float = 4.5, num_videos_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, @@ -548,7 +532,8 @@ def __call__( output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], + None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, return_all_states=False, @@ -627,7 +612,8 @@ def __call__( is returned where the first element is a list with the generated images. """ - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + if isinstance(callback_on_step_end, + (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs height = height or self.default_height @@ -638,7 +624,8 @@ def __call__( prompt=prompt, height=height, width=width, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + callback_on_step_end_tensor_inputs= + callback_on_step_end_tensor_inputs, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, @@ -678,10 +665,10 @@ def __call__( device=device, ) if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], + dim=0) prompt_attention_mask = torch.cat( - [negative_prompt_attention_mask, prompt_attention_mask], dim=0 - ) + [negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels @@ -698,45 +685,52 @@ def __call__( ) world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group if get_sequence_parallel_state(): - latents = rearrange( - latents, "b t (n s) h w -> b t n s h w", n=world_size - ).contiguous() + latents = rearrange(latents, + "b t (n s) h w -> b t n s h w", + n=world_size).contiguous() latents = latents[:, :, rank, :, :, :] original_noise = copy.deepcopy(latents) # 5. Prepare timestep # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 threshold_noise = 0.025 - sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) + sigmas = linear_quadratic_schedule(num_inference_steps, + threshold_noise) sigmas = np.array(sigmas) # check if of type FlowMatchEulerDiscreteScheduler if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas, + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, ) else: timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, + self.scheduler, + num_inference_steps, + device, ) num_warmup_steps = max( - len(timesteps) - num_inference_steps * self.scheduler.order, 0 - ) + len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # 6. Denoising loop - self._progress_bar_config = {"disable": nccl_info.rank_within_group != 0} + self._progress_bar_config = { + "disable": nccl_info.rank_within_group != 0 + } with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue - latent_model_input = ( - torch.cat([latents] * 2) - if self.do_classifier_free_guidance - else latents - ) + latent_model_input = (torch.cat( + [latents] * + 2) if self.do_classifier_free_guidance else latents) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + timestep = t.expand(latent_model_input.shape[0]).to( + latents.dtype) noise_pred = self.transformer( hidden_states=latent_model_input, @@ -752,14 +746,14 @@ def __call__( if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step( - noise_pred, t, latents.to(torch.float32), return_dict=False - )[0] + latents = self.scheduler.step(noise_pred, + t, + latents.to(torch.float32), + return_dict=False)[0] latents = latents.to(latents_dtype) if latents.dtype != latents_dtype: @@ -771,15 +765,17 @@ def __call__( callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + (i + 1) > num_warmup_steps and + (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: @@ -799,36 +795,25 @@ def __call__( else: # unscale/denormalize the latents # denormalize with the mean and std if available and not None - has_latents_mean = ( - hasattr(self.vae.config, "latents_mean") - and self.vae.config.latents_mean is not None - ) - has_latents_std = ( - hasattr(self.vae.config, "latents_std") - and self.vae.config.latents_std is not None - ) + has_latents_mean = (hasattr(self.vae.config, "latents_mean") + and self.vae.config.latents_mean is not None) + has_latents_std = (hasattr(self.vae.config, "latents_std") + and self.vae.config.latents_std is not None) if has_latents_mean and has_latents_std: - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, 12, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = ( - torch.tensor(self.vae.config.latents_std) - .view(1, 12, 1, 1, 1) - .to(latents.device, latents.dtype) - ) + latents_mean = (torch.tensor( + self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to( + latents.device, latents.dtype)) + latents_std = (torch.tensor(self.vae.config.latents_std).view( + 1, 12, 1, 1, 1).to(latents.device, latents.dtype)) latents = ( - latents * latents_std / self.vae.config.scaling_factor - + latents_mean - ) + latents * latents_std / self.vae.config.scaling_factor + + latents_mean) else: latents = latents / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video( - video, output_type=output_type - ) + video, output_type=output_type) # Offload all models self.maybe_free_model_hooks() @@ -839,6 +824,6 @@ def __call__( return original_noise, video, latents, prompt_embeds, prompt_attention_mask if not return_dict: - return (video,) + return (video, ) return MochiPipelineOutput(frames=video) diff --git a/fastvideo/sample/generate_synthetic.py b/fastvideo/sample/generate_synthetic.py index f5186a8..42550db 100644 --- a/fastvideo/sample/generate_synthetic.py +++ b/fastvideo/sample/generate_synthetic.py @@ -1,16 +1,16 @@ +import argparse import json +import os -import torch.distributed as dist import torch -from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline -import os +import torch.distributed as dist from diffusers.utils import export_to_video -import argparse +from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline -def generate_video_and_latent( - pipe, prompt, height, width, num_frames, num_inference_steps, guidance_scale -): + +def generate_video_and_latent(pipe, prompt, height, width, num_frames, + num_inference_steps, guidance_scale): # Set the random seed for reproducibility generator = torch.Generator("cuda").manual_seed(12345) # Generate videos from the input prompt @@ -25,7 +25,8 @@ def generate_video_and_latent( output_type="latent_and_video", ) # prompt_embed has negative prompt at index 0 - return noise[0], video[0], latent[0], prompt_embed[1], prompt_attention_mask[1] + return noise[0], video[0], latent[0], prompt_embed[ + 1], prompt_attention_mask[1] # return dummy tensor to debug first # return torch.zeros(1, 3, 480, 848), torch.zeros(1, 256, 16, 16) @@ -39,19 +40,22 @@ def generate_video_and_latent( parser.add_argument("--num_inference_steps", type=int, default=64) parser.add_argument("--guidance_scale", type=float, default=4.5) parser.add_argument("--model_path", type=str, default="data/mochi") - parser.add_argument( - "--prompt_path", type=str, default="data/dummyVid/videos2caption.json" - ) - parser.add_argument("--dataset_output_dir", type=str, default="data/dummySynthetic") + parser.add_argument("--prompt_path", + type=str, + default="data/dummyVid/videos2caption.json") + parser.add_argument("--dataset_output_dir", + type=str, + default="data/dummySynthetic") args = parser.parse_args() local_rank = int(os.getenv("RANK", 0)) world_size = int(os.getenv("WORLD_SIZE", 1)) print("world_size", world_size, "local rank", local_rank) torch.cuda.set_device(local_rank) - dist.init_process_group( - backend="nccl", init_method="env://", world_size=world_size, rank=local_rank - ) + dist.init_process_group(backend="nccl", + init_method="env://", + world_size=world_size, + rank=local_rank) if not isinstance(args.prompt_path, list): args.prompt_path = [args.prompt_path] @@ -59,7 +63,8 @@ def generate_video_and_latent( text_prompt = open(args.prompt_path[0], "r").readlines() text_prompt = [i.strip() for i in text_prompt] - pipe = MochiPipeline.from_pretrained(args.model_path, torch_dtype=torch.bfloat16) + pipe = MochiPipeline.from_pretrained(args.model_path, + torch_dtype=torch.bfloat16) pipe.enable_vae_tiling() pipe.enable_model_cpu_offload(gpu_id=local_rank) # make dir if not exist @@ -68,10 +73,10 @@ def generate_video_and_latent( os.makedirs(os.path.join(args.dataset_output_dir, "noise"), exist_ok=True) os.makedirs(os.path.join(args.dataset_output_dir, "video"), exist_ok=True) os.makedirs(os.path.join(args.dataset_output_dir, "latent"), exist_ok=True) - os.makedirs(os.path.join(args.dataset_output_dir, "prompt_embed"), exist_ok=True) - os.makedirs( - os.path.join(args.dataset_output_dir, "prompt_attention_mask"), exist_ok=True - ) + os.makedirs(os.path.join(args.dataset_output_dir, "prompt_embed"), + exist_ok=True) + os.makedirs(os.path.join(args.dataset_output_dir, "prompt_attention_mask"), + exist_ok=True) data = [] for i, prompt in enumerate(text_prompt): if i % world_size != local_rank: @@ -93,17 +98,17 @@ def generate_video_and_latent( ) # save latent video_name = str(i) - noise_path = os.path.join(args.dataset_output_dir, "noise", video_name + ".pt") - latent_path = os.path.join( - args.dataset_output_dir, "latent", video_name + ".pt" - ) - prompt_embed_path = os.path.join( - args.dataset_output_dir, "prompt_embed", video_name + ".pt" - ) - video_path = os.path.join(args.dataset_output_dir, "video", video_name + ".mp4") - prompt_attention_mask_path = os.path.join( - args.dataset_output_dir, "prompt_attention_mask", video_name + ".pt" - ) + noise_path = os.path.join(args.dataset_output_dir, "noise", + video_name + ".pt") + latent_path = os.path.join(args.dataset_output_dir, "latent", + video_name + ".pt") + prompt_embed_path = os.path.join(args.dataset_output_dir, + "prompt_embed", video_name + ".pt") + video_path = os.path.join(args.dataset_output_dir, "video", + video_name + ".mp4") + prompt_attention_mask_path = os.path.join(args.dataset_output_dir, + "prompt_attention_mask", + video_name + ".pt") # save latent torch.save(noise, noise_path) torch.save(latent, latent_path) @@ -127,7 +132,6 @@ def generate_video_and_latent( # save json if local_rank == 0: all_data = [item for sublist in gathered_data for item in sublist] - with open( - os.path.join(args.dataset_output_dir, "videos2caption.json"), "w" - ) as f: + with open(os.path.join(args.dataset_output_dir, "videos2caption.json"), + "w") as f: json.dump(all_data, f, indent=4) diff --git a/fastvideo/sample/sample_t2v_diffusers_hunyuan.py b/fastvideo/sample/sample_t2v_diffusers_hunyuan.py index bf9b5ef..81b1fb3 100644 --- a/fastvideo/sample/sample_t2v_diffusers_hunyuan.py +++ b/fastvideo/sample/sample_t2v_diffusers_hunyuan.py @@ -1,12 +1,14 @@ -import torch -from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, BitsAndBytesConfig -import imageio as iio -import math -import numpy as np -import io -import time import argparse +import io import os +import time + +import imageio as iio +import numpy as np +import torch +from diffusers import (BitsAndBytesConfig, HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel) + def export_to_video_bytes(fps, frames): request = iio.core.Request("", mode="w", extension=".mp4") @@ -19,64 +21,81 @@ def export_to_video_bytes(fps, frames): out_bytes = io.BytesIO(new_bytes) return out_bytes + def export_to_video(frames, path, fps): video_bytes = export_to_video_bytes(fps, frames) video_bytes.seek(0) with open(path, "wb") as f: f.write(video_bytes.getbuffer()) + def main(args): torch.manual_seed(args.seed) device = "cuda" if torch.cuda.is_available() else "cpu" prompt_template = { - "template": ( - "<|start_header_cid|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " - "1. The main content and theme of the video." - "2. The color, shape, size, texture, quantity, text, and spatial relationships of the contents, including objects, people, and anything else." - "3. Actions, events, behaviors temporal relationships, physical movement changes of the contents." - "4. Background environment, light, style, atmosphere, and qualities." - "5. Camera angles, movements, and transitions used in the video." - "6. Thematic and aesthetic concepts associated with the scene, i.e. realistic, futuristic, fairy tale, etc<|eot_id|>" - "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" - ), - "crop_start": 95, + "template": + ("<|start_header_cid|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the contents, including objects, people, and anything else." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the contents." + "4. Background environment, light, style, atmosphere, and qualities." + "5. Camera angles, movements, and transitions used in the video." + "6. Thematic and aesthetic concepts associated with the scene, i.e. realistic, futuristic, fairy tale, etc<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"), + "crop_start": + 95, } - - + model_id = args.model_path if args.quantization == "nf4": - quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", llm_int8_skip_modules=["proj_out", "norm_out"]) + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_quant_type="nf4", + llm_int8_skip_modules=["proj_out", "norm_out"]) transformer = HunyuanVideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer/" ,torch_dtype=torch.bfloat16, quantization_config=quantization_config - ) + model_id, + subfolder="transformer/", + torch_dtype=torch.bfloat16, + quantization_config=quantization_config) if args.quantization == "int8": - quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out", "norm_out"]) + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, llm_int8_skip_modules=["proj_out", "norm_out"]) transformer = HunyuanVideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer/" ,torch_dtype=torch.bfloat16, quantization_config=quantization_config - ) + model_id, + subfolder="transformer/", + torch_dtype=torch.bfloat16, + quantization_config=quantization_config) elif not args.quantization: transformer = HunyuanVideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer/" ,torch_dtype=torch.bfloat16 - ).to(device) - - print("Max vram for read transofrmer:", round(torch.cuda.max_memory_allocated(device="cuda") / 1024 ** 3, 3), "GiB") + model_id, subfolder="transformer/", + torch_dtype=torch.bfloat16).to(device) + + print("Max vram for read transformer:", + round(torch.cuda.max_memory_allocated(device="cuda") / 1024**3, 3), + "GiB") torch.cuda.reset_max_memory_allocated(device) - + if not args.cpu_offload: - pipe = HunyuanVideoPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device) + pipe = HunyuanVideoPipeline.from_pretrained( + model_id, torch_dtype=torch.bfloat16).to(device) pipe.transformer = transformer else: - pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16) + pipe = HunyuanVideoPipeline.from_pretrained(model_id, + transformer=transformer, + torch_dtype=torch.bfloat16) torch.cuda.reset_max_memory_allocated(device) pipe.scheduler._shift = args.flow_shift pipe.vae.enable_tiling() if args.cpu_offload: pipe.enable_model_cpu_offload() - print("Max vram for init pipeline:", round(torch.cuda.max_memory_allocated(device="cuda") / 1024 ** 3, 3), "GiB") + print("Max vram for init pipeline:", + round(torch.cuda.max_memory_allocated(device="cuda") / 1024**3, 3), + "GiB") with open(args.prompt) as f: prompts = f.readlines() - + generator = torch.Generator("cpu").manual_seed(args.seed) os.makedirs(os.path.dirname(args.output_path), exist_ok=True) torch.cuda.reset_max_memory_allocated(device) @@ -84,16 +103,22 @@ def main(args): start_time = time.perf_counter() output = pipe( prompt=prompt, - height = args.height, - width = args.width, - num_frames = args.num_frames, + height=args.height, + width=args.width, + num_frames=args.num_frames, prompt_template=prompt_template, - num_inference_steps = args.num_inference_steps, + num_inference_steps=args.num_inference_steps, generator=generator, ).frames[0] - export_to_video(output, os.path.join(args.output_path, f"{prompt[:100]}.mp4"), fps=args.fps) + export_to_video(output, + os.path.join(args.output_path, f"{prompt[:100]}.mp4"), + fps=args.fps) print("Time:", round(time.perf_counter() - start_time, 2), "seconds") - print("Max vram for denoise:", round(torch.cuda.max_memory_allocated(device="cuda") / 1024 ** 3, 3), "GiB") + print( + "Max vram for denoise:", + round(torch.cuda.max_memory_allocated(device="cuda") / 1024**3, 3), + "GiB") + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -116,10 +141,14 @@ def main(args): default="flow", help="Denoise type for noised inputs.", ) - parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") - parser.add_argument( - "--neg_prompt", type=str, default=None, help="Negative prompt for sampling." - ) + parser.add_argument("--seed", + type=int, + default=None, + help="Seed for evaluation.") + parser.add_argument("--neg_prompt", + type=str, + default=None, + help="Negative prompt for sampling.") parser.add_argument( "--guidance_scale", type=float, @@ -132,12 +161,14 @@ def main(args): default=6.0, help="Embedded classifier free guidance scale.", ) - parser.add_argument( - "--flow_shift", type=int, default=7, help="Flow shift parameter." - ) - parser.add_argument( - "--batch_size", type=int, default=1, help="Batch size for inference." - ) + parser.add_argument("--flow_shift", + type=int, + default=7, + help="Flow shift parameter.") + parser.add_argument("--batch_size", + type=int, + default=1, + help="Batch size for inference.") parser.add_argument( "--num_videos", type=int, @@ -148,7 +179,8 @@ def main(args): "--load-key", type=str, default="module", - help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.", + help= + "Key to load the model states. 'module' for the main model, 'ema' for the EMA model.", ) parser.add_argument( "--use-cpu-offload", @@ -158,17 +190,20 @@ def main(args): parser.add_argument( "--dit-weight", type=str, - default="data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", + default= + "data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", ) parser.add_argument( "--reproduce", action="store_true", - help="Enable reproducibility by setting random seeds and deterministic algorithms.", + help= + "Enable reproducibility by setting random seeds and deterministic algorithms.", ) parser.add_argument( "--disable-autocast", action="store_true", - help="Disable autocast for denoising loop and vae decoding in pipeline sampling.", + help= + "Disable autocast for denoising loop and vae decoding in pipeline sampling.", ) # Flow Matching @@ -177,13 +212,15 @@ def main(args): action="store_true", help="If reverse, learning/sampling from t=1 -> t=0.", ) - parser.add_argument( - "--flow-solver", type=str, default="euler", help="Solver for flow matching." - ) + parser.add_argument("--flow-solver", + type=str, + default="euler", + help="Solver for flow matching.") parser.add_argument( "--use-linear-quadratic-schedule", action="store_true", - help="Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)", + help= + "Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)", ) parser.add_argument( "--linear-schedule-end", @@ -195,17 +232,20 @@ def main(args): # Model parameters parser.add_argument("--model", type=str, default="HYVideo-T/2-cfgdistill") parser.add_argument("--latent-channels", type=int, default=16) - parser.add_argument( - "--precision", type=str, default="bf16", choices=["fp32", "fp16", "bf16", "fp8"] - ) - parser.add_argument( - "--rope-theta", type=int, default=256, help="Theta used in RoPE." - ) + parser.add_argument("--precision", + type=str, + default="bf16", + choices=["fp32", "fp16", "bf16", "fp8"]) + parser.add_argument("--rope-theta", + type=int, + default=256, + help="Theta used in RoPE.") parser.add_argument("--vae", type=str, default="884-16c-hy") - parser.add_argument( - "--vae-precision", type=str, default="fp16", choices=["fp32", "fp16", "bf16"] - ) + parser.add_argument("--vae-precision", + type=str, + default="fp16", + choices=["fp32", "fp16", "bf16"]) parser.add_argument("--vae-tiling", action="store_true", default=True) parser.add_argument("--text-encoder", type=str, default="llm") @@ -218,10 +258,12 @@ def main(args): parser.add_argument("--text-states-dim", type=int, default=4096) parser.add_argument("--text-len", type=int, default=256) parser.add_argument("--tokenizer", type=str, default="llm") - parser.add_argument("--prompt-template", type=str, default="dit-llm-encode") - parser.add_argument( - "--prompt-template-video", type=str, default="dit-llm-encode-video" - ) + parser.add_argument("--prompt-template", + type=str, + default="dit-llm-encode") + parser.add_argument("--prompt-template-video", + type=str, + default="dit-llm-encode-video") parser.add_argument("--hidden-state-skip-layer", type=int, default=2) parser.add_argument("--apply-final-norm", action="store_true") @@ -237,4 +279,4 @@ def main(args): parser.add_argument("--text-len-2", type=int, default=77) args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/fastvideo/sample/sample_t2v_hunyuan.py b/fastvideo/sample/sample_t2v_hunyuan.py index 7383b46..b5506f9 100644 --- a/fastvideo/sample/sample_t2v_hunyuan.py +++ b/fastvideo/sample/sample_t2v_hunyuan.py @@ -1,26 +1,17 @@ +import argparse import os -import imageio -import time -from einops import rearrange +from pathlib import Path +import imageio +import numpy as np import torch +import torch.distributed as dist import torchvision -import numpy as np -from pathlib import Path -from loguru import logger -from datetime import datetime -import argparse -from diffusers.utils import export_to_video +from einops import rearrange -from fastvideo.models.hunyuan.utils.file_utils import save_videos_grid from fastvideo.models.hunyuan.inference import HunyuanVideoSampler - -import torch.distributed as dist - from fastvideo.utils.parallel_states import ( - initialize_sequence_parallel_state, - nccl_info, -) + initialize_sequence_parallel_state, nccl_info) def initialize_distributed(): @@ -28,16 +19,16 @@ def initialize_distributed(): world_size = int(os.getenv("WORLD_SIZE", 1)) print("world_size", world_size) torch.cuda.set_device(local_rank) - dist.init_process_group( - backend="nccl", init_method="env://", world_size=world_size, rank=local_rank - ) + dist.init_process_group(backend="nccl", + init_method="env://", + world_size=world_size, + rank=local_rank) initialize_sequence_parallel_state(world_size) def main(args): initialize_distributed() print(nccl_info.sp_size) - device = torch.cuda.current_device() print(args) models_root_path = Path(args.model_path) @@ -50,15 +41,11 @@ def main(args): # Load models hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained( - models_root_path, args=args - ) + models_root_path, args=args) # Get the updated args args = hunyuan_video_sampler.args - # Start sampling - samples = [] - with open(args.prompt) as f: prompts = f.readlines() @@ -84,9 +71,9 @@ def main(args): x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) outputs.append((x * 255).numpy().astype(np.uint8)) os.makedirs(os.path.dirname(args.output_path), exist_ok=True) - imageio.mimsave( - os.path.join(args.output_path, f"{prompt[:100]}.mp4"), outputs, fps=args.fps - ) + imageio.mimsave(os.path.join(args.output_path, f"{prompt[:100]}.mp4"), + outputs, + fps=args.fps) if __name__ == "__main__": @@ -109,10 +96,14 @@ def main(args): default="flow", help="Denoise type for noised inputs.", ) - parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") - parser.add_argument( - "--neg_prompt", type=str, default=None, help="Negative prompt for sampling." - ) + parser.add_argument("--seed", + type=int, + default=None, + help="Seed for evaluation.") + parser.add_argument("--neg_prompt", + type=str, + default=None, + help="Negative prompt for sampling.") parser.add_argument( "--guidance_scale", type=float, @@ -125,12 +116,14 @@ def main(args): default=6.0, help="Embedded classifier free guidance scale.", ) - parser.add_argument( - "--flow_shift", type=int, default=7, help="Flow shift parameter." - ) - parser.add_argument( - "--batch_size", type=int, default=1, help="Batch size for inference." - ) + parser.add_argument("--flow_shift", + type=int, + default=7, + help="Flow shift parameter.") + parser.add_argument("--batch_size", + type=int, + default=1, + help="Batch size for inference.") parser.add_argument( "--num_videos", type=int, @@ -141,7 +134,8 @@ def main(args): "--load-key", type=str, default="module", - help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.", + help= + "Key to load the model states. 'module' for the main model, 'ema' for the EMA model.", ) parser.add_argument( "--use-cpu-offload", @@ -151,17 +145,20 @@ def main(args): parser.add_argument( "--dit-weight", type=str, - default="data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", + default= + "data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", ) parser.add_argument( "--reproduce", action="store_true", - help="Enable reproducibility by setting random seeds and deterministic algorithms.", + help= + "Enable reproducibility by setting random seeds and deterministic algorithms.", ) parser.add_argument( "--disable-autocast", action="store_true", - help="Disable autocast for denoising loop and vae decoding in pipeline sampling.", + help= + "Disable autocast for denoising loop and vae decoding in pipeline sampling.", ) # Flow Matching @@ -170,13 +167,15 @@ def main(args): action="store_true", help="If reverse, learning/sampling from t=1 -> t=0.", ) - parser.add_argument( - "--flow-solver", type=str, default="euler", help="Solver for flow matching." - ) + parser.add_argument("--flow-solver", + type=str, + default="euler", + help="Solver for flow matching.") parser.add_argument( "--use-linear-quadratic-schedule", action="store_true", - help="Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)", + help= + "Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)", ) parser.add_argument( "--linear-schedule-end", @@ -188,17 +187,20 @@ def main(args): # Model parameters parser.add_argument("--model", type=str, default="HYVideo-T/2-cfgdistill") parser.add_argument("--latent-channels", type=int, default=16) - parser.add_argument( - "--precision", type=str, default="bf16", choices=["fp32", "fp16", "bf16"] - ) - parser.add_argument( - "--rope-theta", type=int, default=256, help="Theta used in RoPE." - ) + parser.add_argument("--precision", + type=str, + default="bf16", + choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--rope-theta", + type=int, + default=256, + help="Theta used in RoPE.") parser.add_argument("--vae", type=str, default="884-16c-hy") - parser.add_argument( - "--vae-precision", type=str, default="fp16", choices=["fp32", "fp16", "bf16"] - ) + parser.add_argument("--vae-precision", + type=str, + default="fp16", + choices=["fp32", "fp16", "bf16"]) parser.add_argument("--vae-tiling", action="store_true", default=True) parser.add_argument("--text-encoder", type=str, default="llm") @@ -211,10 +213,12 @@ def main(args): parser.add_argument("--text-states-dim", type=int, default=4096) parser.add_argument("--text-len", type=int, default=256) parser.add_argument("--tokenizer", type=str, default="llm") - parser.add_argument("--prompt-template", type=str, default="dit-llm-encode") - parser.add_argument( - "--prompt-template-video", type=str, default="dit-llm-encode-video" - ) + parser.add_argument("--prompt-template", + type=str, + default="dit-llm-encode") + parser.add_argument("--prompt-template-video", + type=str, + default="dit-llm-encode-video") parser.add_argument("--hidden-state-skip-layer", type=int, default=2) parser.add_argument("--apply-final-norm", action="store_true") diff --git a/fastvideo/sample/sample_t2v_mochi.py b/fastvideo/sample/sample_t2v_mochi.py index b42b31f..085c10c 100644 --- a/fastvideo/sample/sample_t2v_mochi.py +++ b/fastvideo/sample/sample_t2v_mochi.py @@ -1,27 +1,17 @@ -import torch -from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline -import torch.distributed as dist - -from diffusers.utils import export_to_video -from fastvideo.utils.parallel_states import ( - initialize_sequence_parallel_state, - nccl_info, -) import argparse -import os -from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel import json -from typing import Optional -from safetensors.torch import save_file, load_file -from peft import set_peft_model_state_dict, inject_adapter_in_model, load_peft_weights -from peft import LoraConfig -import sys -import pdb -import copy -from typing import Dict +import os + +import torch +import torch.distributed as dist from diffusers import FlowMatchEulerDiscreteScheduler -from diffusers.utils import convert_unet_state_dict_to_peft +from diffusers.utils import export_to_video + from fastvideo.distill.solver import PCMFMScheduler +from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel +from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline +from fastvideo.utils.parallel_states import ( + initialize_sequence_parallel_state, nccl_info) def initialize_distributed(): @@ -29,9 +19,10 @@ def initialize_distributed(): world_size = int(os.getenv("WORLD_SIZE", 1)) print("world_size", world_size) torch.cuda.set_device(local_rank) - dist.init_process_group( - backend="nccl", init_method="env://", world_size=world_size, rank=local_rank - ) + dist.init_process_group(backend="nccl", + init_method="env://", + world_size=world_size, + rank=local_rank) initialize_sequence_parallel_state(world_size) @@ -40,7 +31,7 @@ def main(args): print(nccl_info.sp_size) device = torch.cuda.current_device() # Peiyuan: GPU seed will cause A100 and H100 to produce different results ..... - weight_dtype = torch.bfloat16 + if args.scheduler_type == "euler": scheduler = FlowMatchEulerDiscreteScheduler() else: @@ -54,29 +45,33 @@ def main(args): args.linear_range, ) if args.transformer_path is not None: - transformer = MochiTransformer3DModel.from_pretrained(args.transformer_path) + transformer = MochiTransformer3DModel.from_pretrained( + args.transformer_path) else: transformer = MochiTransformer3DModel.from_pretrained( - args.model_path, subfolder="transformer/" - ) + args.model_path, subfolder="transformer/") - pipe = MochiPipeline.from_pretrained( - args.model_path, transformer=transformer, scheduler=scheduler - ) + pipe = MochiPipeline.from_pretrained(args.model_path, + transformer=transformer, + scheduler=scheduler) pipe.enable_vae_tiling() if args.lora_checkpoint_dir is not None: print(f"Loading LoRA weights from {args.lora_checkpoint_dir}") - config_path = os.path.join(args.lora_checkpoint_dir, "lora_config.json") + config_path = os.path.join(args.lora_checkpoint_dir, + "lora_config.json") with open(config_path, "r") as f: lora_config_dict = json.load(f) rank = lora_config_dict["lora_params"]["lora_rank"] lora_alpha = lora_config_dict["lora_params"]["lora_alpha"] lora_scaling = lora_alpha / rank - pipe.load_lora_weights(args.lora_checkpoint_dir, adapter_name="default") + pipe.load_lora_weights(args.lora_checkpoint_dir, + adapter_name="default") pipe.set_adapters(["default"], [lora_scaling]) - print(f"Successfully Loaded LoRA weights from {args.lora_checkpoint_dir}") + print( + f"Successfully Loaded LoRA weights from {args.lora_checkpoint_dir}" + ) # pipe.to(device) pipe.enable_model_cpu_offload(device) @@ -84,18 +79,13 @@ def main(args): # Generate videos from the input prompt if args.prompt_embed_path is not None: - prompt_embeds = ( - torch.load(args.prompt_embed_path, map_location="cpu", weights_only=True) - .to(device) - .unsqueeze(0) - ) - encoder_attention_mask = ( - torch.load( - args.encoder_attention_mask_path, map_location="cpu", weights_only=True - ) - .to(device) - .unsqueeze(0) - ) + prompt_embeds = (torch.load(args.prompt_embed_path, + map_location="cpu", + weights_only=True).to(device).unsqueeze(0)) + encoder_attention_mask = (torch.load( + args.encoder_attention_mask_path, + map_location="cpu", + weights_only=True).to(device).unsqueeze(0)) prompts = None elif args.prompt_path is not None: prompts = [line.strip() for line in open(args.prompt_path, "r")] @@ -161,7 +151,9 @@ def main(args): parser.add_argument("--prompt_embed_path", type=str, default=None) parser.add_argument("--prompt_path", type=str, default=None) parser.add_argument("--scheduler_type", type=str, default="euler") - parser.add_argument("--encoder_attention_mask_path", type=str, default=None) + parser.add_argument("--encoder_attention_mask_path", + type=str, + default=None) parser.add_argument( "--lora_checkpoint_dir", type=str, diff --git a/fastvideo/sample/sample_t2v_mochi_no_sp.py b/fastvideo/sample/sample_t2v_mochi_no_sp.py index 3fc4e5b..2a1e702 100644 --- a/fastvideo/sample/sample_t2v_mochi_no_sp.py +++ b/fastvideo/sample/sample_t2v_mochi_no_sp.py @@ -1,9 +1,11 @@ -import torch -from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline -from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel -from diffusers.utils import export_to_video, load_image, load_video import argparse + +import torch from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import export_to_video + +from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel +from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline def main(args): @@ -12,14 +14,14 @@ def main(args): # do not invert scheduler = FlowMatchEulerDiscreteScheduler() if args.transformer_path is not None: - transformer = MochiTransformer3DModel.from_pretrained(args.transformer_path) + transformer = MochiTransformer3DModel.from_pretrained( + args.transformer_path) else: transformer = MochiTransformer3DModel.from_pretrained( - args.model_path, subfolder="transformer/" - ) - pipe = MochiPipeline.from_pretrained( - args.model_path, transformer=transformer, scheduler=scheduler - ) + args.model_path, subfolder="transformer/") + pipe = MochiPipeline.from_pretrained(args.model_path, + transformer=transformer, + scheduler=scheduler) pipe.enable_vae_tiling() # pipe.to("cuda:1") pipe.enable_model_cpu_offload() diff --git a/fastvideo/train.py b/fastvideo/train.py index 401cb99..ecb7cb9 100644 --- a/fastvideo/train.py +++ b/fastvideo/train.py @@ -1,57 +1,45 @@ +# !/bin/python3 +# isort: skip_file import argparse -from email.policy import strict -import logging import math import os -import shutil -from pathlib import Path -from fastvideo.utils.parallel_states import ( - initialize_sequence_parallel_state, - destroy_sequence_parallel_group, - get_sequence_parallel_state, - nccl_info, -) -from fastvideo.utils.communications import sp_parallel_dataloader_wrapper, broadcast -from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input -from fastvideo.utils.validation import log_validation import time -from torch.utils.data import DataLoader +from collections import deque + import torch -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - StateDictType, - FullStateDictConfig, -) -import json -from torch.utils.data.distributed import DistributedSampler -from fastvideo.utils.dataset_utils import LengthGroupedSampler +import torch.distributed as dist import wandb from accelerate.utils import set_seed -from tqdm.auto import tqdm -from fastvideo.utils.fsdp_util import get_dit_fsdp_kwargs, apply_fsdp_checkpointing -from diffusers.utils import convert_unet_state_dict_to_peft from diffusers import FlowMatchEulerDiscreteScheduler -from fastvideo.utils.load import load_transformer from diffusers.optimization import get_scheduler -from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel -from diffusers.utils import check_min_version -from fastvideo.dataset.latent_datasets import LatentDataset, latent_collate_function -import torch.distributed as dist -from safetensors.torch import save_file, load_file -from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft +from peft import LoraConfig, set_peft_model_state_dict from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from fastvideo.utils.checkpoint import ( - save_checkpoint, - save_lora_checkpoint, - resume_lora_optimizer, -) -from fastvideo.utils.logging_ import main_print +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from tqdm.auto import tqdm + +from fastvideo.dataset.latent_datasets import (LatentDataset, + latent_collate_function) +from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline +from fastvideo.utils.checkpoint import (resume_lora_optimizer, save_checkpoint, + save_lora_checkpoint) +from fastvideo.utils.communications import (broadcast, + sp_parallel_dataloader_wrapper) +from fastvideo.utils.dataset_utils import LengthGroupedSampler +from fastvideo.utils.fsdp_util import (apply_fsdp_checkpointing, + get_dit_fsdp_kwargs) +from fastvideo.utils.load import load_transformer +from fastvideo.utils.logging_ import main_print +from fastvideo.utils.parallel_states import (destroy_sequence_parallel_group, + get_sequence_parallel_state, + initialize_sequence_parallel_state + ) +from fastvideo.utils.validation import log_validation # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.31.0") -import time -from collections import deque def compute_density_for_timestep_sampling( @@ -74,24 +62,29 @@ def compute_density_for_timestep_sampling( u = torch.normal( mean=logit_mean, std=logit_std, - size=(batch_size,), + size=(batch_size, ), device="cpu", generator=generator, ) u = torch.nn.functional.sigmoid(u) elif weighting_scheme == "mode": - u = torch.rand(size=(batch_size,), device="cpu", generator=generator) - u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + u = torch.rand(size=(batch_size, ), device="cpu", generator=generator) + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2)**2 - 1 + u) else: - u = torch.rand(size=(batch_size,), device="cpu", generator=generator) + u = torch.rand(size=(batch_size, ), device="cpu", generator=generator) return u -def get_sigmas(noise_scheduler, device, timesteps, n_dim=4, dtype=torch.float32): +def get_sigmas(noise_scheduler, + device, + timesteps, + n_dim=4, + dtype=torch.float32): sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) schedule_timesteps = noise_scheduler.timesteps.to(device) timesteps = timesteps.to(device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [(schedule_timesteps == t).nonzero().item() + for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < n_dim: @@ -137,7 +130,8 @@ def train_one_step( mode_scale=mode_scale, ) indices = (u * noise_scheduler.config.num_train_timesteps).long() - timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) + timesteps = noise_scheduler.timesteps[indices].to( + device=latents.device) if sp_size > 1: # Make sure that the timesteps are the same across all sp processes. broadcast(timesteps) @@ -170,10 +164,8 @@ def train_one_step( else: target = noise - latents - loss = ( - torch.mean((model_pred.float() - target.float()) ** 2) - / gradient_accumulation_steps - ) + loss = (torch.mean((model_pred.float() - target.float())**2) / + gradient_accumulation_steps) loss.backward() @@ -209,7 +201,7 @@ def main(args): if rank <= 0 and args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) - # For mixed precision training we cast all non-trainable weigths to half-precision + # For mixed precision training we cast all non-trainable weights to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. # Create model: @@ -236,25 +228,24 @@ def main(args): if args.resume_from_lora_checkpoint: lora_state_dict = MochiPipeline.lora_state_dict( - args.resume_from_lora_checkpoint - ) + args.resume_from_lora_checkpoint) transformer_state_dict = { f'{k.replace("transformer.", "")}': v - for k, v in lora_state_dict.items() - if k.startswith("transformer.") + for k, v in lora_state_dict.items() if k.startswith("transformer.") } - transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) - incompatible_keys = set_peft_model_state_dict( - transformer, transformer_state_dict, adapter_name="default" - ) + transformer_state_dict = convert_unet_state_dict_to_peft( + transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer, + transformer_state_dict, + adapter_name="default") if incompatible_keys is not None: # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", + None) if unexpected_keys: main_print( f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) + f" {unexpected_keys}. ") main_print( f" Total training parameters = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e6} M" @@ -273,19 +264,24 @@ def main(args): if args.use_lora: transformer.config.lora_rank = args.lora_rank transformer.config.lora_alpha = args.lora_alpha - transformer.config.lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + transformer.config.lora_target_modules = [ + "to_k", "to_q", "to_v", "to_out.0" + ] transformer._no_split_modules = [ no_split_module.__name__ for no_split_module in no_split_modules ] - fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"](transformer) + fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"]( + transformer) - transformer = FSDP(transformer, **fsdp_kwargs,) - main_print(f"--> model loaded") + transformer = FSDP( + transformer, + **fsdp_kwargs, + ) + main_print("--> model loaded") if args.gradient_checkpointing: - apply_fsdp_checkpointing( - transformer, no_split_modules, args.selective_checkpointing - ) + apply_fsdp_checkpointing(transformer, no_split_modules, + args.selective_checkpointing) # Set model as trainable. transformer.train() @@ -293,7 +289,8 @@ def main(args): noise_scheduler = FlowMatchEulerDiscreteScheduler() params_to_optimize = transformer.parameters() - params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize)) + params_to_optimize = list( + filter(lambda p: p.requires_grad, params_to_optimize)) optimizer = torch.optim.AdamW( params_to_optimize, @@ -306,8 +303,7 @@ def main(args): init_steps = 0 if args.resume_from_lora_checkpoint: transformer, optimizer, init_steps = resume_lora_optimizer( - transformer, args.resume_from_lora_checkpoint, optimizer - ) + transformer, args.resume_from_lora_checkpoint, optimizer) main_print(f"optimizer: {optimizer}") lr_scheduler = get_scheduler( @@ -320,21 +316,17 @@ def main(args): last_epoch=init_steps - 1, ) - train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, args.cfg) - sampler = ( - LengthGroupedSampler( - args.train_batch_size, - rank=rank, - world_size=world_size, - lengths=train_dataset.lengths, - group_frame=args.group_frame, - group_resolution=args.group_resolution, - ) - if (args.group_frame or args.group_resolution) - else DistributedSampler( - train_dataset, rank=rank, num_replicas=world_size, shuffle=False - ) - ) + train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, + args.cfg) + sampler = (LengthGroupedSampler( + args.train_batch_size, + rank=rank, + world_size=world_size, + lengths=train_dataset.lengths, + group_frame=args.group_frame, + group_resolution=args.group_resolution, + ) if (args.group_frame or args.group_resolution) else DistributedSampler( + train_dataset, rank=rank, num_replicas=world_size, shuffle=False)) train_dataloader = DataLoader( train_dataset, @@ -347,45 +339,43 @@ def main(args): ) num_update_steps_per_epoch = math.ceil( - len(train_dataloader) - / args.gradient_accumulation_steps - * args.sp_size - / args.train_sp_batch_size - ) - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + len(train_dataloader) / args.gradient_accumulation_steps * + args.sp_size / args.train_sp_batch_size) + args.num_train_epochs = math.ceil(args.max_train_steps / + num_update_steps_per_epoch) if rank <= 0: project = args.tracker_project_name or "fastvideo" wandb.init(project=project, config=args) # Train! - total_batch_size = ( - args.train_batch_size - * world_size - * args.gradient_accumulation_steps - / args.sp_size - * args.train_sp_batch_size - ) + total_batch_size = (args.train_batch_size * world_size * + args.gradient_accumulation_steps / args.sp_size * + args.train_sp_batch_size) main_print("***** Running training *****") main_print(f" Num examples = {len(train_dataset)}") main_print(f" Dataloader size = {len(train_dataloader)}") main_print(f" Num Epochs = {args.num_train_epochs}") main_print(f" Resume training from step {init_steps}") - main_print(f" Instantaneous batch size per device = {args.train_batch_size}") + main_print( + f" Instantaneous batch size per device = {args.train_batch_size}") main_print( f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}" ) - main_print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + main_print( + f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") main_print(f" Total optimization steps = {args.max_train_steps}") main_print( f" Total training parameters per FSDP shard = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e9} B" ) # print dtype - main_print(f" Master weight dtype: {transformer.parameters().__next__().dtype}") + main_print( + f" Master weight dtype: {transformer.parameters().__next__().dtype}") # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: - assert NotImplementedError("resume_from_checkpoint is not supported now.") + assert NotImplementedError( + "resume_from_checkpoint is not supported now.") # TODO progress_bar = tqdm( @@ -433,13 +423,11 @@ def main(args): step_times.append(step_time) avg_step_time = sum(step_times) / len(step_times) - progress_bar.set_postfix( - { - "loss": f"{loss:.4f}", - "step_time": f"{step_time:.2f}s", - "grad_norm": grad_norm, - } - ) + progress_bar.set_postfix({ + "loss": f"{loss:.4f}", + "step_time": f"{step_time:.2f}s", + "grad_norm": grad_norm, + }) progress_bar.update(1) if rank <= 0: wandb.log( @@ -455,24 +443,22 @@ def main(args): if step % args.checkpointing_steps == 0: if args.use_lora: # Save LoRA weights - save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, step - ) + save_lora_checkpoint(transformer, optimizer, rank, + args.output_dir, step) else: # Your existing checkpoint saving code - save_checkpoint(transformer, optimizer, rank, args.output_dir, step) + save_checkpoint(transformer, optimizer, rank, args.output_dir, + step) dist.barrier() if args.log_validation and step % args.validation_steps == 0: log_validation(args, transformer, device, torch.bfloat16, step) if args.use_lora: - save_lora_checkpoint( - transformer, optimizer, rank, args.output_dir, args.max_train_steps - ) + save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, + args.max_train_steps) else: - save_checkpoint( - transformer, optimizer, rank, args.output_dir, args.max_train_steps - ) + save_checkpoint(transformer, optimizer, rank, args.output_dir, + args.max_train_steps) if get_sequence_parallel_state(): destroy_sequence_parallel_group() @@ -480,9 +466,10 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--model_type", type=str, default="mochi", help="The type of model to train." - ) + parser.add_argument("--model_type", + type=str, + default="mochi", + help="The type of model to train.") # dataset & dataloader parser.add_argument("--data_json_path", type=str, required=True) parser.add_argument("--num_frames", type=int, default=163) @@ -490,7 +477,8 @@ def main(args): "--dataloader_num_workers", type=int, default=10, - help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + help= + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", ) parser.add_argument( "--train_batch_size", @@ -498,9 +486,10 @@ def main(args): default=16, help="Batch size (per device) for the training dataloader.", ) - parser.add_argument( - "--num_latent_t", type=int, default=28, help="Number of latent timesteps." - ) + parser.add_argument("--num_latent_t", + type=int, + default=28, + help="Number of latent timesteps.") parser.add_argument("--group_frame", action="store_true") # TODO parser.add_argument("--group_resolution", action="store_true") # TODO @@ -537,14 +526,16 @@ def main(args): parser.add_argument("--validation_steps", type=int, default=50) parser.add_argument("--log_validation", action="store_true") parser.add_argument("--tracker_project_name", type=str, default=None) - parser.add_argument( - "--seed", type=int, default=None, help="A seed for reproducible training." - ) + parser.add_argument("--seed", + type=int, + default=None, + help="A seed for reproducible training.") parser.add_argument( "--output_dir", type=str, default=None, - help="The output directory where the model predictions and checkpoints will be written.", + help= + "The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--checkpoints_total_limit", @@ -556,38 +547,36 @@ def main(args): "--checkpointing_steps", type=int, default=500, - help=( - "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" - " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" - " training using `--resume_from_checkpoint`." - ), + help= + ("Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`."), ) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, - help=( - "Whether training should be resumed from a previous checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), + help= + ("Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), ) parser.add_argument( "--resume_from_lora_checkpoint", type=str, default=None, - help=( - "Whether training should be resumed from a previous lora checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), + help= + ("Whether training should be resumed from a previous lora checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), ) parser.add_argument( "--logging_dir", type=str, default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), + help= + ("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), ) # optimizer & scheduler & Training @@ -596,25 +585,29 @@ def main(args): "--max_train_steps", type=int, default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + help= + "Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", + help= + "Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, - help="Initial learning rate (after the potential warmup period) to use.", + help= + "Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=False, - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + help= + "Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_warmup_steps", @@ -622,41 +615,47 @@ def main(args): default=10, help="Number of steps for the warmup in the lr scheduler.", ) - parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." - ) + parser.add_argument("--max_grad_norm", + default=1.0, + type=float, + help="Max gradient norm.") parser.add_argument( "--gradient_checkpointing", action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + help= + "Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument("--selective_checkpointing", type=float, default=1.0) parser.add_argument( "--allow_tf32", action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), + help= + ("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), + help= + ("Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), ) parser.add_argument( "--use_cpu_offload", action="store_true", - help="Whether to use CPU offload for param & gradient & optimizer states.", + help= + "Whether to use CPU offload for param & gradient & optimizer states.", ) - parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel") + parser.add_argument("--sp_size", + type=int, + default=1, + help="For sequence parallel") parser.add_argument( "--train_sp_batch_size", type=int, @@ -670,12 +669,14 @@ def main(args): default=False, help="Whether to use LoRA for finetuning.", ) - parser.add_argument( - "--lora_alpha", type=int, default=256, help="Alpha parameter for LoRA." - ) - parser.add_argument( - "--lora_rank", type=int, default=128, help="LoRA rank parameter. " - ) + parser.add_argument("--lora_alpha", + type=int, + default=256, + help="Alpha parameter for LoRA.") + parser.add_argument("--lora_rank", + type=int, + default=128, + help="LoRA rank parameter. ") parser.add_argument("--fsdp_sharding_startegy", default="full") parser.add_argument( @@ -700,17 +701,17 @@ def main(args): "--mode_scale", type=float, default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + help= + "Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", ) # lr_scheduler parser.add_argument( "--lr_scheduler", type=str, default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), + help= + ('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), ) parser.add_argument( "--lr_num_cycles", @@ -724,9 +725,10 @@ def main(args): default=1.0, help="Power factor of the polynomial scheduler.", ) - parser.add_argument( - "--weight_decay", type=float, default=0.01, help="Weight decay to apply." - ) + parser.add_argument("--weight_decay", + type=float, + default=0.01, + help="Weight decay to apply.") parser.add_argument( "--master_weight_type", type=str, diff --git a/fastvideo/utils/checkpoint.py b/fastvideo/utils/checkpoint.py index 3fd8bfc..b84fd34 100644 --- a/fastvideo/utils/checkpoint.py +++ b/fastvideo/utils/checkpoint.py @@ -1,41 +1,49 @@ # import -import os import json +import os + import torch -from fastvideo.utils.logging_ import main_print -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - StateDictType, - FullStateDictConfig, -) -from safetensors.torch import save_file, load_file import torch.distributed.checkpoint as dist_cp -from torch.distributed.checkpoint.default_planner import ( - DefaultSavePlanner, - DefaultLoadPlanner, -) -from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict -from torch.distributed.fsdp import FullOptimStateDictConfig -from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from peft import get_peft_model_state_dict +from safetensors.torch import load_file, save_file +from torch.distributed.checkpoint.default_planner import (DefaultLoadPlanner, + DefaultSavePlanner) +from torch.distributed.checkpoint.optimizer import \ + load_sharded_optimizer_state_dict +from torch.distributed.fsdp import (FullOptimStateDictConfig, + FullStateDictConfig) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import StateDictType + from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline +from fastvideo.utils.logging_ import main_print -def save_checkpoint(model, optimizer, rank, output_dir, step, discriminator=False): +def save_checkpoint(model, + optimizer, + rank, + output_dir, + step, + discriminator=False): with FSDP.state_dict_type( - model, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True), - FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), ): cpu_state = model.state_dict() - optim_state = FSDP.optim_state_dict(model, optimizer,) + optim_state = FSDP.optim_state_dict( + model, + optimizer, + ) # todo move to get_state_dict save_dir = os.path.join(output_dir, f"checkpoint-{step}") os.makedirs(save_dir, exist_ok=True) # save using safetensors if rank <= 0 and not discriminator: - weight_path = os.path.join(save_dir, "diffusion_pytorch_model.safetensors") + weight_path = os.path.join(save_dir, + "diffusion_pytorch_model.safetensors") save_file(cpu_state, weight_path) config_dict = dict(model.config) config_path = os.path.join(save_dir, "config.json") @@ -45,19 +53,26 @@ def save_checkpoint(model, optimizer, rank, output_dir, step, discriminator=Fals optimizer_path = os.path.join(save_dir, "optimizer.pt") torch.save(optim_state, optimizer_path) else: - weight_path = os.path.join(save_dir, "discriminator_pytorch_model.safetensors") + weight_path = os.path.join(save_dir, + "discriminator_pytorch_model.safetensors") save_file(cpu_state, weight_path) optimizer_path = os.path.join(save_dir, "discriminator_optimizer.pt") torch.save(optim_state, optimizer_path) def save_checkpoint_generator_discriminator( - model, optimizer, discriminator, discriminator_optimizer, rank, output_dir, step, + model, + optimizer, + discriminator, + discriminator_optimizer, + rank, + output_dir, + step, ): with FSDP.state_dict_type( - model, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ): cpu_state = model.state_dict() @@ -73,7 +88,8 @@ def save_checkpoint_generator_discriminator( # save dict as json with open(config_path, "w") as f: json.dump(config_dict, f, indent=4) - weight_path = os.path.join(hf_weight_dir, "diffusion_pytorch_model.safetensors") + weight_path = os.path.join(hf_weight_dir, + "diffusion_pytorch_model.safetensors") save_file(cpu_state, weight_path) main_print(f"--> saved HF weight checkpoint at path {hf_weight_dir}") @@ -97,21 +113,22 @@ def save_checkpoint_generator_discriminator( planner=DefaultSavePlanner(), ) - discriminator_fsdp_state_dir = os.path.join(save_dir, "discriminator_fsdp_state") + discriminator_fsdp_state_dir = os.path.join(save_dir, + "discriminator_fsdp_state") os.makedirs(discriminator_fsdp_state_dir, exist_ok=True) with FSDP.state_dict_type( - discriminator, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True), - FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + discriminator, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), ): - optim_state = FSDP.optim_state_dict(discriminator, discriminator_optimizer) + optim_state = FSDP.optim_state_dict(discriminator, + discriminator_optimizer) model_state = discriminator.state_dict() state_dict = {"optimizer": optim_state, "model": model_state} if rank <= 0: discriminator_fsdp_state_fil = os.path.join( - discriminator_fsdp_state_dir, "discriminator_state.pt" - ) + discriminator_fsdp_state_dir, "discriminator_state.pt") torch.save(state_dict, discriminator_fsdp_state_fil) main_print("--> saved FSDP state checkpoint") @@ -128,8 +145,7 @@ def load_sharded_model(model, optimizer, model_dir, optimizer_dir): ) optim_state = optim_state["optimizer"] flattened_osd = FSDP.optim_state_dict_to_load( - model=model, optim=optimizer, optim_state_dict=optim_state - ) + model=model, optim=optimizer, optim_state_dict=optim_state) optimizer.load_state_dict(flattened_osd) dist_cp.load_state_dict( state_dict=weight_state_dict, @@ -144,10 +160,10 @@ def load_sharded_model(model, optimizer, model_dir, optimizer_dir): def load_full_state_model(model, optimizer, checkpoint_file, rank): with FSDP.state_dict_type( - model, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True), - FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), ): discriminator_state = torch.load(checkpoint_file) model_state = discriminator_state["model"] @@ -157,8 +173,7 @@ def load_full_state_model(model, optimizer, checkpoint_file, rank): optim_state = None model.load_state_dict(model_state) discriminator_optim_state = FSDP.optim_state_dict_to_load( - model=model, optim=optimizer, optim_state_dict=optim_state - ) + model=model, optim=optimizer, optim_state_dict=optim_state) optimizer.load_state_dict(discriminator_optim_state) main_print( f"--> loaded discriminator and discriminator optimizer from path {checkpoint_file}" @@ -166,37 +181,35 @@ def load_full_state_model(model, optimizer, checkpoint_file, rank): return model, optimizer -def resume_training_generator_discriminator( - model, optimizer, discriminator, discriminator_optimizer, checkpoint_dir, rank -): +def resume_training_generator_discriminator(model, optimizer, discriminator, + discriminator_optimizer, + checkpoint_dir, rank): step = int(checkpoint_dir.split("-")[-1]) model_weight_dir = os.path.join(checkpoint_dir, "model_weights_state") model_optimizer_dir = os.path.join(checkpoint_dir, "model_optimizer_state") - model, optimizer = load_sharded_model( - model, optimizer, model_weight_dir, model_optimizer_dir - ) - discriminator_ckpt_file = os.path.join( - checkpoint_dir, "discriminator_fsdp_state", "discriminator_state.pt" - ) + model, optimizer = load_sharded_model(model, optimizer, model_weight_dir, + model_optimizer_dir) + discriminator_ckpt_file = os.path.join(checkpoint_dir, + "discriminator_fsdp_state", + "discriminator_state.pt") discriminator, discriminator_optimizer = load_full_state_model( - discriminator, discriminator_optimizer, discriminator_ckpt_file, rank - ) + discriminator, discriminator_optimizer, discriminator_ckpt_file, rank) return model, optimizer, discriminator, discriminator_optimizer, step def resume_training(model, optimizer, checkpoint_dir, discriminator=False): - weight_path = os.path.join(checkpoint_dir, "diffusion_pytorch_model.safetensors") + weight_path = os.path.join(checkpoint_dir, + "diffusion_pytorch_model.safetensors") if discriminator: - weight_path = os.path.join( - checkpoint_dir, "discriminator_pytorch_model.safetensors" - ) + weight_path = os.path.join(checkpoint_dir, + "discriminator_pytorch_model.safetensors") model_weights = load_file(weight_path) with FSDP.state_dict_type( - model, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True), - FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), ): current_state = model.state_dict() current_state.update(model_weights) @@ -207,8 +220,7 @@ def resume_training(model, optimizer, checkpoint_dir, discriminator=False): optim_path = os.path.join(checkpoint_dir, "optimizer.pt") optimizer_state_dict = torch.load(optim_path, weights_only=False) optim_state = FSDP.optim_state_dict_to_load( - model=model, optim=optimizer, optim_state_dict=optimizer_state_dict - ) + model=model, optim=optimizer, optim_state_dict=optimizer_state_dict) optimizer.load_state_dict(optim_state) step = int(checkpoint_dir.split("-")[-1]) return model, optimizer, step @@ -216,12 +228,15 @@ def resume_training(model, optimizer, checkpoint_dir, discriminator=False): def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step): with FSDP.state_dict_type( - transformer, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + transformer, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ): full_state_dict = transformer.state_dict() - lora_optim_state = FSDP.optim_state_dict(transformer, optimizer,) + lora_optim_state = FSDP.optim_state_dict( + transformer, + optimizer, + ) if rank <= 0: save_dir = os.path.join(output_dir, f"lora-checkpoint-{step}") @@ -233,8 +248,7 @@ def save_lora_checkpoint(transformer, optimizer, rank, output_dir, step): # save lora weight main_print(f"--> saving LoRA checkpoint at step {step}") transformer_lora_layers = get_peft_model_state_dict( - model=transformer, state_dict=full_state_dict - ) + model=transformer, state_dict=full_state_dict) MochiPipeline.save_lora_weights( save_directory=save_dir, transformer_lora_layers=transformer_lora_layers, @@ -262,8 +276,9 @@ def resume_lora_optimizer(transformer, checkpoint_dir, optimizer): optim_path = os.path.join(checkpoint_dir, "lora_optimizer.pt") optimizer_state_dict = torch.load(optim_path, weights_only=False) optim_state = FSDP.optim_state_dict_to_load( - model=transformer, optim=optimizer, optim_state_dict=optimizer_state_dict - ) + model=transformer, + optim=optimizer, + optim_state_dict=optimizer_state_dict) optimizer.load_state_dict(optim_state) step = config_dict["step"] main_print(f"--> Successfully resuming LoRA optimizer from step {step}") diff --git a/fastvideo/utils/communications.py b/fastvideo/utils/communications.py index f7ff0b7..6d18be6 100644 --- a/fastvideo/utils/communications.py +++ b/fastvideo/utils/communications.py @@ -3,12 +3,13 @@ # DeepSpeed Team +from typing import Any, Tuple + import torch import torch.distributed as dist -from fastvideo.utils.parallel_states import nccl_info -from typing import Any, Tuple from torch import Tensor -from torch.nn import Module + +from fastvideo.utils.parallel_states import nccl_info def broadcast(input_: torch.Tensor): @@ -16,9 +17,10 @@ def broadcast(input_: torch.Tensor): dist.broadcast(input_, src=src, group=nccl_info.group) -def _all_to_all_4D( - input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None -) -> torch.tensor: +def _all_to_all_4D(input: torch.tensor, + scatter_idx: int = 2, + gather_idx: int = 1, + group=None) -> torch.tensor: """ all-to-all for QKV @@ -45,11 +47,8 @@ def _all_to_all_4D( # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) - input_t = ( - input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs) - .transpose(0, 2) - .contiguous() - ) + input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, + hs).transpose(0, 2).contiguous()) output = torch.empty_like(input_t) # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single @@ -63,7 +62,8 @@ def _all_to_all_4D( output = output.reshape(seqlen, bs, shard_hc, hs) # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) - output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) + output = output.transpose(0, 1).contiguous().reshape( + bs, seqlen, shard_hc, hs) return output @@ -76,13 +76,10 @@ def _all_to_all_4D( # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) - input_t = ( - input.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs) - .transpose(0, 3) - .transpose(0, 1) - .contiguous() - .reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs) - ) + input_t = (input.reshape( + bs, seq_world_size, shard_seqlen, shard_hc, + hs).transpose(0, 3).transpose(0, 1).contiguous().reshape( + seq_world_size, shard_hc, shard_seqlen, bs, hs)) output = torch.empty_like(input_t) # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single @@ -97,14 +94,17 @@ def _all_to_all_4D( output = output.reshape(hc, shard_seqlen, bs, hs) # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) - output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) + output = output.transpose(0, 2).contiguous().reshape( + bs, shard_seqlen, hc, hs) return output else: - raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") + raise RuntimeError( + "scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") class SeqAllToAll4D(torch.autograd.Function): + @staticmethod def forward( ctx: Any, @@ -120,21 +120,24 @@ def forward( return _all_to_all_4D(input, scatter_idx, gather_idx, group=group) @staticmethod - def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + def backward(ctx: Any, + *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: return ( None, - SeqAllToAll4D.apply( - ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx - ), + SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, + ctx.scatter_idx), None, None, ) def all_to_all_4D( - input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1, + input_: torch.Tensor, + scatter_dim: int = 2, + gather_dim: int = 1, ): - return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, gather_dim) + return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, + gather_dim) def _all_to_all( @@ -145,7 +148,8 @@ def _all_to_all( gather_dim: int, ): input_list = [ - t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim) + t.contiguous() + for t in torch.tensor_split(input_, world_size, scatter_dim) ] output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] dist.all_to_all(output_list, input_list, group=group) @@ -168,9 +172,8 @@ def forward(ctx, input_, process_group, scatter_dim, gather_dim): ctx.scatter_dim = scatter_dim ctx.gather_dim = gather_dim ctx.world_size = dist.get_world_size(process_group) - output = _all_to_all( - input_, ctx.world_size, process_group, scatter_dim, gather_dim - ) + output = _all_to_all(input_, ctx.world_size, process_group, + scatter_dim, gather_dim) return output @staticmethod @@ -191,7 +194,9 @@ def backward(ctx, grad_output): def all_to_all( - input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1, + input_: torch.Tensor, + scatter_dim: int = 2, + gather_dim: int = 1, ): return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim) @@ -248,9 +253,8 @@ def all_gather(input_: torch.Tensor, dim: int = 1): return _AllGather.apply(input_, dim) -def prepare_sequence_parallel_data( - hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask -): +def prepare_sequence_parallel_data(hidden_states, encoder_hidden_states, + attention_mask, encoder_attention_mask): if nccl_info.sp_size == 1: return ( hidden_states, @@ -259,17 +263,18 @@ def prepare_sequence_parallel_data( encoder_attention_mask, ) - def prepare( - hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask - ): + def prepare(hidden_states, encoder_hidden_states, attention_mask, + encoder_attention_mask): hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0) - encoder_hidden_states = all_to_all( - encoder_hidden_states, scatter_dim=1, gather_dim=0 - ) - attention_mask = all_to_all(attention_mask, scatter_dim=1, gather_dim=0) - encoder_attention_mask = all_to_all( - encoder_attention_mask, scatter_dim=1, gather_dim=0 - ) + encoder_hidden_states = all_to_all(encoder_hidden_states, + scatter_dim=1, + gather_dim=0) + attention_mask = all_to_all(attention_mask, + scatter_dim=1, + gather_dim=0) + encoder_attention_mask = all_to_all(encoder_attention_mask, + scatter_dim=1, + gather_dim=0) return ( hidden_states, encoder_hidden_states, @@ -296,9 +301,8 @@ def prepare( return hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask -def sp_parallel_dataloader_wrapper( - dataloader, device, train_batch_size, sp_size, train_sp_batch_size -): +def sp_parallel_dataloader_wrapper(dataloader, device, train_batch_size, + sp_size, train_sp_batch_size): while True: for data_item in dataloader: latents, cond, attn_mask, cond_mask = data_item @@ -311,12 +315,12 @@ def sp_parallel_dataloader_wrapper( yield latents, cond, attn_mask, cond_mask else: latents, cond, attn_mask, cond_mask = prepare_sequence_parallel_data( - latents, cond, attn_mask, cond_mask - ) + latents, cond, attn_mask, cond_mask) assert ( train_batch_size * sp_size >= train_sp_batch_size ), "train_batch_size * sp_size should be greater than train_sp_batch_size" - for iter in range(train_batch_size * sp_size // train_sp_batch_size): + for iter in range(train_batch_size * sp_size // + train_sp_batch_size): st_idx = iter * train_sp_batch_size ed_idx = (iter + 1) * train_sp_batch_size encoder_hidden_states = cond[st_idx:ed_idx] diff --git a/fastvideo/utils/dataset_utils.py b/fastvideo/utils/dataset_utils.py index ae91c70..40ac9b5 100644 --- a/fastvideo/utils/dataset_utils.py +++ b/fastvideo/utils/dataset_utils.py @@ -1,17 +1,14 @@ import math -from einops import rearrange +import random +from collections import Counter +from typing import List, Optional + import decord -from torch.nn import functional as F import torch -from typing import Optional import torch.utils import torch.utils.data -import torch +from torch.nn import functional as F from torch.utils.data import Sampler -from typing import List -from collections import Counter -import random - IMG_EXTENSIONS = [".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG"] @@ -33,17 +30,15 @@ def __call__(self, filename): results (dict): The resulting dict to be modified and passed to the next transform in pipeline. """ - reader = decord.VideoReader( - filename, ctx=self.ctx, num_threads=self.num_threads - ) + reader = decord.VideoReader(filename, + ctx=self.ctx, + num_threads=self.num_threads) return reader def __repr__(self): - repr_str = ( - f"{self.__class__.__name__}(" - f"sr={self.sr}," - f"num_threads={self.num_threads})" - ) + repr_str = (f"{self.__class__.__name__}(" + f"sr={self.sr}," + f"num_threads={self.num_threads})") return repr_str @@ -58,6 +53,7 @@ def pad_to_multiple(number, ds_stride): # TODO class Collate: + def __init__(self, args): self.batch_size = args.train_batch_size self.group_frame = args.group_frame @@ -98,7 +94,8 @@ def __call__(self, batch): self.max_thw, self.ae_stride_thw, ) - assert not torch.any(torch.isnan(pad_batch_tubes)), "after pad_batch_tubes" + assert not torch.any( + torch.isnan(pad_batch_tubes)), "after pad_batch_tubes" return pad_batch_tubes, attention_mask, input_ids, cond_mask def process( @@ -112,18 +109,20 @@ def process( ae_stride_thw, ): # pad to max multiple of ds_stride - batch_input_size = [i.shape for i in batch_tubes] # [(c t h w), (c t h w)] + batch_input_size = [i.shape + for i in batch_tubes] # [(c t h w), (c t h w)] assert len(batch_input_size) == self.batch_size if self.group_frame or self.group_resolution or self.batch_size == 1: # len_each_batch = batch_input_size - idx_length_dict = dict([*zip(list(range(self.batch_size)), len_each_batch)]) + idx_length_dict = dict( + [*zip(list(range(self.batch_size)), len_each_batch)]) count_dict = Counter(len_each_batch) if len(count_dict) != 1: - sorted_by_value = sorted(count_dict.items(), key=lambda item: item[1]) + sorted_by_value = sorted(count_dict.items(), + key=lambda item: item[1]) pick_length = sorted_by_value[-1][0] # the highest frequency candidate_batch = [ - idx - for idx, length in idx_length_dict.items() + idx for idx, length in idx_length_dict.items() if length == pick_length ] random_select_batch = [ @@ -142,9 +141,8 @@ def process( pick_idx = candidate_batch + random_select_batch batch_tubes = [batch_tubes[i] for i in pick_idx] - batch_input_size = [ - i.shape for i in batch_tubes - ] # [(c t h w), (c t h w)] + batch_input_size = [i.shape for i in batch_tubes + ] # [(c t h w), (c t h w)] input_ids = [input_ids[i] for i in pick_idx] # b [1, l] cond_mask = [cond_mask[i] for i in pick_idx] # b [1, l] @@ -161,10 +159,10 @@ def process( pad_to_multiple(max_w, ds_stride), ) pad_max_t = pad_max_t + 1 - self.ae_stride_t - each_pad_t_h_w = [ - [pad_max_t - i.shape[1], pad_max_h - i.shape[2], pad_max_w - i.shape[3]] - for i in batch_tubes - ] + each_pad_t_h_w = [[ + pad_max_t - i.shape[1], pad_max_h - i.shape[2], + pad_max_w - i.shape[3] + ] for i in batch_tubes] pad_batch_tubes = [ F.pad(im, (0, pad_w, 0, pad_h, 0, pad_t), value=0) for (pad_t, pad_h, pad_w), im in zip(each_pad_t_h_w, batch_tubes) @@ -177,14 +175,11 @@ def process( max_tube_size[1] // ae_stride_thw[1], max_tube_size[2] // ae_stride_thw[2], ] - valid_latent_size = [ - [ - int(math.ceil((i[1] - 1) / ae_stride_thw[0])) + 1, - int(math.ceil(i[2] / ae_stride_thw[1])), - int(math.ceil(i[3] / ae_stride_thw[2])), - ] - for i in batch_input_size - ] + valid_latent_size = [[ + int(math.ceil((i[1] - 1) / ae_stride_thw[0])) + 1, + int(math.ceil(i[2] / ae_stride_thw[1])), + int(math.ceil(i[3] / ae_stride_thw[2])), + ] for i in batch_input_size] attention_mask = [ F.pad( torch.ones(i, dtype=pad_batch_tubes.dtype), @@ -197,8 +192,7 @@ def process( max_latent_size[0] - i[0], ), value=0, - ) - for i in valid_latent_size + ) for i in valid_latent_size ] attention_mask = torch.stack(attention_mask) # b t h w if self.batch_size == 1 or self.group_frame or self.group_resolution: @@ -236,7 +230,8 @@ def split_to_even_chunks(indices, lengths, num_chunks, batch_size): assert batch_size > len(chunk) if len(chunk) != 0: chunk = chunk + [ - random.choice(chunk) for _ in range(batch_size - len(chunk)) + random.choice(chunk) + for _ in range(batch_size - len(chunk)) ] else: chunk = random.choice(pad_chunks) @@ -261,10 +256,12 @@ def megabatch_frame_alignment(megabatches, lengths): # mixed frame length, align megabatch inside if len(count_dict) != 1: - sorted_by_value = sorted(count_dict.items(), key=lambda item: item[1]) + sorted_by_value = sorted(count_dict.items(), + key=lambda item: item[1]) pick_length = sorted_by_value[-1][0] # the highest frequency candidate_batch = [ - idx for idx, length in idx_length_dict.items() if length == pick_length + idx for idx, length in idx_length_dict.items() + if length == pick_length ] random_select_batch = [ random.choice(candidate_batch) @@ -291,8 +288,7 @@ def get_length_grouped_indices( # We need to use torch for the random part as a distributed sampler will set the random seed for torch. if generator is None: generator = torch.Generator().manual_seed( - seed - ) # every rank will generate a fixed order but random index + seed) # every rank will generate a fixed order but random index indices = torch.randperm(len(lengths), generator=generator).tolist() @@ -302,7 +298,8 @@ def get_length_grouped_indices( # chunk dataset to megabatches megabatch_size = world_size * batch_size megabatches = [ - indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size) + indices[i:i + megabatch_size] + for i in range(0, len(lengths), megabatch_size) ] # make sure the length in each magabatch is align with each other @@ -320,7 +317,8 @@ def get_length_grouped_indices( # expand indices and return return [ - i for megabatch in shuffled_megabatches for batch in megabatch for i in batch + i for megabatch in shuffled_megabatches for batch in megabatch + for i in batch ] @@ -368,11 +366,10 @@ def distributed_sampler(lst, rank, batch_size, world_size): result = [] index = rank * batch_size while index < len(lst): - result.extend(lst[index : index + batch_size]) + result.extend(lst[index:index + batch_size]) index += batch_size * world_size return result - indices = distributed_sampler( - indices, self.rank, self.batch_size, self.world_size - ) + indices = distributed_sampler(indices, self.rank, self.batch_size, + self.world_size) return iter(indices) diff --git a/fastvideo/utils/env_utils.py b/fastvideo/utils/env_utils.py index b9c9b76..206c4aa 100644 --- a/fastvideo/utils/env_utils.py +++ b/fastvideo/utils/env_utils.py @@ -6,10 +6,8 @@ import transformers from transformers.utils import is_torch_cuda_available, is_torch_npu_available - VERSION = "1.2.0" - if __name__ == "__main__": info = { "FastVideo version": VERSION, @@ -28,7 +26,7 @@ if is_torch_npu_available(): info["PyTorch version"] += " (NPU)" info["NPU type"] = torch.npu.get_device_name() - info["CANN version"] = torch.version.cann + info["CANN version"] = torch.version.cann # codespell:ignore try: import bitsandbytes @@ -37,4 +35,6 @@ except Exception: pass - print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n") \ No newline at end of file + print("\n" + + "\n".join([f"- {key}: {value}" + for key, value in info.items()]) + "\n") diff --git a/fastvideo/utils/fsdp_util.py b/fastvideo/utils/fsdp_util.py index 04fb5f2..cadf33a 100644 --- a/fastvideo/utils/fsdp_util.py +++ b/fastvideo/utils/fsdp_util.py @@ -1,35 +1,20 @@ -from sympy import use -import torch -import os -import torch.distributed as dist -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper, - CheckpointImpl, - apply_activation_checkpointing, -) -from peft.utils.other import fsdp_auto_wrap_policy - -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - StateDictType, - FullStateDictConfig, # general model non-sharded, non-flattened params - LocalStateDictConfig, # flattened params, usable only by FSDP - # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. -) - -from fastvideo.utils.load import get_no_split_modules -from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformerBlock - +# ruff: noqa: E731 +import functools from functools import partial -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy - +import torch +from peft.utils.other import fsdp_auto_wrap_policy +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper) from torch.distributed.fsdp import MixedPrecision, ShardingStrategy -import functools +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformerBlock +from fastvideo.utils.load import get_no_split_modules non_reentrant_wrapper = partial( - checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT, + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) check_fn = lambda submodule: isinstance(submodule, MochiTransformerBlock) @@ -40,7 +25,7 @@ def apply_fsdp_checkpointing(model, no_split_modules, p=1): """apply activation checkpointing to model returns None as model is updated directly """ - print(f"--> applying fdsp activation checkpointing...") + print("--> applying fdsp activation checkpointing...") block_idx = 0 cut_off = 1 / 2 # when passing p as a fraction number (e.g. 1/3), it will be interpreted @@ -90,7 +75,8 @@ def get_dit_fsdp_kwargs( auto_wrap_policy = fsdp_auto_wrap_policy else: auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, transformer_layer_cls=no_split_modules, + transformer_auto_wrap_policy, + transformer_layer_cls=no_split_modules, ) # we use float32 for fsdp but autocast during training @@ -107,9 +93,8 @@ def get_dit_fsdp_kwargs( sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 device_id = torch.cuda.current_device() - cpu_offload = ( - torch.distributed.fsdp.CPUOffload(offload_params=True) if cpu_offload else None - ) + cpu_offload = (torch.distributed.fsdp.CPUOffload( + offload_params=True) if cpu_offload else None) fsdp_kwargs = { "auto_wrap_policy": auto_wrap_policy, "mixed_precision": mixed_precision, @@ -121,12 +106,10 @@ def get_dit_fsdp_kwargs( # Add LoRA-specific settings when LoRA is enabled if use_lora: - fsdp_kwargs.update( - { - "use_orig_params": False, # Required for LoRA memory savings - "sync_module_states": True, - } - ) + fsdp_kwargs.update({ + "use_orig_params": False, # Required for LoRA memory savings + "sync_module_states": True, + }) return fsdp_kwargs, no_split_modules diff --git a/fastvideo/utils/load.py b/fastvideo/utils/load.py index d29cb00..8637feb 100644 --- a/fastvideo/utils/load.py +++ b/fastvideo/utils/load.py @@ -1,23 +1,20 @@ -import torch -from fastvideo.models.mochi_hf.modeling_mochi import ( - MochiTransformer3DModel, - MochiTransformerBlock, -) -from fastvideo.models.hunyuan.modules.models import ( - HYVideoDiffusionTransformer, - MMDoubleStreamBlock, - MMSingleStreamBlock, -) -from fastvideo.models.hunyuan.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D -from diffusers import AutoencoderKLMochi -from transformers import T5EncoderModel, AutoTokenizer import os -from torch import nn - # Path from pathlib import Path + +import torch import torch.nn.functional as F +from diffusers import AutoencoderKLMochi +from torch import nn +from transformers import AutoTokenizer, T5EncoderModel + +from fastvideo.models.hunyuan.modules.models import ( + HYVideoDiffusionTransformer, MMDoubleStreamBlock, MMSingleStreamBlock) from fastvideo.models.hunyuan.text_encoder import TextEncoder +from fastvideo.models.hunyuan.vae.autoencoder_kl_causal_3d import \ + AutoencoderKLCausal3D +from fastvideo.models.mochi_hf.modeling_mochi import (MochiTransformer3DModel, + MochiTransformerBlock) from fastvideo.utils.logging_ import main_print hunyuan_config = { @@ -30,12 +27,10 @@ "guidance_embed": True, } - PROMPT_TEMPLATE_ENCODE = ( "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, " "quantity, text, spatial relationships of the objects and background:<|eot_id|>" - "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" -) + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>") PROMPT_TEMPLATE_ENCODE_VIDEO = ( "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " "1. The main content and theme of the video." @@ -43,13 +38,15 @@ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." "4. background environment, light, style and atmosphere." "5. camera angles, movements, and transitions used in the video:<|eot_id|>" - "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" -) + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>") NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" PROMPT_TEMPLATE = { - "dit-llm-encode": {"template": PROMPT_TEMPLATE_ENCODE, "crop_start": 36,}, + "dit-llm-encode": { + "template": PROMPT_TEMPLATE_ENCODE, + "crop_start": 36, + }, "dit-llm-encode-video": { "template": PROMPT_TEMPLATE_ENCODE_VIDEO, "crop_start": 95, @@ -58,11 +55,13 @@ class HunyuanTextEncoderWrapper(nn.Module): + def __init__(self, pretrained_model_name_or_path, device): super().__init__() text_len = 256 - crop_start = PROMPT_TEMPLATE["dit-llm-encode-video"].get("crop_start", 0) + crop_start = PROMPT_TEMPLATE["dit-llm-encode-video"].get( + "crop_start", 0) max_length = text_len + crop_start @@ -71,7 +70,8 @@ def __init__(self, pretrained_model_name_or_path, device): # prompt_template_video prompt_template_video = PROMPT_TEMPLATE["dit-llm-encode-video"] - text_encoder_path = os.path.join(pretrained_model_name_or_path, "text_encoder") + text_encoder_path = os.path.join(pretrained_model_name_or_path, + "text_encoder") self.text_encoder = TextEncoder( text_encoder_type="llm", text_encoder_path=text_encoder_path, @@ -86,9 +86,8 @@ def __init__(self, pretrained_model_name_or_path, device): logger=None, device=device, ) - text_encoder_path_2 = os.path.join( - pretrained_model_name_or_path, "text_encoder_2" - ) + text_encoder_path_2 = os.path.join(pretrained_model_name_or_path, + "text_encoder_2") self.text_encoder_2 = TextEncoder( text_encoder_type="clipL", text_encoder_path=text_encoder_path_2, @@ -109,9 +108,9 @@ def encode_(self, prompt, text_encoder, clip_skip=None): text_inputs = text_encoder.text2tokens(prompt, data_type=data_type) if clip_skip is None: - prompt_outputs = text_encoder.encode( - text_inputs, data_type="video", device=device - ) + prompt_outputs = text_encoder.encode(text_inputs, + data_type="video", + device=device) prompt_embeds = prompt_outputs.hidden_state else: prompt_outputs = text_encoder.encode( @@ -123,8 +122,7 @@ def encode_(self, prompt, text_encoder, clip_skip=None): prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)] prompt_embeds = text_encoder.model.text_model.final_layer_norm( - prompt_embeds - ) + prompt_embeds) attention_mask = prompt_outputs.attention_mask if attention_mask is not None: @@ -132,8 +130,7 @@ def encode_(self, prompt, text_encoder, clip_skip=None): bs_embed, seq_len = attention_mask.shape attention_mask = attention_mask.repeat(1, num_videos_per_prompt) attention_mask = attention_mask.view( - bs_embed * num_videos_per_prompt, seq_len - ) + bs_embed * num_videos_per_prompt, seq_len) if text_encoder is not None: prompt_embeds_dtype = text_encoder.dtype @@ -142,25 +139,27 @@ def encode_(self, prompt, text_encoder, clip_skip=None): else: prompt_embeds_dtype = prompt_embeds.dtype - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, + device=device) if prompt_embeds.ndim == 2: bs_embed, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) - prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_videos_per_prompt, -1) else: bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view( - bs_embed * num_videos_per_prompt, seq_len, -1 - ) + bs_embed * num_videos_per_prompt, seq_len, -1) return (prompt_embeds, attention_mask) def encode_prompt(self, prompt): prompt_embeds, attention_mask = self.encode_(prompt, self.text_encoder) - prompt_embeds_2, attention_mask_2 = self.encode_(prompt, self.text_encoder_2) + prompt_embeds_2, attention_mask_2 = self.encode_( + prompt, self.text_encoder_2) prompt_embeds_2 = F.pad( prompt_embeds_2, (0, prompt_embeds.shape[2] - prompt_embeds_2.shape[1]), @@ -171,14 +170,14 @@ def encode_prompt(self, prompt): class MochiTextEncoderWrapper(nn.Module): + def __init__(self, pretrained_model_name_or_path, device): super().__init__() self.text_encoder = T5EncoderModel.from_pretrained( - os.path.join(pretrained_model_name_or_path, "text_encoder") - ).to(device) + os.path.join(pretrained_model_name_or_path, + "text_encoder")).to(device) self.tokenizer = AutoTokenizer.from_pretrained( - os.path.join(pretrained_model_name_or_path, "tokenizer") - ) + os.path.join(pretrained_model_name_or_path, "tokenizer")) self.max_sequence_length = 256 def encode_prompt(self, prompt): @@ -200,22 +199,19 @@ def encode_prompt(self, prompt): prompt_attention_mask = text_inputs.attention_mask prompt_attention_mask = prompt_attention_mask.bool().to(device) - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids + untruncated_ids = self.tokenizer(prompt, + padding="longest", + return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.max_sequence_length - 1 : -1] - ) + untruncated_ids[:, self.max_sequence_length - 1:-1]) main_print( f"Truncated text input: {prompt} to: {removed_text} for model input." ) prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=prompt_attention_mask - )[0] + text_input_ids.to(device), attention_mask=prompt_attention_mask)[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -231,11 +227,12 @@ def load_hunyuan_state_dict(model, dit_model_name_or_path): model_path = dit_model_name_or_path bare_model = "unknown" - state_dict = torch.load( - model_path, map_location=lambda storage, loc: storage, weights_only=True - ) + state_dict = torch.load(model_path, + map_location=lambda storage, loc: storage, + weights_only=True) - if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict): + if bare_model == "unknown" and ("ema" in state_dict + or "module" in state_dict): bare_model = False if bare_model is False: if load_key in state_dict: @@ -243,8 +240,7 @@ def load_hunyuan_state_dict(model, dit_model_name_or_path): else: raise KeyError( f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint " - f"are: {list(state_dict.keys())}." - ) + f"are: {list(state_dict.keys())}.") model.load_state_dict(state_dict, strict=True) return model @@ -271,9 +267,13 @@ def load_transformer( ) elif model_type == "hunyuan": transformer = HYVideoDiffusionTransformer( - in_channels=16, out_channels=16, **hunyuan_config, dtype=master_weight_type, + in_channels=16, + out_channels=16, + **hunyuan_config, + dtype=master_weight_type, ) - transformer = load_hunyuan_state_dict(transformer, dit_model_name_or_path) + transformer = load_hunyuan_state_dict(transformer, + dit_model_name_or_path) else: raise ValueError(f"Unsupported model type: {model_type}") return transformer @@ -283,15 +283,15 @@ def load_vae(model_type, pretrained_model_name_or_path): weight_dtype = torch.float32 if model_type == "mochi": vae = AutoencoderKLMochi.from_pretrained( - pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype - ).to("cuda") + pretrained_model_name_or_path, + subfolder="vae", + torch_dtype=weight_dtype).to("cuda") autocast_type = torch.bfloat16 fps = 30 elif model_type == "hunyuan": vae_precision = torch.float32 - vae_path = os.path.join( - pretrained_model_name_or_path, "hunyuan-video-t2v-720p/vae" - ) + vae_path = os.path.join(pretrained_model_name_or_path, + "hunyuan-video-t2v-720p/vae") config = AutoencoderKLCausal3D.load_config(vae_path) vae = AutoencoderKLCausal3D.from_config(config) @@ -305,8 +305,7 @@ def load_vae(model_type, pretrained_model_name_or_path): if any(k.startswith("vae.") for k in ckpt.keys()): ckpt = { k.replace("vae.", ""): v - for k, v in ckpt.items() - if k.startswith("vae.") + for k, v in ckpt.items() if k.startswith("vae.") } vae.load_state_dict(ckpt) vae = vae.to(dtype=vae_precision) @@ -320,9 +319,11 @@ def load_vae(model_type, pretrained_model_name_or_path): def load_text_encoder(model_type, pretrained_model_name_or_path, device): if model_type == "mochi": - text_encoder = MochiTextEncoderWrapper(pretrained_model_name_or_path, device) + text_encoder = MochiTextEncoderWrapper(pretrained_model_name_or_path, + device) elif model_type == "hunyuan": - text_encoder = HunyuanTextEncoderWrapper(pretrained_model_name_or_path, device) + text_encoder = HunyuanTextEncoderWrapper(pretrained_model_name_or_path, + device) else: raise ValueError(f"Unsupported model type: {model_type}") return text_encoder @@ -331,7 +332,7 @@ def load_text_encoder(model_type, pretrained_model_name_or_path, device): def get_no_split_modules(transformer): # if of type MochiTransformer3DModel if isinstance(transformer, MochiTransformer3DModel): - return (MochiTransformerBlock,) + return (MochiTransformerBlock, ) elif isinstance(transformer, HYVideoDiffusionTransformer): return (MMDoubleStreamBlock, MMSingleStreamBlock) else: @@ -342,6 +343,7 @@ def get_no_split_modules(transformer): # test encode prompt device = torch.cuda.current_device() pretrained_model_name_or_path = "data/hunyuan" - text_encoder = load_text_encoder("hunyuan", pretrained_model_name_or_path, device) + text_encoder = load_text_encoder("hunyuan", pretrained_model_name_or_path, + device) prompt = "A man on stage claps his hands together while facing the audience. The audience, visible in the foreground, holds up mobile devices to record the event, capturing the moment from various angles. The background features a large banner with text identifying the man on stage. Throughout the sequence, the man's expression remains engaged and directed towards the audience. The camera angle remains constant, focusing on capturing the interaction between the man on stage and the audience." prompt_embeds, attention_mask = text_encoder.encode_prompt(prompt) diff --git a/fastvideo/utils/logging_.py b/fastvideo/utils/logging_.py index ea2fe86..ee1523f 100644 --- a/fastvideo/utils/logging_.py +++ b/fastvideo/utils/logging_.py @@ -1,6 +1,6 @@ -import sys -import pdb import os +import pdb +import sys def main_print(content): diff --git a/fastvideo/utils/optimizer.py b/fastvideo/utils/optimizer.py index f7a70c7..ffee08a 100644 --- a/fastvideo/utils/optimizer.py +++ b/fastvideo/utils/optimizer.py @@ -1,5 +1,5 @@ -from accelerate.logging import get_logger import torch +from accelerate.logging import get_logger logger = get_logger(__name__) @@ -13,11 +13,11 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): ) args.optimizer = "adamw" - if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]): + if args.use_8bit_adam and not (args.optimizer.lower() + not in ["adam", "adamw"]): logger.warning( f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " - f"set to {args.optimizer.lower()}" - ) + f"set to {args.optimizer.lower()}") if args.use_8bit_adam: try: @@ -28,9 +28,8 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): ) if args.optimizer.lower() == "adamw": - optimizer_class = ( - bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW - ) + optimizer_class = (bnb.optim.AdamW8bit + if args.use_8bit_adam else torch.optim.AdamW) optimizer = optimizer_class( params_to_optimize, diff --git a/fastvideo/utils/parallel_states.py b/fastvideo/utils/parallel_states.py index 1897db2..552a9dc 100644 --- a/fastvideo/utils/parallel_states.py +++ b/fastvideo/utils/parallel_states.py @@ -1,9 +1,10 @@ -import torch -import torch.distributed as dist import os +import torch.distributed as dist + class COMM_INFO: + def __init__(self): self.group = None self.sp_size = 1 @@ -44,13 +45,13 @@ def initialize_sequence_parallel_group(sequence_parallel_size): assert ( world_size % sequence_parallel_size == 0 ), "world_size must be divisible by sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format( - world_size, sequence_parallel_size - ) + world_size, sequence_parallel_size) nccl_info.sp_size = sequence_parallel_size nccl_info.global_rank = rank num_sequence_parallel_groups: int = world_size // sequence_parallel_size for i in range(num_sequence_parallel_groups): - ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) + ranks = range(i * sequence_parallel_size, + (i + 1) * sequence_parallel_size) group = dist.new_group(ranks) if rank in ranks: nccl_info.group = group diff --git a/fastvideo/utils/validation.py b/fastvideo/utils/validation.py index aabbc5a..09ec70a 100644 --- a/fastvideo/utils/validation.py +++ b/fastvideo/utils/validation.py @@ -1,27 +1,24 @@ -from typing import Optional, Union, List +import gc +import os +from typing import List, Optional, Union + import numpy as np import torch -from einops import rearrange -from fastvideo.utils.parallel_states import get_sequence_parallel_state, nccl_info -from fastvideo.utils.communications import all_gather +import wandb +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import export_to_video from diffusers.utils.torch_utils import randn_tensor -from fastvideo.models.mochi_hf.pipeline_mochi import ( - linear_quadratic_schedule, - retrieve_timesteps, -) -from tqdm import tqdm from diffusers.video_processor import VideoProcessor -from diffusers import ( - FlowMatchEulerDiscreteScheduler, - AutoencoderKLMochi, -) -from fastvideo.utils.logging_ import main_print +from einops import rearrange +from tqdm import tqdm + from fastvideo.distill.solver import PCMFMScheduler -from diffusers.utils import export_to_video -import os -import wandb -import gc +from fastvideo.models.mochi_hf.pipeline_mochi import ( + linear_quadratic_schedule, retrieve_timesteps) +from fastvideo.utils.communications import all_gather from fastvideo.utils.load import load_vae +from fastvideo.utils.parallel_states import (get_sequence_parallel_state, + nccl_info) def prepare_latents( @@ -42,7 +39,10 @@ def prepare_latents( shape = (batch_size, num_channels_latents, num_frames, height, width) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor(shape, + generator=generator, + device=device, + dtype=dtype) return latents @@ -74,10 +74,10 @@ def sample_validation_video( do_classifier_free_guidance = guidance_scale > 1.0 if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], + dim=0) prompt_attention_mask = torch.cat( - [negative_prompt_attention_mask, prompt_attention_mask], dim=0 - ) + [negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare latent variables # TODO: Remove hardcore @@ -95,9 +95,9 @@ def sample_validation_video( ) world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group if get_sequence_parallel_state(): - latents = rearrange( - latents, "b t (n s) h w -> b t n s h w", n=world_size - ).contiguous() + latents = rearrange(latents, + "b t (n s) h w -> b t n s h w", + n=world_size).contiguous() latents = latents[:, :, rank, :, :, :] # 5. Prepare timestep @@ -107,13 +107,20 @@ def sample_validation_video( sigmas = np.array(sigmas) if scheduler_type == "euler": timesteps, num_inference_steps = retrieve_timesteps( - scheduler, num_inference_steps, device, timesteps, sigmas, + scheduler, + num_inference_steps, + device, + timesteps, + sigmas, ) else: timesteps, num_inference_steps = retrieve_timesteps( - scheduler, num_inference_steps, device, + scheduler, + num_inference_steps, + device, ) - num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * scheduler.order, 0) # 6. Denoising loop # with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -121,14 +128,13 @@ def sample_validation_video( # only enable if nccl_info.global_rank == 0 with tqdm( - total=num_inference_steps, - disable=nccl_info.rank_within_group != 0, - desc="Validation sampling...", + total=num_inference_steps, + disable=nccl_info.rank_within_group != 0, + desc="Validation sampling...", ) as progress_bar: for i, t in enumerate(timesteps): - latent_model_input = ( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents - ) + latent_model_input = (torch.cat([latents] * 2) + if do_classifier_free_guidance else latents) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) with torch.autocast("cuda", dtype=torch.bfloat16): @@ -145,14 +151,14 @@ def sample_validation_video( if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = scheduler.step( - noise_pred, t, latents.to(torch.float32), return_dict=False - )[0] + latents = scheduler.step(noise_pred, + t, + latents.to(torch.float32), + return_dict=False)[0] latents = latents.to(latents_dtype) if latents.dtype != latents_dtype: @@ -160,9 +166,8 @@ def sample_validation_video( # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and + (i + 1) % scheduler.order == 0): progress_bar.update() if get_sequence_parallel_state(): @@ -173,32 +178,26 @@ def sample_validation_video( else: # unscale/denormalize the latents # denormalize with the mean and std if available and not None - has_latents_mean = ( - hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None - ) - has_latents_std = ( - hasattr(vae.config, "latents_std") and vae.config.latents_std is not None - ) + has_latents_mean = (hasattr(vae.config, "latents_mean") + and vae.config.latents_mean is not None) + has_latents_std = (hasattr(vae.config, "latents_std") + and vae.config.latents_std is not None) if has_latents_mean and has_latents_std: - latents_mean = ( - torch.tensor(vae.config.latents_mean) - .view(1, 12, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = ( - torch.tensor(vae.config.latents_std) - .view(1, 12, 1, 1, 1) - .to(latents.device, latents.dtype) - ) + latents_mean = (torch.tensor(vae.config.latents_mean).view( + 1, 12, 1, 1, 1).to(latents.device, latents.dtype)) + latents_std = (torch.tensor(vae.config.latents_std).view( + 1, 12, 1, 1, 1).to(latents.device, latents.dtype)) latents = latents * latents_std / vae.config.scaling_factor + latents_mean else: latents = latents / vae.config.scaling_factor with torch.autocast("cuda", dtype=vae.dtype): video = vae.decode(latents, return_dict=False)[0] - video_processor = VideoProcessor(vae_scale_factor=vae_spatial_scale_factor) - video = video_processor.postprocess_video(video, output_type=output_type) + video_processor = VideoProcessor( + vae_scale_factor=vae_spatial_scale_factor) + video = video_processor.postprocess_video(video, + output_type=output_type) - return (video,) + return (video, ) @torch.no_grad() @@ -217,7 +216,7 @@ def log_validation( ema=False, ): # TODO - print(f"Running validation....\n") + print("Running validation....\n") if args.model_type == "mochi": vae_spatial_scale_factor = 8 vae_temporal_scale_factor = 6 @@ -228,9 +227,8 @@ def log_validation( num_channels_latents = 16 else: raise ValueError(f"Model type {args.model_type} not supported") - vae, autocast_type, fps = load_vae( - args.model_type, args.pretrained_model_name_or_path - ) + vae, autocast_type, fps = load_vae(args.model_type, + args.pretrained_model_name_or_path) vae.enable_tiling() if scheduler_type == "euler": scheduler = FlowMatchEulerDiscreteScheduler() @@ -257,41 +255,37 @@ def log_validation( # prompt_embed are named embed0 to embedN # check how many embeds are there embe_dir = os.path.join(args.validation_prompt_dir, "prompt_embed") - mask_dir = os.path.join(args.validation_prompt_dir, "prompt_attention_mask") + mask_dir = os.path.join(args.validation_prompt_dir, + "prompt_attention_mask") embeds = sorted([f for f in os.listdir(embe_dir)]) masks = sorted([f for f in os.listdir(mask_dir)]) num_embeds = len(embeds) validation_prompt_ids = list(range(num_embeds)) - num_sp_groups = int(os.getenv("WORLD_SIZE", "1")) // nccl_info.sp_size + num_sp_groups = int(os.getenv("WORLD_SIZE", + "1")) // nccl_info.sp_size # pad to multiple of groups if num_embeds % num_sp_groups != 0: - validation_prompt_ids += [0] * ( - num_sp_groups - num_embeds % num_sp_groups - ) + validation_prompt_ids += [0] * (num_sp_groups - + num_embeds % num_sp_groups) num_embeds_per_group = len(validation_prompt_ids) // num_sp_groups - local_prompt_ids = validation_prompt_ids[ - nccl_info.group_id - * num_embeds_per_group : (nccl_info.group_id + 1) - * num_embeds_per_group - ] + local_prompt_ids = validation_prompt_ids[nccl_info.group_id * + num_embeds_per_group: + (nccl_info.group_id + 1) * + num_embeds_per_group] for i in local_prompt_ids: prompt_embed_path = os.path.join(embe_dir, f"{embeds[i]}") prompt_mask_path = os.path.join(mask_dir, f"{masks[i]}") - prompt_embeds = ( - torch.load(prompt_embed_path, map_location="cpu", weights_only=True) - .to(device) - .unsqueeze(0) - ) - prompt_attention_mask = ( - torch.load(prompt_mask_path, map_location="cpu", weights_only=True) - .to(device) - .unsqueeze(0) - ) - negative_prompt_embeds = torch.zeros(256, 4096).to(device).unsqueeze(0) + prompt_embeds = (torch.load( + prompt_embed_path, map_location="cpu", + weights_only=True).to(device).unsqueeze(0)) + prompt_attention_mask = (torch.load( + prompt_mask_path, map_location="cpu", + weights_only=True).to(device).unsqueeze(0)) + negative_prompt_embeds = torch.zeros( + 256, 4096).to(device).unsqueeze(0) negative_prompt_attention_mask = ( - torch.zeros(256).bool().to(device).unsqueeze(0) - ) + torch.zeros(256).bool().to(device).unsqueeze(0)) generator = torch.Generator(device="cuda").manual_seed(12345) video = sample_validation_video( transformer, @@ -308,7 +302,8 @@ def log_validation( prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_embeds=negative_prompt_embeds, - negative_prompt_attention_mask=negative_prompt_attention_mask, + negative_prompt_attention_mask= + negative_prompt_attention_mask, vae_spatial_scale_factor=vae_spatial_scale_factor, vae_temporal_scale_factor=vae_temporal_scale_factor, num_channels_latents=num_channels_latents, @@ -340,7 +335,8 @@ def log_validation( video_filenames.append(filename) logs = { - f"{'ema_' if ema else ''}validation_sample_{validation_sampling_step}_guidance_{validation_guidance_scale}": [ + f"{'ema_' if ema else ''}validation_sample_{validation_sampling_step}_guidance_{validation_guidance_scale}": + [ wandb.Video(filename) for i, filename in enumerate(video_filenames) ] diff --git a/format.sh b/format.sh new file mode 100644 index 0000000..0f11734 --- /dev/null +++ b/format.sh @@ -0,0 +1,237 @@ +#!/usr/bin/env bash +# YAPF formatter, adapted from fastvideo. +# +# Usage: +# # Do work and commit your work. + +# # Format files that differ from origin/main. +# bash format.sh + +# # Commit changed files with message 'Run yapf and ruff' +# +# +# This script formats all changed files from the last mergebase. +# You are encouraged to run this locally before pushing changes for review. + +# Cause the script to exit if a single command fails +set -eo pipefail + +# this stops git rev-parse from failing if we run this from the .git directory +builtin cd "$(dirname "${BASH_SOURCE:-$0}")" +ROOT="$(git rev-parse --show-toplevel)" +builtin cd "$ROOT" || exit 1 + +check_command() { + if ! command -v "$1" &> /dev/null; then + echo "❓❓$1 is not installed, please run \`bash env_setup.sh\`" + exit 1 + fi +} + +check_command yapf +check_command ruff +check_command codespell +check_command isort + +YAPF_VERSION=$(yapf --version | awk '{print $2}') +RUFF_VERSION=$(ruff --version | awk '{print $2}') +CODESPELL_VERSION=$(codespell --version) +ISORT_VERSION=$(isort --vn) +SPHINX_LINT_VERSION=$(sphinx-lint --version | awk '{print $2}') + + +# # params: tool name, tool version, required version +tool_version_check() { + expected=$(grep "$1" requirements-lint.txt | cut -d'=' -f3) + if [[ "$2" != "$expected" ]]; then + echo "❓❓Wrong $1 version installed: $expected is required, not $2." + exit 1 + fi +} + +tool_version_check "yapf" "$YAPF_VERSION" +tool_version_check "ruff" "$RUFF_VERSION" +tool_version_check "isort" "$ISORT_VERSION" +tool_version_check "codespell" "$CODESPELL_VERSION" +tool_version_check "sphinx-lint" "$SPHINX_LINT_VERSION" + +YAPF_FLAGS=( + '--recursive' + '--parallel' +) + +YAPF_EXCLUDES=( + '--exclude' 'data/**' +) + +# Format specified files +format() { + yapf --in-place "${YAPF_FLAGS[@]}" "$@" +} + +# Format files that differ from main branch. Ignores dirs that are not slated +# for autoformat yet. +format_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause yapf to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that + # exist on both branches. + MERGEBASE="$(git merge-base origin/main HEAD)" + + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \ + yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" + fi + +} + +# Format all files +format_all() { + yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" . +} + +## This flag formats individual files. --files *must* be the first command line +## arg to use this option. +if [[ "$1" == '--files' ]]; then + format "${@:2}" + # If `--all` is passed, then any further arguments are ignored and the + # entire python directory is formatted. +elif [[ "$1" == '--all' ]]; then + format_all +else + # Format only the files that changed in last commit. + format_changed +fi +echo 'FastVideo yapf: Done' + + +# If git diff returns a file that is in the skip list, the file may be checked anyway: +# https://github.com/codespell-project/codespell/issues/1915 +# Avoiding the "./" prefix and using "/**" globs for directories appears to solve the problem +CODESPELL_EXCLUDES=( + '--skip' 'data/**, + fastvideo/distill.py, + fastvideo/models/hunyuan/modules/models.py, + fastvideo/models/mochi_hf/modeling_mochi.py, + fastvideo/utils/env_utils.py' +) + +# check spelling of specified files +spell_check() { + codespell "$@" +} + +spell_check_all(){ + codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" +} + +# Spelling check of files that differ from main branch. +spell_check_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause ruff to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that + # exist on both branches. + MERGEBASE="$(git merge-base origin/main HEAD)" + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ + codespell "${CODESPELL_EXCLUDES[@]}" + fi +} + +# Run Codespell +## This flag runs spell check of individual files. --files *must* be the first command line +## arg to use this option. +if [[ "$1" == '--files' ]]; then + spell_check "${@:2}" + # If `--all` is passed, then any further arguments are ignored and the + # entire python directory is linted. +elif [[ "$1" == '--all' ]]; then + spell_check_all +else + # Check spelling only of the files that changed in last commit. + spell_check_changed +fi +echo 'FastVideo codespell: Done' + + +# Lint specified files +lint() { + ruff check "$@" +} + +# Lint files that differ from main branch. Ignores dirs that are not slated +# for autolint yet. +lint_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause ruff to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that + # exist on both branches. + MERGEBASE="$(git merge-base origin/main HEAD)" + + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ + ruff check + fi + +} + +# Run Ruff +### This flag lints individual files. --files *must* be the first command line +### arg to use this option. +if [[ "$1" == '--files' ]]; then + lint "${@:2}" + # If `--all` is passed, then any further arguments are ignored and the + # entire python directory is linted. +elif [[ "$1" == '--all' ]]; then + lint fastvideo scripts +else + # Format only the files that changed in last commit. + lint_changed +fi +echo 'FastVideo ruff: Done' + +# check spelling of specified files +isort_check() { + isort "$@" +} + +isort_check_all(){ + isort . +} + +# Spelling check of files that differ from main branch. +isort_check_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause ruff to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that + # exist on both branches. + MERGEBASE="$(git merge-base origin/main HEAD)" + + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ + isort + fi +} + +# Run Isort +# This flag runs spell check of individual files. --files *must* be the first command line +# arg to use this option. +if [[ "$1" == '--files' ]]; then + isort_check "${@:2}" + # If `--all` is passed, then any further arguments are ignored and the + # entire python directory is linted. +elif [[ "$1" == '--all' ]]; then + isort_check_all +else + # Check spelling only of the files that changed in last commit. + isort_check_changed +fi +echo 'FastVideo isort: Done' diff --git a/predict.py b/predict.py index bbbf1c0..2e310e5 100644 --- a/predict.py +++ b/predict.py @@ -1,22 +1,26 @@ # Prediction interface for Cog ⚙️ # https://cog.run/python -from cog import BasePredictor, Input, Path +import argparse import os +import subprocess import time -import torch + import imageio -import argparse -import subprocess -import torchvision import numpy as np +import torch +import torchvision +from cog import BasePredictor, Input, Path from einops import rearrange -MODEL_CACHE = 'FastHunyuan' -os.environ['MODEL_BASE'] = './'+MODEL_CACHE from fastvideo.models.hunyuan.inference import HunyuanVideoSampler + +MODEL_CACHE = 'FastHunyuan' +os.environ['MODEL_BASE'] = './' + MODEL_CACHE + MODEL_URL = "https://weights.replicate.delivery/default/FastVideo/FastHunyuan/model.tar" + def download_weights(url, dest): start = time.time() print("downloading url: ", url) @@ -24,7 +28,9 @@ def download_weights(url, dest): subprocess.check_call(["pget", "-xf", url, dest], close_fds=False) print("downloading took: ", time.time() - start) + class Predictor(BasePredictor): + def setup(self): """Load the model into memory""" print("Model Base: " + os.environ['MODEL_BASE']) @@ -32,7 +38,8 @@ def setup(self): if not os.path.exists(MODEL_CACHE): download_weights(MODEL_URL, MODEL_CACHE) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") args = argparse.Namespace( num_frames=125, height=720, @@ -49,7 +56,8 @@ def setup(self): num_videos=1, load_key='module', use_cpu_offload=False, - dit_weight='FastHunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt', + dit_weight= + 'FastHunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt', reproduce=True, disable_autocast=False, flow_reverse=True, @@ -79,27 +87,56 @@ def setup(self): text_len_2=77, model_path=MODEL_CACHE, ) - self.model = HunyuanVideoSampler.from_pretrained(MODEL_CACHE, args=args) + self.model = HunyuanVideoSampler.from_pretrained(MODEL_CACHE, + args=args) def predict( self, - prompt: str = Input(description="Text prompt for video generation", default="A cat walks on the grass, realistic style."), - negative_prompt: str = Input(description="Text prompt to specify what you don't want in the video.", default=""), - width: int = Input(description="Width of output video", default=1280, ge=256), - height: int = Input(description="Height of output video", default=720, ge=256), - num_frames: int = Input(description="Number of frames to generate", default=125, ge=16), - num_inference_steps: int = Input(description="Number of denoising steps", default=6, ge=1, le=50), - guidance_scale: float = Input(description="Classifier free guidance scale", default=1.0, ge=0.1, le=10.0), - embedded_cfg_scale: float = Input(description="Embedded classifier free guidance scale", default=6.0, ge=0.1, le=10.0), - flow_shift: int = Input(description="Flow shift parameter", default=17, ge=1, le=20), - fps: int = Input(description="Frames per second of output video", default=24, ge=1, le=60), - seed: int = Input(description="0 for Random seed. Set for reproducible generation", default=0), + prompt: str = Input( + description="Text prompt for video generation", + default="A cat walks on the grass, realistic style."), + negative_prompt: str = Input( + description= + "Text prompt to specify what you don't want in the video.", + default=""), + width: int = Input(description="Width of output video", + default=1280, + ge=256), + height: int = Input(description="Height of output video", + default=720, + ge=256), + num_frames: int = Input(description="Number of frames to generate", + default=125, + ge=16), + num_inference_steps: int = Input( + description="Number of denoising steps", default=6, ge=1, le=50), + guidance_scale: float = Input( + description="Classifier free guidance scale", + default=1.0, + ge=0.1, + le=10.0), + embedded_cfg_scale: float = Input( + description="Embedded classifier free guidance scale", + default=6.0, + ge=0.1, + le=10.0), + flow_shift: int = Input(description="Flow shift parameter", + default=17, + ge=1, + le=20), + fps: int = Input(description="Frames per second of output video", + default=24, + ge=1, + le=60), + seed: int = Input( + description="0 for Random seed. Set for reproducible generation", + default=0), ) -> Path: """Run video generation""" - if seed <=0: + if seed <= 0: seed = int.from_bytes(os.urandom(2), "big") print(f"Using seed: {seed}") - + outputs = self.model.predict( prompt=prompt, height=height, diff --git a/requirements-lint.txt b/requirements-lint.txt new file mode 100644 index 0000000..6ddaed0 --- /dev/null +++ b/requirements-lint.txt @@ -0,0 +1,14 @@ +# formatting +yapf==0.32.0 +toml==0.10.2 +tomli==2.0.2 +ruff==0.6.5 +codespell==2.3.0 +isort==5.13.2 +sphinx-lint==1.0.0 + +# type checking +mypy==1.11.1 +types-PyYAML +types-requests +types-setuptools diff --git a/scripts/huggingface/download_hf.py b/scripts/huggingface/download_hf.py index cef005c..fd26ee7 100644 --- a/scripts/huggingface/download_hf.py +++ b/scripts/huggingface/download_hf.py @@ -1,15 +1,15 @@ -from huggingface_hub import snapshot_download, hf_hub_download import argparse +from huggingface_hub import hf_hub_download, snapshot_download + # set args for repo_id, local_dir, repo_type, if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Download a dataset or model from the Hugging Face Hub" - ) - parser.add_argument( - "--repo_id", type=str, help="The ID of the repository to download" - ) + description="Download a dataset or model from the Hugging Face Hub") + parser.add_argument("--repo_id", + type=str, + help="The ID of the repository to download") parser.add_argument( "--local_dir", type=str, @@ -20,7 +20,9 @@ type=str, help="The type of repository to download (dataset or model)", ) - parser.add_argument("--file_name", type=str, help="The file name to download") + parser.add_argument("--file_name", + type=str, + help="The file name to download") args = parser.parse_args() if args.file_name: hf_hub_download(