From ccfaa001e74f80798e528b4b3ea6ef811017c07b Mon Sep 17 00:00:00 2001 From: minux302 Date: Fri, 15 Nov 2024 20:21:28 +0900 Subject: [PATCH 01/11] add flux controlnet base module --- flux_train_control_net.py | 573 ++++++++++++++++++++++++++++++++++++++ flux_train_network.py | 5 +- library/flux_models.py | 257 ++++++++++++++++- library/flux_utils.py | 8 + 4 files changed, 841 insertions(+), 2 deletions(-) create mode 100644 flux_train_control_net.py diff --git a/flux_train_control_net.py b/flux_train_control_net.py new file mode 100644 index 000000000..704c4d32e --- /dev/null +++ b/flux_train_control_net.py @@ -0,0 +1,573 @@ +import argparse +import copy +import math +import random +from typing import Any, Optional + +import torch +from accelerate import Accelerator +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util +import train_network +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class FluxNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + self.is_schnell: Optional[bool] = None + self.is_swapping_blocks: bool = False + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # prepare CLIP-L/T5XXL training flags + self.train_clip_l = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + # deprecated split_mode option + if args.split_mode: + if args.blocks_to_swap is not None: + logger.warning( + "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored." + " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。" + ) + else: + logger.warning( + "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set." + " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。" + ) + args.blocks_to_swap = 18 # 18 is safe for most cases + + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + self.is_schnell, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + ) + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 FLUX model") + else: + logger.info( + "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + model.to(torch.float8_e4m3fn) + + # if args.split_mode: + # model = self.prepare_split_model(model, weight_dtype, accelerator) + + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + clip_l.eval() + + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + + def get_tokenize_strategy(self, args): + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + if self.train_clip_l and not self.train_t5xxl: + return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached + else: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders # both CLIP-L and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_clip_l, self.train_t5xxl] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_clip_l or self.train_t5xxl, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device) + + if text_encoders[1].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[1].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[1].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device) + + # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # # get size embeddings + # orig_size = batch["original_sizes_hw"] + # crop_size = batch["crop_top_lefts"] + # target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # # concat embeddings + # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + # return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs + ) + # return + + """ + class FluxUpperLowerWrapper(torch.nn.Module): + def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): + super().__init__() + self.flux_upper = flux_upper + self.flux_lower = flux_lower + self.target_device = device + + def prepare_block_swap_before_forward(self): + pass + + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): + self.flux_lower.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_upper.to(self.target_device) + img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) + self.flux_upper.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_lower.to(self.target_device) + return self.flux_lower(img, txt, vec, pe, txt_attention_mask) + + wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) + clean_memory_on_device(accelerator.device) + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs + ) + clean_memory_on_device(accelerator.device) + """ + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance + # ensure guidance_scale in args is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + img_ids.requires_grad_(True) + guidance_vec.requires_grad_(True) + + # Predict the noise residual + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): + # if not args.split_mode: + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=img, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + """ + else: + # split forward to reduce memory usage + assert network.train_blocks == "single", "train_blocks must be single for split mode" + with accelerator.autocast(): + # move flux lower to cpu, and then move flux upper to gpu + unet.to("cpu") + clean_memory_on_device(accelerator.device) + self.flux_upper.to(accelerator.device) + + # upper model does not require grad + with torch.no_grad(): + intermediate_img, intermediate_txt, vec, pe = self.flux_upper( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # move flux upper back to cpu, and then move flux lower to gpu + self.flux_upper.to("cpu") + clean_memory_on_device(accelerator.device) + unet.to(accelerator.device) + + # lower model requires grad + intermediate_img.requires_grad_(True) + intermediate_txt.requires_grad_(True) + vec.requires_grad_(True) + pe.requires_grad_(True) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + """ + + return model_pred + + model_pred = call_dit( + img=packed_noisy_model_input, + img_ids=img_ids, + t5_out=t5_out, + txt_ids=txt_ids, + l_pooled=l_pooled, + timesteps=timesteps, + guidance_vec=guidance_vec, + t5_attn_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(): + model_pred_prior = call_dit( + img=packed_noisy_model_input[diff_output_pr_indices], + img_ids=img_ids[diff_output_pr_indices], + t5_out=t5_out[diff_output_pr_indices], + txt_ids=txt_ids[diff_output_pr_indices], + l_pooled=l_pooled[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, + t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) + model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, None, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + + def update_metadata(self, metadata, args): + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_guidance_scale"] = args.guidance_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0: # CLIP-L + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0: # CLIP-L + logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + flux: flux_models.Flux = unet + flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + + return flux + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--split_mode", + action="store_true", + # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", + help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead." + " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = FluxNetworkTrainer() + trainer.train(args) diff --git a/flux_train_network.py b/flux_train_network.py index 704c4d32e..0feb9b011 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -125,7 +125,10 @@ def load_target_model(self, args, weight_dtype, accelerator): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + controlnet = flux_utils.load_controlnet() + controlnet.train() + + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model, controlnet def get_tokenize_strategy(self, args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) diff --git a/library/flux_models.py b/library/flux_models.py index fa3c7ad2b..a3bd19743 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1013,6 +1013,8 @@ def forward( txt_ids: Tensor, timesteps: Tensor, y: Tensor, + block_controlnet_hidden_states=None, + block_controlnet_single_hidden_states=None, guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, ) -> Tensor: @@ -1031,18 +1033,29 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) + if block_controlnet_hidden_states is not None: + controlnet_depth = len(block_controlnet_hidden_states) + if block_controlnet_single_hidden_states is not None: + controlnet_single_depth = len(block_controlnet_single_hidden_states) if not self.blocks_to_swap: - for block in self.double_blocks: + for block_idx, block in enumerate(self.double_blocks): img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_hidden_states is not None: + img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] + img = torch.cat((txt, img), 1) for block in self.single_blocks: img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_single_hidden_states is not None: + img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] else: for block_idx, block in enumerate(self.double_blocks): self.offloader_double.wait_for_block(block_idx) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_hidden_states is not None: + img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) @@ -1052,6 +1065,8 @@ def forward( self.offloader_single.wait_for_block(block_idx) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_single_hidden_states is not None: + img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) @@ -1066,6 +1081,246 @@ def forward( return img +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNetFlux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams, controlnet_depth=2): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(0) # TMP + ] + ) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.blocks_to_swap = None + + self.offloader_double = None + self.offloader_single = None + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) + + # add ControlNet blocks + self.controlnet_blocks_for_double = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks_for_double.append(controlnet_block) + self.controlnet_blocks_for_single = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks_for_single.append(controlnet_block) + self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.gradient_checkpointing = False + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(nn.Conv2d(16, 16, 3, padding=1)) + ) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def enable_block_swap(self, num_blocks: int, device: torch.device): + self.blocks_to_swap = num_blocks + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + + self.offloader_double = custom_offloading_utils.ModelOffloader( + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True + ) + self.offloader_single = custom_offloading_utils.ModelOffloader( + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True + ) + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." + ) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_double_blocks = self.double_blocks + save_single_blocks = self.single_blocks + self.double_blocks = None + self.single_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.double_blocks = save_double_blocks + self.single_blocks = save_single_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, + ) -> tuple[tuple[Tensor]]: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_samples = () + block_single_samples = () + if not self.blocks_to_swap: + for block_idx, block in enumerate(self.double_blocks): + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_samples = block_samples + (img,) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_single_samples = block_single_samples + (img,) + else: + for block_idx, block in enumerate(self.double_blocks): + self.offloader_double.wait_for_block(block_idx) + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_samples = block_samples + (img,) + + self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) + + img = torch.cat((txt, img), 1) + + for block_idx, block in enumerate(self.single_blocks): + self.offloader_single.wait_for_block(block_idx) + + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_single_samples = block_single_samples + (img,) + + self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) + + controlnet_block_samples = () + controlnet_single_block_samples = () + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double): + block_sample = controlnet_block(block_sample) + controlnet_block_samples = controlnet_block_samples + (block_sample,) + for block_sample, controlnet_block in zip(block_samples, self.controlnet_single_blocks_for_single): + block_sample = controlnet_block(block_sample) + controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,) + + return controlnet_block_samples, controlnet_single_block_samples + + """ class FluxUpper(nn.Module): "" diff --git a/library/flux_utils.py b/library/flux_utils.py index f3093615d..678efbc8a 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -153,6 +153,14 @@ def load_ae( return ae +def load_controlnet(name, device, transformer=None): + with torch.device(device): + controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) + if transformer is not None: + controlnet.load_state_dict(transformer.state_dict(), strict=False) + return controlnet + + def load_clip_l( ckpt_path: Optional[str], dtype: torch.dtype, From 42f6edf3a886287b99770bc7a8c0bafd3fa03f39 Mon Sep 17 00:00:00 2001 From: minux302 Date: Fri, 15 Nov 2024 23:48:51 +0900 Subject: [PATCH 02/11] fix for adding controlnet --- flux_train_control_net.py | 1270 +++++++++++++++++++++-------------- flux_train_network.py | 3 - library/flux_train_utils.py | 32 +- library/flux_utils.py | 11 +- 4 files changed, 820 insertions(+), 496 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 704c4d32e..8a7be75f2 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -1,563 +1,860 @@ +# training with captions + +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + import argparse +from concurrent.futures import ThreadPoolExecutor import copy import math -import random -from typing import Any, Optional +import os +from multiprocessing import Value +import time +from typing import List, Optional, Tuple, Union +import toml + +from tqdm import tqdm import torch -from accelerate import Accelerator +import torch.nn as nn +from library import utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() -from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util -import train_network -from library.utils import setup_logging +from accelerate.utils import set_seed +from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments setup_logging() import logging logger = logging.getLogger(__name__) +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True -class FluxNetworkTrainer(train_network.NetworkTrainer): - def __init__(self): - super().__init__() - self.sample_prompts_te_outputs = None - self.is_schnell: Optional[bool] = None - self.is_swapping_blocks: bool = False + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) - # sdxl_train_util.verify_sdxl_training_args(args) + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, ( + "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) - if args.fp8_base_unet: - args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None - if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: - logger.warning( - "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" - ) - args.cache_text_encoder_outputs = True + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + if args.debug_dataset: if args.cache_text_encoder_outputs: - assert ( - train_dataset_group.is_text_encoder_output_cacheable() - ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - - # prepare CLIP-L/T5XXL training flags - self.train_clip_l = not args.network_train_unet_only - self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False + ) + ) + t5xxl_max_token_length = ( + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) + ) + strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) + + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return - if args.max_token_length is not None: - logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if args.cache_text_encoder_outputs: assert ( - args.blocks_to_swap is None or args.blocks_to_swap == 0 - ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - # deprecated split_mode option - if args.split_mode: - if args.blocks_to_swap is not None: - logger.warning( - "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored." - " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。" - ) - else: - logger.warning( - "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set." - " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。" - ) - args.blocks_to_swap = 18 # 18 is safe for most cases + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) - train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) - def load_target_model(self, args, weight_dtype, accelerator): - # currently offload to cpu for some models + # モデルを読み込む - # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) - loading_dtype = None if args.fp8_base else weight_dtype + # load VAE for caching latents + ae = None + if cache_latents: + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() - # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future - self.is_schnell, model = flux_utils.load_flow_model( - args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + train_dataset_group.new_cache_latents(ae, accelerator) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) + + # load clip_l, t5xxl for caching text encoder outputs + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + clip_l.eval() + t5xxl.eval() + clip_l.requires_grad_(False) + t5xxl.requires_grad_(False) + + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask ) - if args.fp8_base: - # check dtype of model - if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: - raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") - elif model.dtype == torch.float8_e4m3fn: - logger.info("Loaded fp8 FLUX model") - else: - logger.info( - "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." - " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" - ) - model.to(torch.float8_e4m3fn) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = flux_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + clip_l = None + t5xxl = None + clean_memory_on_device(accelerator.device) - # if args.split_mode: - # model = self.prepare_split_model(model, weight_dtype, accelerator) + # load FLUX + _, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) + flux.requires_grad_(False) - self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 - if self.is_swapping_blocks: - # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. - logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") - model.enable_block_swap(args.blocks_to_swap, accelerator.device) + # load controlnet + controlnet = flux_utils.load_controlnet() + controlnet.requires_grad_(True) - clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - clip_l.eval() + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) - # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) - if args.fp8_base and not args.fp8_base_unet: - loading_dtype = None # as is - else: - loading_dtype = weight_dtype + # block swap - # loading t5xxl to cpu takes a long time, so we should load to gpu in future - t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - t5xxl.eval() - if args.fp8_base and not args.fp8_base_unet: - # check dtype of model - if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: - raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") - elif t5xxl.dtype == torch.float8_e4m3fn: - logger.info("Loaded fp8 T5XXL model") + # backward compatibility + if args.blocks_to_swap is None: + blocks_to_swap = args.double_blocks_to_swap or 0 + if args.single_blocks_to_swap is not None: + blocks_to_swap += args.single_blocks_to_swap // 2 + if blocks_to_swap > 0: + logger.warning( + "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + ) + logger.info( + f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + ) + args.blocks_to_swap = blocks_to_swap + del blocks_to_swap + + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) + + if not cache_latents: + # load VAE here if not cached + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(controlnet) + name_and_params = list(controlnet.named_parameters()) + # single param group for now + params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate}) + param_names = [[n for n, _ in name_and_params]] + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups. currently different learning rates are not supported + grouped_params = [] + param_group = {} + for group in params_to_optimize: + named_parameters = list(controlnet.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "single" + else: + block_index = -1 + + param_group_key = (block_type, block_index) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) - ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) - def get_tokenize_strategy(self, args): - _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - if args.t5xxl_max_token_length is None: - if is_schnell: - t5xxl_max_token_length = 256 - else: - t5xxl_max_token_length = 512 - else: - t5xxl_max_token_length = args.t5xxl_max_token_length + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + flux.to(weight_dtype) + controlnet.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + flux.to(weight_dtype) + controlnet.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] - logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") - return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + else: + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + controlnet = accelerator.prepare(controlnet, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + accelerator.unwrap_model(controlnet).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) - def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): - return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + if is_swapping_blocks: + accelerator.unwrap_model(controlnet).prepare_block_swap_before_forward() - def get_latents_caching_strategy(self, args): - latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) - return latents_caching_strategy + # For --sample_at_first + optimizer_eval_fn() + flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) - def get_text_encoding_strategy(self, args): - return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 - def post_process_network(self, args, accelerator, network, text_encoders, unet): - # check t5xxl is trained or not - self.train_t5xxl = network.train_t5xxl + for m in training_models: + m.train() - if self.train_t5xxl and args.cache_text_encoder_outputs: - raise ValueError( - "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" - ) + for step, batch in enumerate(train_dataloader): + current_step.value = global_step - def get_models_for_text_encoding(self, args, accelerator, text_encoders): - if args.cache_text_encoder_outputs: - if self.train_clip_l and not self.train_t5xxl: - return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached - else: - return None # no text encoders are needed for encoding because both are cached - else: - return text_encoders # both CLIP-L and T5XXL are needed for encoding + if args.blockwise_fused_optimizers: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step - def get_text_encoders_train_flags(self, args, text_encoders): - return [self.train_clip_l, self.train_t5xxl] + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] - def get_text_encoder_outputs_caching_strategy(self, args): - if args.cache_text_encoder_outputs: - # if the text encoders is trained, we need tokenization, so is_partial is True - return strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, - args.text_encoder_batch_size, - args.skip_cache_check, - is_partial=self.train_clip_l or self.train_t5xxl, - apply_t5_attn_mask=args.apply_t5_attn_mask, - ) - else: - return None + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps - def cache_text_encoder_outputs_if_needed( - self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype - ): - if args.cache_text_encoder_outputs: - if not args.lowram: - # メモリ消費を減らす - logger.info("move vae and unet to cpu to save memory") - org_vae_device = vae.device - org_unet_device = unet.device - vae.to("cpu") - unet.to("cpu") - clean_memory_on_device(accelerator.device) - - # When TE is not be trained, it will not be prepared so we need to use explicit autocast - logger.info("move text encoders to gpu") - text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 - text_encoders[1].to(accelerator.device) - - if text_encoders[1].dtype == torch.float8_e4m3fn: - # if we load fp8 weights, the model is already fp8, so we use it as is - self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) - else: - # otherwise, we need to convert it to target dtype - text_encoders[1].to(weight_dtype) - - with accelerator.autocast(): - dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) - - # cache sample prompts - if args.sample_prompts is not None: - logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") - - tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() - text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - - prompts = train_util.load_prompts(args.sample_prompts) - sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs - with accelerator.autocast(), torch.no_grad(): - for prompt_dict in prompts: - for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: - if p not in sample_prompts_te_outputs: - logger.info(f"cache Text Encoder outputs for prompt: {p}") - tokens_and_masks = tokenize_strategy.tokenize(p) - sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask - ) - self.sample_prompts_te_outputs = sample_prompts_te_outputs - - accelerator.wait_for_everyone() - - # move back to cpu - if not self.is_train_text_encoder(args): - logger.info("move CLIP-L back to cpu") - text_encoders[0].to("cpu") - logger.info("move t5XXL back to cpu") - text_encoders[1].to("cpu") - clean_memory_on_device(accelerator.device) - - if not args.lowram: - logger.info("move vae and unet back to original device") - vae.to(org_vae_device) - unet.to(org_unet_device) - else: - # Text Encoderから毎回出力を取得するので、GPUに乗せておく - text_encoders[0].to(accelerator.device, dtype=weight_dtype) - text_encoders[1].to(accelerator.device) + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] - # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) - # # get size embeddings - # orig_size = batch["original_sizes_hw"] - # crop_size = batch["crop_top_lefts"] - # target_size = batch["target_sizes_hw"] - # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) - # # concat embeddings - # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds - # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) - # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) - # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - # return noise_pred + # call model + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None - def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): - text_encoders = text_encoder # for compatibility - text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + with accelerator.autocast(): + block_samples, block_single_samples = controlnet( + img=packed_noisy_model_input, + img_ids=img_ids, + controlnet_cond=batch["control_image"].to(accelerator.device), + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) - flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs - ) - # return - - """ - class FluxUpperLowerWrapper(torch.nn.Module): - def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): - super().__init__() - self.flux_upper = flux_upper - self.flux_lower = flux_lower - self.target_device = device - - def prepare_block_swap_before_forward(self): - pass - - def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): - self.flux_lower.to("cpu") - clean_memory_on_device(self.target_device) - self.flux_upper.to(self.target_device) - img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) - self.flux_upper.to("cpu") - clean_memory_on_device(self.target_device) - self.flux_lower.to(self.target_device) - return self.flux_lower(img, txt, vec, pe, txt_attention_mask) - - wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) - clean_memory_on_device(accelerator.device) - flux_train_utils.sample_images( - accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs - ) - clean_memory_on_device(accelerator.device) - """ - - def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) - self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) - return noise_scheduler - - def encode_images_to_latents(self, args, accelerator, vae, images): - return vae.encode(images) - - def shift_scale_latents(self, args, latents): - return latents - - def get_noise_pred_and_target( - self, - args, - accelerator, - noise_scheduler, - latents, - batch, - text_encoder_conds, - unet: flux_models.Flux, - network, - weight_dtype, - train_unet, - ): - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - - # get noisy model input and timesteps - noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, accelerator.device, weight_dtype - ) + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) - # pack latents and get img_ids - packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 - packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) - - # get guidance - # ensure guidance_scale in args is float - guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) - - # ensure the hidden state will require grad - if args.gradient_checkpointing: - noisy_model_input.requires_grad_(True) - for t in text_encoder_conds: - if t is not None and t.dtype.is_floating_point: - t.requires_grad_(True) - img_ids.requires_grad_(True) - guidance_vec.requires_grad_(True) - - # Predict the noise residual - l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds - if not args.apply_t5_attn_mask: - t5_attn_mask = None - - def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): - # if not args.split_mode: - # normal forward - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = unet( - img=img, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + loss = train_util.conditional_loss( + model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + ) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + flux_train_utils.sample_images( + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) - """ - else: - # split forward to reduce memory usage - assert network.train_blocks == "single", "train_blocks must be single for split mode" - with accelerator.autocast(): - # move flux lower to cpu, and then move flux upper to gpu - unet.to("cpu") - clean_memory_on_device(accelerator.device) - self.flux_upper.to(accelerator.device) - # upper model does not require grad - with torch.no_grad(): - intermediate_img, intermediate_txt, vec, pe = self.flux_upper( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), ) + optimizer_train_fn() - # move flux upper back to cpu, and then move flux lower to gpu - self.flux_upper.to("cpu") - clean_memory_on_device(accelerator.device) - unet.to(accelerator.device) - - # lower model requires grad - intermediate_img.requires_grad_(True) - intermediate_txt.requires_grad_(True) - vec.requires_grad_(True) - pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) - """ - - return model_pred - - model_pred = call_dit( - img=packed_noisy_model_input, - img_ids=img_ids, - t5_out=t5_out, - txt_ids=txt_ids, - l_pooled=l_pooled, - timesteps=timesteps, - guidance_vec=guidance_vec, - t5_attn_mask=t5_attn_mask, - ) + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) - # unpack latents - model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - - # apply model prediction type - model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) - - # flow matching loss: this is different from SD3 - target = noise - latents - - # differential output preservation - if "custom_attributes" in batch: - diff_output_pr_indices = [] - for i, custom_attributes in enumerate(batch["custom_attributes"]): - if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: - diff_output_pr_indices.append(i) - - if len(diff_output_pr_indices) > 0: - network.set_multiplier(0.0) - with torch.no_grad(): - model_pred_prior = call_dit( - img=packed_noisy_model_input[diff_output_pr_indices], - img_ids=img_ids[diff_output_pr_indices], - t5_out=t5_out[diff_output_pr_indices], - txt_ids=txt_ids[diff_output_pr_indices], - l_pooled=l_pooled[diff_output_pr_indices], - timesteps=timesteps[diff_output_pr_indices], - guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, - t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, - ) - network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + accelerator.log(logs, step=global_step) - model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) - model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( args, - model_pred_prior, - noisy_model_input[diff_output_pr_indices], - sigmas[diff_output_pr_indices] if sigmas is not None else None, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), ) - target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - - return model_pred, target, timesteps, None, weighting - - def post_process_loss(self, loss, args, timesteps, noise_scheduler): - return loss - - def get_sai_model_spec(self, args): - return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") - - def update_metadata(self, metadata, args): - metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask - metadata["ss_weighting_scheme"] = args.weighting_scheme - metadata["ss_logit_mean"] = args.logit_mean - metadata["ss_logit_std"] = args.logit_std - metadata["ss_mode_scale"] = args.mode_scale - metadata["ss_guidance_scale"] = args.guidance_scale - metadata["ss_timestep_sampling"] = args.timestep_sampling - metadata["ss_sigmoid_scale"] = args.sigmoid_scale - metadata["ss_model_prediction_type"] = args.model_prediction_type - metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift - - def is_text_encoder_not_needed_for_training(self, args): - return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) - - def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): - if index == 0: # CLIP-L - return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) - else: # T5XXL - text_encoder.encoder.embed_tokens.requires_grad_(True) - - def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): - if index == 0: # CLIP-L - logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") - text_encoder.to(te_weight_dtype) # fp8 - text_encoder.text_model.embeddings.to(dtype=weight_dtype) - else: # T5XXL - - def prepare_fp8(text_encoder, target_dtype): - def forward_hook(module): - def forward(hidden_states): - hidden_gelu = module.act(module.wi_0(hidden_states)) - hidden_linear = module.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = module.dropout(hidden_states) - - hidden_states = module.wo(hidden_states) - return hidden_states - - return forward - - for module in text_encoder.modules(): - if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: - # print("set", module.__class__.__name__, "to", target_dtype) - module.to(target_dtype) - if module.__class__.__name__ in ["T5DenseGatedActDense"]: - # print("set", module.__class__.__name__, "hooks") - module.forward = forward_hook(module) - - if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: - logger.info(f"T5XXL already prepared for fp8") - else: - logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") - text_encoder.to(te_weight_dtype) # fp8 - prepare_fp8(text_encoder, weight_dtype) - def prepare_unet_with_accelerator( - self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module - ) -> torch.nn.Module: - if not self.is_swapping_blocks: - return super().prepare_unet_with_accelerator(args, accelerator, unet) + flux_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + ) + optimizer_train_fn() + + is_main_process = accelerator.is_main_process + # if is_main_process: + controlnet = accelerator.unwrap_model(controlnet) - # if we doesn't swap blocks, we can move the model to device - flux: flux_models.Flux = unet - flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) - accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage - accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + accelerator.end_training() + optimizer_eval_fn() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す - return flux + if is_main_process: + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux) + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: - parser = train_network.setup_parser() + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument( - "--split_mode", + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--single_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--cpu_offload_checkpointing", action="store_true", - # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" - # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", - help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead." - " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) return parser @@ -569,5 +866,4 @@ def setup_parser() -> argparse.ArgumentParser: train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) - trainer = FluxNetworkTrainer() - trainer.train(args) + train(args) diff --git a/flux_train_network.py b/flux_train_network.py index 0feb9b011..6668012e4 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -125,9 +125,6 @@ def load_target_model(self, args, weight_dtype, accelerator): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - controlnet = flux_utils.load_controlnet() - controlnet.train() - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model, controlnet def get_tokenize_strategy(self, args): diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index d90644a25..cc3bcb0ec 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -40,6 +40,7 @@ def sample_images( text_encoders, sample_prompts_te_outputs, prompt_replacement=None, + controlnet=None ): if steps == 0: if not args.sample_at_first: @@ -67,6 +68,8 @@ def sample_images( flux = accelerator.unwrap_model(flux) if text_encoders is not None: text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + if controlnet is not None: + controlnet = accelerator.unwrap_model(controlnet) # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) prompts = train_util.load_prompts(args.sample_prompts) @@ -98,6 +101,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, + controlnet ) else: # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) @@ -121,6 +125,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, + controlnet ) torch.set_rng_state(rng_state) @@ -142,6 +147,7 @@ def sample_image_inference( steps, sample_prompts_te_outputs, prompt_replacement, + controlnet ): assert isinstance(prompt_dict, dict) # negative_prompt = prompt_dict.get("negative_prompt") @@ -150,7 +156,7 @@ def sample_image_inference( height = prompt_dict.get("height", 512) scale = prompt_dict.get("scale", 3.5) seed = prompt_dict.get("seed") - # controlnet_image = prompt_dict.get("controlnet_image") + controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) @@ -169,6 +175,9 @@ def sample_image_inference( # if negative_prompt is None: # negative_prompt = "" + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 @@ -224,7 +233,7 @@ def sample_image_inference( t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None with accelerator.autocast(), torch.no_grad(): - x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask) + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) x = x.float() x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) @@ -301,18 +310,37 @@ def denoise( timesteps: list[float], guidance: float = 4.0, t5_attn_mask: Optional[torch.Tensor] = None, + controlnet: Optional[flux_models.ControlNetFlux] = None, + controlnet_img: Optional[torch.Tensor] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() + if controlnet is not None: + block_samples, block_single_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_img, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + else: + block_samples = None + block_single_samples = None pred = model( img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, timesteps=t_vec, guidance=guidance_vec, txt_attention_mask=t5_attn_mask, diff --git a/library/flux_utils.py b/library/flux_utils.py index 678efbc8a..7b538d133 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -153,11 +153,14 @@ def load_ae( return ae -def load_controlnet(name, device, transformer=None): - with torch.device(device): +def load_controlnet(): + # TODO + is_schnell = False + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + with torch.device("meta"): controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) - if transformer is not None: - controlnet.load_state_dict(transformer.state_dict(), strict=False) + # if transformer is not None: + # controlnet.load_state_dict(transformer.state_dict(), strict=False) return controlnet From e358b118afbc93f63dbb5ab6d2412ec553ea9cd7 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sat, 16 Nov 2024 14:49:29 +0900 Subject: [PATCH 03/11] fix dataloader --- flux_train_control_net.py | 84 ++++++++++++++++++++------------------- library/flux_models.py | 17 ++++---- 2 files changed, 52 insertions(+), 49 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 8a7be75f2..ee4d0ebf3 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -11,31 +11,36 @@ # - Per-block fused optimizer instances import argparse -from concurrent.futures import ThreadPoolExecutor import copy import math import os -from multiprocessing import Value import time +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Value from typing import List, Optional, Tuple, Union -import toml - -from tqdm import tqdm +import toml import torch import torch.nn as nn +from tqdm import tqdm + from library import utils -from library.device_utils import init_ipex, clean_memory_on_device +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() from accelerate.utils import set_seed -from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux -from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler import library.train_util as train_util - -from library.utils import setup_logging, add_logging_arguments +from library import ( + deepspeed_utils, + flux_train_utils, + flux_utils, + strategy_base, + strategy_flux, +) +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler +from library.utils import add_logging_arguments, setup_logging setup_logging() import logging @@ -46,10 +51,10 @@ # import library.sdxl_train_util as sdxl_train_util from library.config_util import ( - ConfigSanitizer, BlueprintGenerator, + ConfigSanitizer, ) -from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments +from library.custom_train_functions import add_custom_train_arguments, apply_masked_loss def train(args): @@ -85,7 +90,6 @@ def train(args): ) cache_latents = args.cache_latents - use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する @@ -103,7 +107,7 @@ def train(args): if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "conditioing_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( @@ -111,31 +115,17 @@ def train(args): ) ) else: - if use_dreambooth_method: - logger.info("Using DreamBooth method.") - user_config = { - "datasets": [ - { - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( - args.train_data_dir, args.reg_data_dir - ) - } - ] - } - else: - logger.info("Training with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension + ) + } + ] + } blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) @@ -648,12 +638,12 @@ def grad_hook(parameter: torch.Tensor): l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds if not args.apply_t5_attn_mask: t5_attn_mask = None - + with accelerator.autocast(): block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, img_ids=img_ids, - controlnet_cond=batch["control_image"].to(accelerator.device), + controlnet_img=batch["conditioing_image"].to(accelerator.device), txt=t5_out, txt_ids=txt_ids, y=l_pooled, @@ -856,6 +846,18 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index a3bd19743..b52ea6f0b 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -2,15 +2,15 @@ # license: Apache-2.0 License -from concurrent.futures import Future, ThreadPoolExecutor -from dataclasses import dataclass import math import os import time +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass from typing import Dict, List, Optional, Union from library import utils -from library.device_utils import init_ipex, clean_memory_on_device +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() @@ -18,6 +18,7 @@ from einops import rearrange from torch import Tensor, nn from torch.utils.checkpoint import checkpoint + from library import custom_offloading_utils # USE_REENTRANT = True @@ -1251,7 +1252,7 @@ def forward( self, img: Tensor, img_ids: Tensor, - controlnet_cond: Tensor, + controlnet_img: Tensor, txt: Tensor, txt_ids: Tensor, timesteps: Tensor, @@ -1264,10 +1265,10 @@ def forward( # running on sequences img img = self.img_in(img) - controlnet_cond = self.input_hint_block(controlnet_cond) - controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - controlnet_cond = self.pos_embed_input(controlnet_cond) - img = img + controlnet_cond + controlnet_img = self.input_hint_block(controlnet_img) + controlnet_img = rearrange(controlnet_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_img = self.pos_embed_input(controlnet_img) + img = img + controlnet_img vec = self.time_in(timestep_embedding(timesteps, 256)) if self.params.guidance_embed: if guidance is None: From b2660bbe7410d7ffa40906a7a09f84a17139cb46 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sun, 17 Nov 2024 10:24:57 +0000 Subject: [PATCH 04/11] train run --- flux_train_control_net.py | 39 ++++++++++++++++++++++--------------- library/flux_models.py | 30 ++++++++++++++-------------- library/flux_train_utils.py | 2 +- library/flux_utils.py | 2 +- 4 files changed, 40 insertions(+), 33 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index ee4d0ebf3..205ff6b6a 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -103,11 +103,11 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "conditioing_data_dir"] + ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( @@ -263,10 +263,11 @@ def train(args): args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) flux.requires_grad_(False) + flux.to(accelerator.device) # load controlnet controlnet = flux_utils.load_controlnet() - controlnet.requires_grad_(True) + controlnet.train() if args.gradient_checkpointing: controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) @@ -443,7 +444,8 @@ def train(args): clean_memory_on_device(accelerator.device) - if args.deepspeed: + # if args.deepspeed: + if True: ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -612,8 +614,10 @@ def grad_hook(parameter: torch.Tensor): text_encoder_conds = text_encoding_strategy.encode_tokens( flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask ) - if args.full_fp16: - text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + # if args.full_fp16: + # text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + # TODO: check + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps @@ -629,10 +633,10 @@ def grad_hook(parameter: torch.Tensor): # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device).to(weight_dtype) # get guidance: ensure args.guidance_scale is float - guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype) # call model l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds @@ -640,10 +644,11 @@ def grad_hook(parameter: torch.Tensor): t5_attn_mask = None with accelerator.autocast(): + print("control start") block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, img_ids=img_ids, - controlnet_img=batch["conditioing_image"].to(accelerator.device), + controlnet_cond=batch["conditioning_images"].to(accelerator.device).to(weight_dtype), txt=t5_out, txt_ids=txt_ids, y=l_pooled, @@ -651,6 +656,8 @@ def grad_hook(parameter: torch.Tensor): guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) + print("control end") + print("dit start") # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = flux( img=packed_noisy_model_input, @@ -796,7 +803,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) # TODO split this - train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) deepspeed_utils.add_deepspeed_arguments(parser) @@ -852,12 +859,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="controlnet model name or path / controlnetのモデル名またはパス", ) - parser.add_argument( - "--conditioning_data_dir", - type=str, - default=None, - help="conditioning data directory / 条件付けデータのディレクトリ", - ) + # parser.add_argument( + # "--conditioning_data_dir", + # type=str, + # default=None, + # help="conditioning data directory / 条件付けデータのディレクトリ", + # ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index b52ea6f0b..2fc21db9d 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1042,20 +1042,20 @@ def forward( if not self.blocks_to_swap: for block_idx, block in enumerate(self.double_blocks): img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_hidden_states is not None: + if block_controlnet_hidden_states is not None and controlnet_depth > 0: img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] img = torch.cat((txt, img), 1) - for block in self.single_blocks: + for block_idx, block in enumerate(self.single_blocks): img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_single_hidden_states is not None: + if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] else: for block_idx, block in enumerate(self.double_blocks): self.offloader_double.wait_for_block(block_idx) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_hidden_states is not None: + if block_controlnet_hidden_states is not None and controlnet_depth > 0: img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) @@ -1066,7 +1066,7 @@ def forward( self.offloader_single.wait_for_block(block_idx) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_single_hidden_states is not None: + if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) @@ -1121,14 +1121,14 @@ def __init__(self, params: FluxParams, controlnet_depth=2): mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, ) - for _ in range(params.depth) + for _ in range(controlnet_depth) ] ) self.single_blocks = nn.ModuleList( [ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) - for _ in range(0) # TMP + for _ in range(0) # TODO ] ) @@ -1148,7 +1148,7 @@ def __init__(self, params: FluxParams, controlnet_depth=2): controlnet_block = zero_module(controlnet_block) self.controlnet_blocks_for_double.append(controlnet_block) self.controlnet_blocks_for_single = nn.ModuleList([]) - for _ in range(controlnet_depth): + for _ in range(0): # TODO controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = zero_module(controlnet_block) self.controlnet_blocks_for_single.append(controlnet_block) @@ -1252,7 +1252,7 @@ def forward( self, img: Tensor, img_ids: Tensor, - controlnet_img: Tensor, + controlnet_cond: Tensor, txt: Tensor, txt_ids: Tensor, timesteps: Tensor, @@ -1265,10 +1265,10 @@ def forward( # running on sequences img img = self.img_in(img) - controlnet_img = self.input_hint_block(controlnet_img) - controlnet_img = rearrange(controlnet_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - controlnet_img = self.pos_embed_input(controlnet_img) - img = img + controlnet_img + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond vec = self.time_in(timestep_embedding(timesteps, 256)) if self.params.guidance_embed: if guidance is None: @@ -1283,7 +1283,7 @@ def forward( block_samples = () block_single_samples = () if not self.blocks_to_swap: - for block_idx, block in enumerate(self.double_blocks): + for block in self.double_blocks: img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) block_samples = block_samples + (img,) @@ -1315,7 +1315,7 @@ def forward( for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double): block_sample = controlnet_block(block_sample) controlnet_block_samples = controlnet_block_samples + (block_sample,) - for block_sample, controlnet_block in zip(block_samples, self.controlnet_single_blocks_for_single): + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single): block_sample = controlnet_block(block_sample) controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index cc3bcb0ec..d82bde91c 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -460,7 +460,7 @@ def get_noisy_model_input_and_timesteps( sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents - return noisy_model_input, timesteps, sigmas + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): diff --git a/library/flux_utils.py b/library/flux_utils.py index 7b538d133..4a3817fdb 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -157,7 +157,7 @@ def load_controlnet(): # TODO is_schnell = False name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL - with torch.device("meta"): + with torch.device("cuda:0"): controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) # if transformer is not None: # controlnet.load_state_dict(transformer.state_dict(), strict=False) From 35778f021897796410372aed8540547ba317c2a3 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sun, 17 Nov 2024 11:09:05 +0000 Subject: [PATCH 05/11] fix sample_images type --- flux_train_control_net.py | 31 ++++++++++++++----------------- library/flux_train_utils.py | 2 +- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 205ff6b6a..791900d17 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -444,8 +444,7 @@ def train(args): clean_memory_on_device(accelerator.device) - # if args.deepspeed: - if True: + if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -644,7 +643,6 @@ def grad_hook(parameter: torch.Tensor): t5_attn_mask = None with accelerator.autocast(): - print("control start") block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, img_ids=img_ids, @@ -656,8 +654,6 @@ def grad_hook(parameter: torch.Tensor): guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) - print("control end") - print("dit start") # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = flux( img=packed_noisy_model_input, @@ -763,18 +759,19 @@ def grad_hook(parameter: torch.Tensor): accelerator.wait_for_everyone() optimizer_eval_fn() - if args.save_every_n_epochs is not None: - if accelerator.is_main_process: - flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( - args, - True, - accelerator, - save_dtype, - epoch, - num_train_epochs, - global_step, - accelerator.unwrap_model(flux), - ) + # TODO: save cn models + # if args.save_every_n_epochs is not None: + # if accelerator.is_main_process: + # flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + # args, + # True, + # accelerator, + # save_dtype, + # epoch, + # num_train_epochs, + # global_step, + # accelerator.unwrap_model(flux), + # ) flux_train_utils.sample_images( accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index d82bde91c..de2ee030a 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -235,7 +235,7 @@ def sample_image_inference( with accelerator.autocast(), torch.no_grad(): x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) - x = x.float() + # x = x.float() # TODO: check x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) # latent to image From 4dd4cd6ec8c55fa94b53217181ed9c95e59eed56 Mon Sep 17 00:00:00 2001 From: minux302 Date: Mon, 18 Nov 2024 12:47:01 +0000 Subject: [PATCH 06/11] work cn load and validation --- flux_train_control_net.py | 20 ++++---------------- library/flux_models.py | 6 +++--- library/flux_train_utils.py | 18 ++++++++++++++---- library/flux_utils.py | 25 ++++++++++++++++--------- 4 files changed, 37 insertions(+), 32 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 791900d17..cbfac418f 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -266,7 +266,7 @@ def train(args): flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet() + controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors) controlnet.train() if args.gradient_checkpointing: @@ -568,7 +568,7 @@ def grad_hook(parameter: torch.Tensor): # For --sample_at_first optimizer_eval_fn() - flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet) optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb @@ -718,7 +718,7 @@ def grad_hook(parameter: torch.Tensor): optimizer_eval_fn() flux_train_utils.sample_images( - accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet ) # 指定ステップごとにモデルを保存 @@ -774,7 +774,7 @@ def grad_hook(parameter: torch.Tensor): # ) flux_train_utils.sample_images( - accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet ) optimizer_train_fn() @@ -850,18 +850,6 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) - parser.add_argument( - "--controlnet_model_name_or_path", - type=str, - default=None, - help="controlnet model name or path / controlnetのモデル名またはパス", - ) - # parser.add_argument( - # "--conditioning_data_dir", - # type=str, - # default=None, - # help="conditioning data directory / 条件付けデータのディレクトリ", - # ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index 2fc21db9d..4123b40e5 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1142,11 +1142,11 @@ def __init__(self, params: FluxParams, controlnet_depth=2): self.num_single_blocks = len(self.single_blocks) # add ControlNet blocks - self.controlnet_blocks_for_double = nn.ModuleList([]) + self.controlnet_blocks = nn.ModuleList([]) for _ in range(controlnet_depth): controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = zero_module(controlnet_block) - self.controlnet_blocks_for_double.append(controlnet_block) + self.controlnet_blocks.append(controlnet_block) self.controlnet_blocks_for_single = nn.ModuleList([]) for _ in range(0): # TODO controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) @@ -1312,7 +1312,7 @@ def forward( controlnet_block_samples = () controlnet_single_block_samples = () - for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double): + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): block_sample = controlnet_block(block_sample) controlnet_block_samples = controlnet_block_samples + (block_sample,) for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single): diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index de2ee030a..dbbaba734 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -175,10 +175,6 @@ def sample_image_inference( # if negative_prompt is None: # negative_prompt = "" - if controlnet_image is not None: - controlnet_image = Image.open(controlnet_image).convert("RGB") - controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) - height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 logger.info(f"prompt: {prompt}") @@ -232,6 +228,12 @@ def sample_image_inference( img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) + with accelerator.autocast(), torch.no_grad(): x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) @@ -315,6 +317,8 @@ def denoise( ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() @@ -560,6 +564,12 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提", ) parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--controlnet", + type=str, + default=None, + help="path to controlnet (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)" + ) parser.add_argument( "--t5xxl_max_token_length", type=int, diff --git a/library/flux_utils.py b/library/flux_utils.py index 4a3817fdb..fb7a30749 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -153,15 +153,22 @@ def load_ae( return ae -def load_controlnet(): - # TODO - is_schnell = False - name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL - with torch.device("cuda:0"): - controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) - # if transformer is not None: - # controlnet.load_state_dict(transformer.state_dict(), strict=False) - return controlnet +def load_controlnet( + ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +): + logger.info("Building ControlNet") + # is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) + is_schnell = False + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + with torch.device("meta"): + controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype) + + if ckpt_path is not None: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = controlnet.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded ControlNet: {info}") + return controlnet def load_clip_l( From 31ca899b6b5425466c814d0d9e2e4e8bfbf93001 Mon Sep 17 00:00:00 2001 From: minux302 Date: Mon, 18 Nov 2024 13:03:28 +0000 Subject: [PATCH 07/11] fix depth value --- library/flux_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index 4123b40e5..328ad481d 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1093,7 +1093,7 @@ class ControlNetFlux(nn.Module): Transformer model for flow matching on sequences. """ - def __init__(self, params: FluxParams, controlnet_depth=2): + def __init__(self, params: FluxParams, controlnet_depth=2, controlnet_single_depth=0): super().__init__() self.params = params @@ -1128,7 +1128,7 @@ def __init__(self, params: FluxParams, controlnet_depth=2): self.single_blocks = nn.ModuleList( [ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) - for _ in range(0) # TODO + for _ in range(controlnet_single_depth) ] ) @@ -1148,7 +1148,7 @@ def __init__(self, params: FluxParams, controlnet_depth=2): controlnet_block = zero_module(controlnet_block) self.controlnet_blocks.append(controlnet_block) self.controlnet_blocks_for_single = nn.ModuleList([]) - for _ in range(0): # TODO + for _ in range(controlnet_single_depth): controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = zero_module(controlnet_block) self.controlnet_blocks_for_single.append(controlnet_block) From 0b5229a9550cb921b83d22472c4785a15c42ba90 Mon Sep 17 00:00:00 2001 From: minux302 Date: Thu, 21 Nov 2024 15:55:27 +0000 Subject: [PATCH 08/11] save cn --- flux_train_control_net.py | 34 +++++++++++++++------------------- library/flux_train_utils.py | 1 - 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index cbfac418f..0f38b7094 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -266,7 +266,7 @@ def train(args): flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, "cpu", args.disable_mmap_load_safetensors) controlnet.train() if args.gradient_checkpointing: @@ -613,9 +613,6 @@ def grad_hook(parameter: torch.Tensor): text_encoder_conds = text_encoding_strategy.encode_tokens( flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask ) - # if args.full_fp16: - # text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] - # TODO: check text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps @@ -733,7 +730,7 @@ def grad_hook(parameter: torch.Tensor): epoch, num_train_epochs, global_step, - accelerator.unwrap_model(flux), + accelerator.unwrap_model(controlnet), ) optimizer_train_fn() @@ -759,19 +756,18 @@ def grad_hook(parameter: torch.Tensor): accelerator.wait_for_everyone() optimizer_eval_fn() - # TODO: save cn models - # if args.save_every_n_epochs is not None: - # if accelerator.is_main_process: - # flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( - # args, - # True, - # accelerator, - # save_dtype, - # epoch, - # num_train_epochs, - # global_step, - # accelerator.unwrap_model(flux), - # ) + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(controlnet), + ) flux_train_utils.sample_images( accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet @@ -791,7 +787,7 @@ def grad_hook(parameter: torch.Tensor): del accelerator # この後メモリを使うのでこれは消す if is_main_process: - flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux) + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, controlnet) logger.info("model saved.") diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index dbbaba734..5e25c7feb 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -237,7 +237,6 @@ def sample_image_inference( with accelerator.autocast(), torch.no_grad(): x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) - # x = x.float() # TODO: check x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) # latent to image From 575f583fd9cbaf7f7b644a31437ed9094810b99a Mon Sep 17 00:00:00 2001 From: minux302 Date: Fri, 29 Nov 2024 23:55:52 +0900 Subject: [PATCH 09/11] add README --- README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.md b/README.md index f9c85e3ac..2b1ca3f8c 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ Nov 14, 2024: - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) - [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) - [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) +- [FLUX.1 ControlNet training](#flux1-controlnet-training) - [FLUX.1 OFT training](#flux1-oft-training) - [Inference for FLUX.1 with LoRA model](#inference-for-flux1-with-lora-model) - [FLUX.1 fine-tuning](#flux1-fine-tuning) @@ -245,6 +246,22 @@ example: If you specify one of `train_double_block_indices` or `train_single_block_indices`, the other will be trained as usual. +### FLUX.1 ControlNet training +We have added a new training script for ControlNet training. The script is flux_train_control_net.py. See --help for options. + +Sample command is below. It will work with 80GB VRAM GPUs. +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_control_net.py +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors +--ae ae.safetensors --save_model_as safetensors --sdpa --persistent_data_loader_workers +--max_data_loader_n_workers 1 --seed 42 --gradient_checkpointing --mixed_precision bf16 +--optimizer_type adamw8bit --learning_rate 2e-5 +--highvram --max_train_epochs 1 --save_every_n_steps 1000 --dataset_config dataset.toml +--output_dir /path/to/output/dir --output_name flux-cn +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --deepspeed +``` + + ### FLUX.1 OFT training You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different. From be5860f8e266c5562f123fe9e0cb3febef615290 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sat, 30 Nov 2024 00:08:21 +0900 Subject: [PATCH 10/11] add schnell option to load_cn --- flux_train_control_net.py | 4 ++-- library/flux_utils.py | 14 ++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index a17c811e3..bb27c35ed 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -259,14 +259,14 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - _, flux = flux_utils.load_flow_model( + is_schnell, flux = flux_utils.load_flow_model( args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) flux.requires_grad_(False) flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) + controlnet = flux_utils.load_controlnet(args.controlnet, is_schnell, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) controlnet.train() if args.gradient_checkpointing: diff --git a/library/flux_utils.py b/library/flux_utils.py index f2759c375..8be1d63ee 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,14 +1,14 @@ -from dataclasses import replace import json import os +from dataclasses import replace from typing import List, Optional, Tuple, Union + import einops import torch - -from safetensors.torch import load_file -from safetensors import safe_open from accelerate import init_empty_weights -from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config +from safetensors import safe_open +from safetensors.torch import load_file +from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel from library.utils import setup_logging @@ -154,11 +154,9 @@ def load_ae( def load_controlnet( - ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False + ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False ): logger.info("Building ControlNet") - # is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) - is_schnell = False name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL with torch.device(device): controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype) From f40632bac6704886a7640c327d64820f8f017df8 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sat, 30 Nov 2024 00:15:47 +0900 Subject: [PATCH 11/11] rm abundant arg --- flux_train_network.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 314335366..fa3810e34 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -6,12 +6,21 @@ import torch from accelerate import Accelerator -from library.device_utils import init_ipex, clean_memory_on_device + +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() -from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util import train_network +from library import ( + flux_models, + flux_train_utils, + flux_utils, + sd3_train_utils, + strategy_base, + strategy_flux, + train_util, +) from library.utils import setup_logging setup_logging() @@ -125,7 +134,7 @@ def load_target_model(self, args, weight_dtype, accelerator): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model, controlnet + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model def get_tokenize_strategy(self, args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)