From ba08a898940c80a6551111fdd77b53c6d3a019ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 4 Oct 2024 20:35:16 +0900 Subject: [PATCH 1/2] call optimizer eval/train for sample_at_first, also set train after resuming closes #1667 --- flux_train.py | 2 ++ train_network.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/flux_train.py b/flux_train.py index 022467ea7..81c13e4cc 100644 --- a/flux_train.py +++ b/flux_train.py @@ -706,7 +706,9 @@ def optimizer_hook(parameter: torch.Tensor): accelerator.unwrap_model(flux).prepare_block_swap_before_forward() # For --sample_at_first + optimizer_eval_fn() flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb accelerator.log({}, step=0) diff --git a/train_network.py b/train_network.py index 7b2b76a1b..f0d397b9e 100644 --- a/train_network.py +++ b/train_network.py @@ -1042,7 +1042,9 @@ def remove_model(old_ckpt_name): text_encoder = None # For --sample_at_first + optimizer_eval_fn() self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb accelerator.log({}, step=0) From 83e3048cb089bf6726751609da26da751b8383ae Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 6 Oct 2024 21:32:21 +0900 Subject: [PATCH 2/2] load Diffusers format, check schnell/dev --- README.md | 4 + flux_minimal_inference.py | 15 +-- flux_train.py | 15 ++- flux_train_network.py | 17 ++- library/flux_utils.py | 178 +++++++++++++++++++++++++++-- tools/convert_diffusers_to_flux.py | 78 +------------ 6 files changed, 196 insertions(+), 111 deletions(-) diff --git a/README.md b/README.md index 789fe514a..c567758a5 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 6, 2024: +- In FLUX.1 LoRA training and fine-tuning, the specified weight file (*.safetensors) is automatically determined to be dev or schnell. This allows schnell models to be loaded correctly. Note that LoRA training with schnell models and fine-tuning with schnell models are unverified. +- FLUX.1 LoRA training and fine-tuning can now load weights in Diffusers format in addition to BFL format (a single *.safetensors file). Please specify the parent directory of `transformer` or `diffusion_pytorch_model-00001-of-00003.safetensors` with the full path. However, Diffusers format CLIP/T5XXL is not supported. Saving is supported only in BFL format. + Sep 26, 2024: The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 2f1b9a377..7ab224f1b 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -419,9 +419,6 @@ def encode(prpt: str): steps = args.steps guidance_scale = args.guidance - name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way - is_schnell = name == "schnell" - def is_fp8(dt): return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] @@ -455,12 +452,8 @@ def is_fp8(dt): # if is_fp8(t5xxl_dtype): # t5xxl = accelerator.prepare(t5xxl) - t5xxl_max_length = 256 if is_schnell else 512 - tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) - encoding_strategy = strategy_flux.FluxTextEncodingStrategy() - # DiT - model = flux_utils.load_flow_model(name, args.ckpt_path, None, loading_device) + is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device) model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype @@ -469,8 +462,12 @@ def is_fp8(dt): # if args.offload: # model = model.to("cpu") + t5xxl_max_length = 256 if is_schnell else 512 + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) + encoding_strategy = strategy_flux.FluxTextEncodingStrategy() + # AE - ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device) + ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device) ae.eval() # if is_fp8(ae_dtype): # ae = accelerator.prepare(ae) diff --git a/flux_train.py b/flux_train.py index 81c13e4cc..ecc87c0a8 100644 --- a/flux_train.py +++ b/flux_train.py @@ -137,6 +137,7 @@ def train(args): train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) if args.debug_dataset: if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( @@ -144,9 +145,8 @@ def train(args): args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False ) ) - name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" t5xxl_max_token_length = ( - args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if name == "schnell" else 512) + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) ) strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) @@ -177,12 +177,11 @@ def train(args): weight_dtype, save_dtype = train_util.prepare_dtype(args) # モデルを読み込む - name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # load VAE for caching latents ae = None if cache_latents: - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae.to(accelerator.device, dtype=weight_dtype) ae.requires_grad_(False) ae.eval() @@ -196,7 +195,7 @@ def train(args): # prepare tokenize strategy if args.t5xxl_max_token_length is None: - if name == "schnell": + if is_schnell: t5xxl_max_token_length = 256 else: t5xxl_max_token_length = 512 @@ -258,8 +257,8 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - flux = flux_utils.load_flow_model( - name, args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + _, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) if args.gradient_checkpointing: @@ -294,7 +293,7 @@ def train(args): if not cache_latents: # load VAE here if not cached - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu") ae.requires_grad_(False) ae.eval() ae.to(accelerator.device, dtype=weight_dtype) diff --git a/flux_train_network.py b/flux_train_network.py index 65b121e7c..5d14bd28e 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -2,7 +2,7 @@ import copy import math import random -from typing import Any +from typing import Any, Optional import torch from accelerate import Accelerator @@ -24,6 +24,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() self.sample_prompts_te_outputs = None + self.is_schnell: Optional[bool] = None def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -57,19 +58,15 @@ def assert_extra_args(self, args, train_dataset_group): train_dataset_group.verify_bucket_reso_steps(32) # TODO check this - def get_flux_model_name(self, args): - return "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" - def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models - name = self.get_flux_model_name(args) # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) loading_dtype = None if args.fp8_base else weight_dtype # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future - model = flux_utils.load_flow_model( - name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + self.is_schnell, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors ) if args.fp8_base: # check dtype of model @@ -100,7 +97,7 @@ def load_target_model(self, args, weight_dtype, accelerator): elif t5xxl.dtype == torch.float8_e4m3fn: logger.info("Loaded fp8 T5XXL model") - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model @@ -142,10 +139,10 @@ def prepare_split_model(self, model, weight_dtype, accelerator): return flux_lower def get_tokenize_strategy(self, args): - name = self.get_flux_model_name(args) + _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) if args.t5xxl_max_token_length is None: - if name == "schnell": + if is_schnell: t5xxl_max_token_length = 256 else: t5xxl_max_token_length = 512 diff --git a/library/flux_utils.py b/library/flux_utils.py index 7b0a41a8a..713814e28 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,9 +1,11 @@ import json -from typing import Optional, Union +import os +from typing import List, Optional, Tuple, Union import einops import torch from safetensors.torch import load_file +from safetensors import safe_open from accelerate import init_empty_weights from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config @@ -17,6 +19,8 @@ logger = logging.getLogger(__name__) MODEL_VERSION_FLUX_V1 = "flux1" +MODEL_NAME_DEV = "dev" +MODEL_NAME_SCHNELL = "schnell" # temporary copy from sd3_utils TODO refactor @@ -39,10 +43,35 @@ def load_safetensors( return load_file(path) # prevent device invalid Error +def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, List[str]]: + # check the state dict: Diffusers or BFL, dev or schnell + logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") + + if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers + ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") + if "00001-of-00003" in ckpt_path: + ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)] + else: + ckpt_paths = [ckpt_path] + + keys = [] + for ckpt_path in ckpt_paths: + with safe_open(ckpt_path, framework="pt") as f: + keys.extend(f.keys()) + + is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys + is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) + return is_diffusers, is_schnell, ckpt_paths + + def load_flow_model( - name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False -) -> flux_models.Flux: - logger.info(f"Building Flux model {name}") + ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False +) -> Tuple[bool, flux_models.Flux]: + is_diffusers, is_schnell, ckpt_paths = check_flux_state_dict_diffusers_schnell(ckpt_path) + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + + # build model + logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") with torch.device("meta"): model = flux_models.Flux(flux_models.configs[name].params) if dtype is not None: @@ -50,18 +79,28 @@ def load_flow_model( # load_sft doesn't support torch.device logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + sd = {} + for ckpt_path in ckpt_paths: + sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)) + + # convert Diffusers to BFL + if is_diffusers: + logger.info("Converting Diffusers to BFL") + sd = convert_diffusers_sd_to_bfl(sd) + logger.info("Converted Diffusers to BFL") + info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Flux: {info}") - return model + return is_schnell, model def load_ae( - name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False + ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False ) -> flux_models.AutoEncoder: logger.info("Building AutoEncoder") with torch.device("meta"): - ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype) + # dev and schnell have the same AE params + ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype) logger.info(f"Loading state dict from {ckpt_path}") sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) @@ -246,3 +285,126 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor: """ x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) return x + + +# region Diffusers + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + +BFL_TO_DIFFUSERS_MAP = { + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]: + # make reverse map from diffusers map + diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) + for b in range(NUM_DOUBLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("double_blocks."): + block_prefix = f"transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for b in range(NUM_SINGLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("single_blocks."): + block_prefix = f"single_transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): + for i, weight in enumerate(weights): + diffusers_to_bfl_map[weight] = (i, key) + return diffusers_to_bfl_map + + +def convert_diffusers_sd_to_bfl(diffusers_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + diffusers_to_bfl_map = make_diffusers_to_bfl_map() + + # iterate over three safetensors files to reduce memory usage + flux_sd = {} + for diffusers_key, tensor in diffusers_sd.items(): + if diffusers_key in diffusers_to_bfl_map: + index, bfl_key = diffusers_to_bfl_map[diffusers_key] + if bfl_key not in flux_sd: + flux_sd[bfl_key] = [] + flux_sd[bfl_key].append((index, tensor)) + else: + logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}") + raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}") + + # concat tensors if multiple tensors are mapped to a single key, sort by index + for key, values in flux_sd.items(): + if len(values) == 1: + flux_sd[key] = values[0][1] + else: + flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])]) + + # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + if "final_layer.adaLN_modulation.1.weight" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"]) + if "final_layer.adaLN_modulation.1.bias" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"]) + + return flux_sd + + +# endregion diff --git a/tools/convert_diffusers_to_flux.py b/tools/convert_diffusers_to_flux.py index 9d8f7c74b..65ba7321a 100644 --- a/tools/convert_diffusers_to_flux.py +++ b/tools/convert_diffusers_to_flux.py @@ -29,6 +29,7 @@ import torch from tqdm import tqdm +from library import flux_utils from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file setup_logging() @@ -36,65 +37,6 @@ logger = logging.getLogger(__name__) -NUM_DOUBLE_BLOCKS = 19 -NUM_SINGLE_BLOCKS = 38 - -BFL_TO_DIFFUSERS_MAP = { - "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], - "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], - "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], - "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], - "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], - "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], - "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], - "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], - "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], - "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], - "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], - "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], - "txt_in.weight": ["context_embedder.weight"], - "txt_in.bias": ["context_embedder.bias"], - "img_in.weight": ["x_embedder.weight"], - "img_in.bias": ["x_embedder.bias"], - "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], - "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], - "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], - "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], - "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], - "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], - "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], - "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], - "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], - "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], - "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], - "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], - "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], - "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], - "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], - "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], - "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], - "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], - "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], - "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], - "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], - "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], - "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], - "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], - "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], - "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], - "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], - "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], - "single_blocks.().linear2.weight": ["proj_out.weight"], - "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], - "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], - "single_blocks.().linear2.weight": ["proj_out.weight"], - "single_blocks.().linear2.bias": ["proj_out.bias"], - "final_layer.linear.weight": ["proj_out.weight"], - "final_layer.linear.bias": ["proj_out.bias"], - "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], - "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], -} - def convert(args): # if diffusers_path is folder, get safetensors file @@ -114,23 +56,7 @@ def convert(args): save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None # make reverse map from diffusers map - diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) - for b in range(NUM_DOUBLE_BLOCKS): - for key, weights in BFL_TO_DIFFUSERS_MAP.items(): - if key.startswith("double_blocks."): - block_prefix = f"transformer_blocks.{b}." - for i, weight in enumerate(weights): - diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) - for b in range(NUM_SINGLE_BLOCKS): - for key, weights in BFL_TO_DIFFUSERS_MAP.items(): - if key.startswith("single_blocks."): - block_prefix = f"single_transformer_blocks.{b}." - for i, weight in enumerate(weights): - diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) - for key, weights in BFL_TO_DIFFUSERS_MAP.items(): - if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): - for i, weight in enumerate(weights): - diffusers_to_bfl_map[weight] = (i, key) + diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map() # iterate over three safetensors files to reduce memory usage flux_sd = {}