diff --git a/README.md b/README.md index c567758a5..37fc911f6 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,27 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 12, 2024: + +- Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! + - It should work with all training scripts, but it is unverified. + - Set up multi-GPU training with `accelerate config`. + - Specify `--rdzv_backend=c10d` when launching `accelerate launch`. You can also edit `config.yaml` directly. + ``` + accelerate launch --rdzv_backend=c10d sdxl_train_network.py ... + ``` + - In multi-GPU training, the memory of multiple GPUs is not integrated. In other words, even if you have two 12GB VRAM GPUs, you cannot train the model that requires 24GB VRAM. Training that can be done with 12GB VRAM is executed at (up to) twice the speed. + +Oct 11, 2024: +- ControlNet training for SDXL has been implemented in this branch. Please use `sdxl_train_control_net.py`. + - For details on defining the dataset, see [here](docs/train_lllite_README.md#creating-a-dataset-configuration-file). + - The learning rate for the copy part of the U-Net is specified by `--learning_rate`. The learning rate for the added modules in ControlNet is specified by `--control_net_lr`. The optimal value is still unknown, but try around U-Net `1e-5` and ControlNet `1e-4`. + - If you want to generate sample images, specify the control image as `--cn path/to/control/image`. + - The trained weights are automatically converted and saved in Diffusers format. It should be available in ComfyUI. +- Weighting of prompts (captions) during training in SDXL is now supported (e.g., `(some text)`, `[some text]`, `(some text:1.4)`, etc.). The function is enabled by specifying `--weighted_captions`. + - The default is `False`. It is same as before, and the parentheses are used as normal text. + - If `--weighted_captions` is specified, please use `\` to escape the parentheses in the prompt. For example, `\(some text:1.4\)`. + Oct 6, 2024: - In FLUX.1 LoRA training and fine-tuning, the specified weight file (*.safetensors) is automatically determined to be dev or schnell. This allows schnell models to be loaded correctly. Note that LoRA training with schnell models and fine-tuning with schnell models are unverified. - FLUX.1 LoRA training and fine-tuning can now load weights in Diffusers format in addition to BFL format (a single *.safetensors file). Please specify the parent directory of `transformer` or `diffusion_pytorch_model-00001-of-00003.safetensors` with the full path. However, Diffusers format CLIP/T5XXL is not supported. Saving is supported only in BFL format. diff --git a/docs/train_lllite_README.md b/docs/train_lllite_README.md index a05f87f5f..1bd8e4ae1 100644 --- a/docs/train_lllite_README.md +++ b/docs/train_lllite_README.md @@ -185,7 +185,7 @@ for img_file in img_files: ### Creating a dataset configuration file -You can use the command line arguments of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`. +You can use the command line argument `--conditioning_data_dir` of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`. ```toml [general] diff --git a/fine_tune.py b/fine_tune.py index 62a545a13..fd63385b3 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -366,22 +366,17 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning if args.weighted_captions: - # TODO move to strategy_sd.py - encoder_hidden_states = get_weighted_text_embeddings( - tokenize_strategy.tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, [text_encoder], input_ids_list, weights_list + )[0] else: input_ids = batch["input_ids_list"][0].to(accelerator.device) encoder_hidden_states = text_encoding_strategy.encode_tokens( tokenize_strategy, [text_encoder], [input_ids] )[0] - if args.full_fp16: - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified diff --git a/gen_img.py b/gen_img.py index 59bcd5b09..421d5c0b9 100644 --- a/gen_img.py +++ b/gen_img.py @@ -43,8 +43,8 @@ ) from einops import rearrange from tqdm import tqdm -from torchvision import transforms from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor +from accelerate import init_empty_weights import PIL from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -58,6 +58,7 @@ from tools.original_control_net import ControlNetInfo from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel from library.sdxl_original_unet import InferSdxlUNet2DConditionModel +from library.sdxl_original_control_net import SdxlControlNet from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL @@ -352,8 +353,8 @@ def __init__( self.token_replacements_list.append({}) # ControlNet - self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5 - self.control_net_lllites: List[ControlNetLLLite] = [] + self.control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = [] + self.control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない self.gradual_latent: GradualLatent = None @@ -542,7 +543,7 @@ def __call__( else: text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - if self.control_net_lllites: + if self.control_net_lllites or (self.control_nets and self.is_sdxl): # ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う if isinstance(clip_guide_images, PIL.Image.Image): clip_guide_images = [clip_guide_images] @@ -731,7 +732,12 @@ def __call__( num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 if self.control_nets: - guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if not self.is_sdxl: + guided_hints = original_control_net.get_guided_hints( + self.control_nets, num_latent_input, batch_size, clip_guide_images + ) + else: + clip_guide_images = clip_guide_images * 0.5 + 0.5 # [-1, 1] => [0, 1] each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) if self.control_net_lllites: @@ -793,7 +799,7 @@ def __call__( latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo + # disable ControlNet-LLLite or SDXL ControlNet if ratio is set. ControlNet is disabled in ControlNetInfo if self.control_net_lllites: for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)): if not enabled or ratio >= 1.0: @@ -802,9 +808,16 @@ def __call__( logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") control_net.set_cond_image(None) each_control_net_enabled[j] = False + if self.control_nets and self.is_sdxl: + for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)): + if not enabled or ratio >= 1.0: + continue + if ratio < i / len(timesteps): + logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + each_control_net_enabled[j] = False # predict the noise residual - if self.control_nets and self.control_net_enabled: + if self.control_nets and self.control_net_enabled and not self.is_sdxl: if regional_network: num_sub_and_neg_prompts = len(text_embeddings) // batch_size text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt @@ -823,6 +836,31 @@ def __call__( text_embeddings, text_emb_last, ).sample + elif self.control_nets: + input_resi_add_list = [] + mid_add_list = [] + for (control_net, _), enbld in zip(self.control_nets, each_control_net_enabled): + if not enbld: + continue + input_resi_add, mid_add = control_net( + latent_model_input, t, text_embeddings, vector_embeddings, clip_guide_images + ) + input_resi_add_list.append(input_resi_add) + mid_add_list.append(mid_add) + if len(input_resi_add_list) == 0: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + else: + if len(input_resi_add_list) > 1: + # get mean of input_resi_add_list and mid_add_list + input_resi_add_mean = [] + for i in range(len(input_resi_add_list[0])): + input_resi_add_mean.append( + torch.mean(torch.stack([input_resi_add_list[j][i] for j in range(len(input_resi_add_list))], dim=0)) + ) + input_resi_add = input_resi_add_mean + mid_add = torch.mean(torch.stack(mid_add_list), dim=0) + + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add) elif self.is_sdxl: noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) else: @@ -1827,16 +1865,37 @@ def __getattr__(self, item): upscaler.to(dtype).to(device) # ControlNetの処理 - control_nets: List[ControlNetInfo] = [] + control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = [] if args.control_net_models: - for i, model in enumerate(args.control_net_models): - prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + if not is_sdxl: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + else: + for i, model_file in enumerate(args.control_net_models): + multiplier = ( + 1.0 + if not args.control_net_multipliers or len(args.control_net_multipliers) <= i + else args.control_net_multipliers[i] + ) + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + logger.info(f"loading SDXL ControlNet: {model_file}") + from safetensors.torch import load_file + + state_dict = load_file(model_file) - ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) - prep = original_control_net.load_preprocess(prep_type) - control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + logger.info(f"Initializing SDXL ControlNet with multiplier: {multiplier}") + with init_empty_weights(): + control_net = SdxlControlNet(multiplier=multiplier) + control_net.load_state_dict(state_dict) + control_net.to(dtype).to(device) + control_nets.append((control_net, ratio)) control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] if args.control_net_lllite_models: diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index 03b182566..9196eb0f2 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -13,12 +13,20 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from diffusers import SchedulerMixin, StableDiffusionPipeline -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.models import AutoencoderKL +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.utils import logging from PIL import Image -from library import sdxl_model_util, sdxl_train_util, train_util +from library import ( + sdxl_model_util, + sdxl_train_util, + strategy_base, + strategy_sdxl, + train_util, + sdxl_original_unet, + sdxl_original_control_net, +) try: @@ -537,7 +545,7 @@ def __init__( vae: AutoencoderKL, text_encoder: List[CLIPTextModel], tokenizer: List[CLIPTokenizer], - unet: UNet2DConditionModel, + unet: Union[sdxl_original_unet.SdxlUNet2DConditionModel, sdxl_original_control_net.SdxlControlledUNet], scheduler: SchedulerMixin, # clip_skip: int, safety_checker: StableDiffusionSafetyChecker, @@ -594,74 +602,6 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - is_sdxl_text_encoder2, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `list(int)`): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - is_sdxl_text_encoder2=is_sdxl_text_encoder2, - ) - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ?? - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - if text_pool is not None: - text_pool = text_pool.repeat(1, num_images_per_prompt) - text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1) - - if do_classifier_free_guidance: - bs_embed, seq_len, _ = uncond_embeddings.shape - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - if uncond_pool is not None: - uncond_pool = uncond_pool.repeat(1, num_images_per_prompt) - uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1) - - return text_embeddings, text_pool, uncond_embeddings, uncond_pool - - return text_embeddings, text_pool, None, None - def check_inputs(self, prompt, height, width, strength, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @@ -792,7 +732,7 @@ def __call__( max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, - controlnet=None, + controlnet: sdxl_original_control_net.SdxlControlNet = None, controlnet_image=None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, @@ -896,32 +836,24 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - # 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す - # To simplify the implementation, switch the tokenzer/text encoder and call it twice - text_embeddings_list = [] - text_pool = None - uncond_embeddings_list = [] - uncond_pool = None - for i in range(len(self.tokenizers)): - self.tokenizer = self.tokenizers[i] - self.text_encoder = self.text_encoders[i] - - text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - is_sdxl_text_encoder2=i == 1, - ) - text_embeddings_list.append(text_embeddings) - uncond_embeddings_list.append(uncond_embeddings) + tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy: strategy_sdxl.SdxlTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - if tp1 is not None: - text_pool = tp1 - if up1 is not None: - uncond_pool = up1 + text_input_ids, text_weights = tokenize_strategy.tokenize_with_weights(prompt) + hidden_states_1, hidden_states_2, text_pool = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, self.text_encoders, text_input_ids, text_weights + ) + text_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1) + + if do_classifier_free_guidance: + input_ids, weights = tokenize_strategy.tokenize_with_weights(negative_prompt or "") + hidden_states_1, hidden_states_2, uncond_pool = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, self.text_encoders, input_ids, weights + ) + uncond_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1) + else: + uncond_embeddings = None + uncond_pool = None unet_dtype = self.unet.dtype dtype = unet_dtype @@ -970,23 +902,23 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # create size embs and concat embeddings for SDXL - orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype) + orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(device, dtype) crop_size = torch.zeros_like(orig_size) target_size = orig_size - embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype) + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(device, dtype) # make conditionings + text_pool = text_pool.to(device, dtype) if do_classifier_free_guidance: - text_embeddings = torch.cat(text_embeddings_list, dim=2) - uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2) - text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype) + text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(device, dtype) - cond_vector = torch.cat([text_pool, embs], dim=1) - uncond_vector = torch.cat([uncond_pool, embs], dim=1) - vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype) + uncond_pool = uncond_pool.to(device, dtype) + cond_vector = torch.cat([text_pool, embs], dim=1).to(dtype) + uncond_vector = torch.cat([uncond_pool, embs], dim=1).to(dtype) + vector_embedding = torch.cat([uncond_vector, cond_vector]) else: - text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype) - vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype) + text_embedding = text_embeddings.to(device, dtype) + vector_embedding = torch.cat([text_pool, embs], dim=1) # 8. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): @@ -994,22 +926,14 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - unet_additional_args = {} - if controlnet is not None: - down_block_res_samples, mid_block_res_sample = controlnet( - latent_model_input, - t, - encoder_hidden_states=text_embeddings, - controlnet_cond=controlnet_image, - conditioning_scale=1.0, - guess_mode=False, - return_dict=False, - ) - unet_additional_args["down_block_additional_residuals"] = down_block_res_samples - unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample + # FIXME SD1 ControlNet is not working # predict the noise residual - noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) + if controlnet is not None: + input_resi_add, mid_add = controlnet(latent_model_input, t, text_embedding, vector_embedding, controlnet_image) + noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding, input_resi_add, mid_add) + else: + noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training # perform guidance diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 4fad78a1c..0466c1fa5 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -8,7 +8,7 @@ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet -from .utils import setup_logging +from library.utils import setup_logging setup_logging() import logging diff --git a/library/sdxl_original_control_net.py b/library/sdxl_original_control_net.py new file mode 100644 index 000000000..3af45f4db --- /dev/null +++ b/library/sdxl_original_control_net.py @@ -0,0 +1,272 @@ +# some parts are modified from Diffusers library (Apache License 2.0) + +import math +from types import SimpleNamespace +from typing import Any, Optional +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from einops import rearrange +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sdxl_original_unet +from library.sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl + + +class ControlNetConditioningEmbedding(nn.Module): + def __init__(self): + super().__init__() + + dims = [16, 32, 96, 256] + + self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1) + self.blocks = nn.ModuleList([]) + + for i in range(len(dims) - 1): + channel_in = dims[i] + channel_out = dims[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1) + nn.init.zeros_(self.conv_out.weight) # zero module weight + nn.init.zeros_(self.conv_out.bias) # zero module bias + + def forward(self, x): + x = self.conv_in(x) + x = F.silu(x) + for block in self.blocks: + x = block(x) + x = F.silu(x) + x = self.conv_out(x) + return x + + +class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel): + def __init__(self, multiplier: Optional[float] = None, **kwargs): + super().__init__(**kwargs) + self.multiplier = multiplier + + # remove unet layers + self.output_blocks = nn.ModuleList([]) + del self.out + + self.controlnet_cond_embedding = ControlNetConditioningEmbedding() + + dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280] + self.controlnet_down_blocks = nn.ModuleList([]) + for dim in dims: + self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1)) + nn.init.zeros_(self.controlnet_down_blocks[-1].weight) # zero module weight + nn.init.zeros_(self.controlnet_down_blocks[-1].bias) # zero module bias + + self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1) + nn.init.zeros_(self.controlnet_mid_block.weight) # zero module weight + nn.init.zeros_(self.controlnet_mid_block.bias) # zero module bias + + def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel): + unet_sd = unet.state_dict() + unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")} + sd = super().state_dict() + sd.update(unet_sd) + info = super().load_state_dict(sd, strict=True, assign=True) + return info + + def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any: + # convert state_dict to SAI format + unet_sd = {} + for k in list(state_dict.keys()): + if not k.startswith("controlnet_"): + unet_sd[k] = state_dict.pop(k) + unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd) + state_dict.update(unet_sd) + super().load_state_dict(state_dict, strict=strict, assign=assign) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + # convert state_dict to Diffusers format + state_dict = super().state_dict(destination, prefix, keep_vars) + control_net_sd = {} + for k in list(state_dict.keys()): + if k.startswith("controlnet_"): + control_net_sd[k] = state_dict.pop(k) + state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict) + state_dict.update(control_net_sd) + return state_dict + + def forward( + self, + x: torch.Tensor, + timesteps: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + cond_image: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + if isinstance(layer, sdxl_original_unet.ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, sdxl_original_unet.Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + h = x + multiplier = self.multiplier if self.multiplier is not None else 1.0 + hs = [] + for i, module in enumerate(self.input_blocks): + h = call_module(module, h, emb, context) + if i == 0: + h = self.controlnet_cond_embedding(cond_image) + h + hs.append(self.controlnet_down_blocks[i](h) * multiplier) + + h = call_module(self.middle_block, h, emb, context) + h = self.controlnet_mid_block(h) * multiplier + + return hs, h + + +class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel): + """ + This class is for training purpose only. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + hs = [] + t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + if isinstance(layer, sdxl_original_unet.ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, sdxl_original_unet.Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + h = x + for module in self.input_blocks: + h = call_module(module, h, emb, context) + hs.append(h) + + h = call_module(self.middle_block, h, emb, context) + h = h + mid_add + + for module in self.output_blocks: + resi = hs.pop() + input_resi_add.pop() + h = torch.cat([h, resi], dim=1) + h = call_module(module, h, emb, context) + + h = h.type(x.dtype) + h = call_module(self.out, h, emb, context) + + return h + + +if __name__ == "__main__": + import time + + logger.info("create unet") + unet = SdxlControlledUNet() + unet.to("cuda", torch.bfloat16) + unet.set_use_sdpa(True) + unet.set_gradient_checkpointing(True) + unet.train() + + logger.info("create control_net") + control_net = SdxlControlNet() + control_net.to("cuda") + control_net.set_use_sdpa(True) + control_net.set_gradient_checkpointing(True) + control_net.train() + + logger.info("Initialize control_net from unet") + control_net.init_from_unet(unet) + + unet.requires_grad_(False) + control_net.requires_grad_(True) + + # 使用メモリ量確認用の疑似学習ループ + logger.info("preparing optimizer") + + # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working + + import bitsandbytes + + optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) # not working + # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + + # import transformers + # optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2 + + scaler = torch.cuda.amp.GradScaler(enabled=True) + + logger.info("start training") + steps = 10 + batch_size = 1 + + for step in range(steps): + logger.info(f"step {step}") + if step == 1: + time_start = time.perf_counter() + + x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024 + t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda") + txt = torch.randn(batch_size, 77, 2048).cuda() + vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() + cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda() + + with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): + input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img) + output = unet(x, t, txt, vector, input_resi_add, mid_add) + target = torch.randn_like(output) + loss = torch.nn.functional.mse_loss(output, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + + time_end = time.perf_counter() + logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") + + logger.info("finish training") + sd = control_net.state_dict() + + from safetensors.torch import save_file + + save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors") diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 17c345a89..0aa07d0d6 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -30,7 +30,7 @@ from torch import nn from torch.nn import functional as F from einops import rearrange -from .utils import setup_logging +from library.utils import setup_logging setup_logging() import logging @@ -1156,9 +1156,9 @@ def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_ti self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 self.ds_ratio = ds_ratio - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): r""" - current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink. + current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink and ControlNet. """ _self = self.delegate @@ -1209,6 +1209,8 @@ def call_module(module, h, emb, context): hs.append(h) h = call_module(_self.middle_block, h, emb, context) + if mid_add is not None: + h = h + mid_add for module in _self.output_blocks: # Deep Shrink @@ -1217,7 +1219,11 @@ def call_module(module, h, emb, context): # print("upsample", h.shape, hs[-1].shape) h = resize_like(h, hs[-1]) - h = torch.cat([h, hs.pop()], dim=1) + resi = hs.pop() + if input_resi_add is not None: + resi = resi + input_resi_add.pop() + + h = torch.cat([h, resi], dim=1) h = call_module(module, h, emb, context) # Deep Shrink: in case of depth 0 diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f009b5779..dc3887c34 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -12,7 +12,6 @@ from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet -from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline from .utils import setup_logging setup_logging() @@ -364,9 +363,9 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin # ) # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") - assert ( - not hasattr(args, "weighted_captions") or not args.weighted_captions - ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + # assert ( + # not hasattr(args, "weighted_captions") or not args.weighted_captions + # ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" if supportTextEncoderCaching: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: @@ -378,4 +377,6 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin def sample_images(*args, **kwargs): + from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline + return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) diff --git a/library/strategy_base.py b/library/strategy_base.py index e7d3a97ef..2bff4178a 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -1,6 +1,7 @@ # base class for platform strategies. this file defines the interface for strategies import os +import re from typing import Any, List, Optional, Tuple, Union import numpy as np @@ -22,6 +23,24 @@ class TokenizeStrategy: _strategy = None # strategy instance: actual strategy class + _re_attention = re.compile( + r"""\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, + ) + @classmethod def set_strategy(cls, strategy): if cls._strategy is not None: @@ -54,7 +73,154 @@ def _load_tokenizer( def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: raise NotImplementedError - def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor: + def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + returns: [tokens1, tokens2, ...], [weights1, weights2, ...] + """ + raise NotImplementedError + + def _get_weighted_input_ids( + self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + max_length includes starting and ending tokens. + """ + + def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in TokenizeStrategy._re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + def get_prompts_with_weights(text: str, max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token. + + No padding, starting or ending token is included. + """ + truncated = False + + texts_and_weights = parse_prompt_attention(text) + tokens = [] + weights = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + tokens += token + # copy the weight by length of token + weights += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(tokens) > max_length: + truncated = True + break + # truncate + if len(tokens) > max_length: + truncated = True + tokens = tokens[:max_length] + weights = weights[:max_length] + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens)) + weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights)) + return tokens, weights + + if max_length is None: + max_length = tokenizer.model_max_length + + tokens, weights = get_prompts_with_weights(text, max_length - 2) + tokens, weights = pad_tokens_and_weights( + tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id + ) + return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0) + + def _get_input_ids( + self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False + ) -> torch.Tensor: """ for SD1.5/2.0/SDXL TODO support batch input @@ -62,7 +228,10 @@ def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Option if max_length is None: max_length = tokenizer.model_max_length - 2 - input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids + if weighted: + input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length) + else: + input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids if max_length > tokenizer.model_max_length: input_ids = input_ids.squeeze(0) @@ -101,6 +270,17 @@ def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Option iids_list.append(ids_chunk) input_ids = torch.stack(iids_list) # 3,77 + + if weighted: + weights = weights.squeeze(0) + new_weights = torch.ones(input_ids.shape) + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): + b = i // (tokenizer.model_max_length - 2) + new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2] + weights = new_weights + + if weighted: + return input_ids, weights return input_ids @@ -127,17 +307,34 @@ def encode_tokens( """ raise NotImplementedError + def encode_tokens_with_weights( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Encode tokens into embeddings and outputs. + :param tokens: list of token tensors for each TextModel + :param weights: list of weight tensors for each TextModel + :return: list of output embeddings for each architecture + """ + raise NotImplementedError + class TextEncoderOutputsCachingStrategy: _strategy = None # strategy instance: actual strategy class def __init__( - self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + is_weighted: bool = False, ) -> None: self._cache_to_disk = cache_to_disk self._batch_size = batch_size self.skip_disk_cache_validity_check = skip_disk_cache_validity_check self._is_partial = is_partial + self._is_weighted = is_weighted @classmethod def set_strategy(cls, strategy): @@ -161,6 +358,10 @@ def batch_size(self): def is_partial(self): return self._is_partial + @property + def is_weighted(self): + return self._is_weighted + def get_outputs_npz_path(self, image_abs_path: str) -> str: raise NotImplementedError diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 83ffaa31b..4e7931fdb 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -40,6 +40,16 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: text = [text] if isinstance(text, str) else text return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + text = [text] if isinstance(text, str) else text + tokens_list = [] + weights_list = [] + for t in text: + tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True) + tokens_list.append(tokens) + weights_list.append(weights) + return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)] + class SdTextEncodingStrategy(TextEncodingStrategy): def __init__(self, clip_skip: Optional[int] = None) -> None: @@ -58,6 +68,8 @@ def encode_tokens( model_max_length = sd_tokenize_strategy.tokenizer.model_max_length tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77 + tokens = tokens.to(text_encoder.device) + if self.clip_skip is None: encoder_hidden_states = text_encoder(tokens)[0] else: @@ -93,6 +105,30 @@ def encode_tokens( return [encoder_hidden_states] + def encode_tokens_with_weights( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], + ) -> List[torch.Tensor]: + encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0] + + weights = weights_list[0].to(encoder_hidden_states.device) + + # apply weights + if weights.shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2) + else: + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for i in range(weights.shape[1]): + encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[ + :, i, 1:-1 + ].unsqueeze(-1) + + return [encoder_hidden_states] + class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index 3eb0ab6f6..6b3e2afa6 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -37,6 +37,22 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0), ) + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + text = [text] if isinstance(text, str) else text + tokens1_list, tokens2_list = [], [] + weights1_list, weights2_list = [], [] + for t in text: + tokens1, weights1 = self._get_input_ids(self.tokenizer1, t, self.max_length, weighted=True) + tokens2, weights2 = self._get_input_ids(self.tokenizer2, t, self.max_length, weighted=True) + tokens1_list.append(tokens1) + tokens2_list.append(tokens2) + weights1_list.append(weights1) + weights2_list.append(weights2) + return [torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)], [ + torch.stack(weights1_list, dim=0), + torch.stack(weights2_list, dim=0), + ] + class SdxlTextEncodingStrategy(TextEncodingStrategy): def __init__(self) -> None: @@ -98,7 +114,10 @@ def _get_hidden_states_sdxl( ): # input_ids: b,n,77 -> b*n, 77 b_size = input_ids1.size()[0] - max_token_length = input_ids1.size()[1] * input_ids1.size()[2] + if input_ids1.size()[1] == 1: + max_token_length = None + else: + max_token_length = input_ids1.size()[1] * input_ids1.size()[2] input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 input_ids1 = input_ids1.to(text_encoder1.device) @@ -155,7 +174,8 @@ def encode_tokens( """ Args: tokenize_strategy: TokenizeStrategy - models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)] + models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]. + If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required tokens: List of tokens, for text_encoder1 and text_encoder2 """ if len(models) == 2: @@ -172,14 +192,45 @@ def encode_tokens( ) return [hidden_states1, hidden_states2, pool2] + def encode_tokens_with_weights( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], + ) -> List[torch.Tensor]: + hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens_list) + + weights_list = [weights.to(hidden_states1.device) for weights in weights_list] + + # apply weights + if weights_list[0].shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + hidden_states1 = hidden_states1 * weights_list[0].squeeze(1).unsqueeze(2) + hidden_states2 = hidden_states2 * weights_list[1].squeeze(1).unsqueeze(2) + else: + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for weight, hidden_states in zip(weights_list, [hidden_states1, hidden_states2]): + for i in range(weight.shape[1]): + hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[ + :, i, 1:-1 + ].unsqueeze(-1) + + return [hidden_states1, hidden_states2, pool2] + class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" def __init__( - self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + is_weighted: bool = False, ) -> None: - super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted) def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX @@ -215,11 +266,19 @@ def cache_batch_outputs( sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy captions = [info.caption for info in infos] - tokens1, tokens2 = tokenize_strategy.tokenize(captions) - with torch.no_grad(): - hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, [tokens1, tokens2] - ) + if self.is_weighted: + tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, models, tokens_list, weights_list + ) + else: + tokens1, tokens2 = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, [tokens1, tokens2] + ) + if hidden_state1.dtype == torch.bfloat16: hidden_state1 = hidden_state1.float() if hidden_state2.dtype == torch.bfloat16: diff --git a/library/train_util.py b/library/train_util.py index c75d407f0..5f582b643 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -31,8 +31,10 @@ import subprocess from io import BytesIO import toml +# from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm +from packaging.version import Version import torch from library.device_utils import init_ipex, clean_memory_on_device @@ -74,6 +76,7 @@ import cv2 import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline +from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec @@ -911,6 +914,23 @@ def make_buckets(self): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) + # # run in parallel + # max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes) + # with ThreadPoolExecutor(max_workers) as executor: + # futures = [] + # for info in tqdm(self.image_data.values(), desc="loading image sizes"): + # if info.image_size is None: + # def get_and_set_image_size(info): + # info.image_size = self.get_image_size(info.absolute_path) + # futures.append(executor.submit(get_and_set_image_size, info)) + # # consume futures to reduce memory usage and prevent Ctrl-C hang + # if len(futures) >= max_workers: + # for future in futures: + # future.result() + # futures = [] + # for future in futures: + # future.result() + if self.enable_bucket: logger.info("make buckets") else: @@ -1825,7 +1845,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] missing_captions = [] - for img_path in img_paths: + for img_path in tqdm(img_paths, desc="read caption"): cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) if cap_for_img is None and subset.class_tokens is None: logger.warning( @@ -3587,7 +3607,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: # available backends: # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + choices=[ + "eager", + "aot_eager", + "inductor", + "aot_ts_nvfuser", + "nvprims_nvfuser", + "cudagraphs", + "ofi", + "fx2trt", + "onnxrt", + "tensort", + "ipex", + "tvm", + ], help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") @@ -5051,17 +5084,18 @@ def prepare_accelerator(args: argparse.Namespace): if args.torch_compile: dynamo_backend = args.dynamo_backend - kwargs_handlers = ( - InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, - ( - DistributedDataParallelKwargs( - gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph - ) - if args.ddp_gradient_as_bucket_view or args.ddp_static_graph - else None - ), - ) - kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) + kwargs_handlers = [ + InitProcessGroupKwargs( + backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None, + timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None + ) if torch.cuda.device_count() > 1 else None, + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, + static_graph=args.ddp_static_graph + ) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None + ] + kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) accelerator = Accelerator( @@ -5856,8 +5890,8 @@ def sample_images_common( pipe_class, accelerator: Accelerator, args: argparse.Namespace, - epoch, - steps, + epoch: int, + steps: int, device, vae, tokenizer, @@ -5916,11 +5950,7 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - # schedulers: dict = {} cannot find where this is used - default_scheduler = get_my_scheduler( - sample_sampler=args.sample_sampler, - v_parameterization=args.v_parameterization, - ) + default_scheduler = get_my_scheduler(sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization) pipeline = pipe_class( text_encoder=text_encoder, @@ -5981,21 +6011,18 @@ def sample_images_common( # clear pipeline and cache to reduce vram usage del pipeline - # I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here. - # with torch.cuda.device(torch.cuda.current_device()): - # torch.cuda.empty_cache() - clean_memory_on_device(accelerator.device) - torch.set_rng_state(rng_state) if torch.cuda.is_available() and cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, - pipeline, + pipeline: Union[StableDiffusionLongPromptWeightingPipeline, SdxlStableDiffusionLongPromptWeightingPipeline], save_dir, prompt_dict, epoch, diff --git a/sdxl_train.py b/sdxl_train.py index 7291ddd2f..aeff9c469 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -104,8 +104,8 @@ def train(args): setup_logging(args, reset=True) assert ( - not args.weighted_captions - ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + not args.weighted_captions or not args.cache_text_encoder_outputs + ), "weighted_captions is not supported when caching text encoder outputs / cache_text_encoder_outputsを使うときはweighted_captionsはサポートされていません" assert ( not args.train_text_encoder or not args.cache_text_encoder_outputs ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" @@ -321,7 +321,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False + args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) @@ -660,22 +660,24 @@ def optimizer_hook(parameter: torch.Tensor): input_ids1, input_ids2 = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning - # TODO support weighted captions - # if args.weighted_captions: - # encoder_hidden_states = get_weighted_text_embeddings( - # tokenizer, - # text_encoder, - # batch["captions"], - # accelerator.device, - # args.max_token_length // 75 if args.max_token_length else 1, - # clip_skip=args.clip_skip, - # ) - # else: - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( - tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] - ) + if args.weighted_captions: + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states1, encoder_hidden_states2, pool2 = ( + text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + [text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)], + input_ids_list, + weights_list, + ) + ) + else: + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, + [text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)], + [input_ids1, input_ids2], + ) if args.full_fp16: encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py new file mode 100644 index 000000000..67c8d52c8 --- /dev/null +++ b/sdxl_train_control_net.py @@ -0,0 +1,722 @@ +import argparse +import math +import os +import random +from multiprocessing import Value +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from accelerate import init_empty_weights +from diffusers import DDPMScheduler +from diffusers.utils.torch_utils import is_compiled_module +from safetensors.torch import load_file +from library import ( + deepspeed_utils, + sai_model_spec, + sdxl_model_util, + sdxl_train_util, + strategy_base, + strategy_sd, + strategy_sdxl, +) + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.huggingface_util as huggingface_util +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + add_v_prediction_like_loss, + apply_snr_weight, + prepare_scheduler_for_custom_training, + scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, +) +from library.sdxl_original_control_net import SdxlControlNet, SdxlControlledUNet +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = { + "loss/current": current_loss, + "loss/average": avr_loss, + "lr": lr_scheduler.get_last_lr()[0], + } + + if args.optimizer_type.lower().startswith("DAdapt".lower()): + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + + return logs + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) + + cache_latents = args.cache_latents + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizer1, tokenizer2 = tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2 # this is used for sampling images + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if use_user_config: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension, + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(32) + + if args.debug_dataset: + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + else: + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) + + unet.to(accelerator.device) # reduce main memory usage + + # convert U-Net to Controlled U-Net + logger.info("convert U-Net to Controlled U-Net") + unet_sd = unet.state_dict() + with init_empty_weights(): + unet = SdxlControlledUNet() + unet.load_state_dict(unet_sd, strict=True, assign=True) + del unet_sd + + # make control net + logger.info("make ControlNet") + if args.controlnet_model_path: + with init_empty_weights(): + control_net = SdxlControlNet() + + logger.info(f"load ControlNet from {args.controlnet_model_path}") + filename = args.controlnet_model_path + if os.path.splitext(filename)[1] == ".safetensors": + state_dict = load_file(filename) + else: + state_dict = torch.load(filename) + info = control_net.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"ControlNet loaded from {filename}: {info}") + else: + control_net = SdxlControlNet() + + logger.info("initialize ControlNet from U-Net") + info = control_net.init_from_unet(unet) + logger.info(f"ControlNet initialized from U-Net: {info}") + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + + accelerator.wait_for_everyone() + + # モデルに xformers とか memory efficient attention を組み込む + # train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if args.xformers: + unet.set_use_memory_efficient_attention(True, False) + control_net.set_use_memory_efficient_attention(True, False) + elif args.sdpa: + unet.set_use_sdpa(True) + control_net.set_use_sdpa(True) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + control_net.enable_gradient_checkpointing() + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + trainable_params = [] + ctrlnet_params = [] + unet_params = [] + for name, param in control_net.named_parameters(): + if name.startswith("controlnet_"): + ctrlnet_params.append(param) + else: + unet_params.append(param) + trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr}) + trainable_params.append({"params": unet_params, "lr": args.learning_rate}) + all_params = ctrlnet_params + unet_params + + logger.info(f"trainable params count: {len(all_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}") + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + control_net.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + control_net.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + control_net, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + control_net, optimizer, train_dataloader, lr_scheduler + ) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + unet.requires_grad_(False) + text_encoder1.requires_grad_(False) + text_encoder2.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + + unet.eval() + control_net.train() + + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + text_encoder2.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + ("sdxl_control_net_train" if args.log_tracker_name is None else args.log_tracker_name), + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + loss_recorder = train_util.LossRecorder() + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name, model, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) + sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet" + state_dict = model.state_dict() + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(ckpt_file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, ckpt_file, sai_metadata) + else: + torch.save(state_dict, ckpt_file) + + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, + args, + 0, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], + unet, + controlnet=control_net, + ) + + # training loop + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + control_net.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(control_net): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] + with torch.no_grad(): + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], [input_ids1, input_ids2] + ) + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + # '-1 to +1' to '0 to 1' + controlnet_image = (controlnet_image + 1) / 2 + + with accelerator.autocast(): + input_resi_add, mid_add = control_net( + noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image + ) + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, input_resi_add, mid_add) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if not args.fused_backward_pass: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = control_net.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + sdxl_train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], + unet, + controlnet=control_net, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name, unwrap_model(control_net)) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if len(accelerator.trackers) > 0: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, unwrap_model(control_net)) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + sdxl_train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], + unet, + controlnet=control_net, + ) + + # end of epoch + + if is_main_process: + control_net = unwrap_model(control_net) + + accelerator.end_training() + + if is_main_process and (args.save_state or args.save_state_on_train_end): + train_util.save_state_on_train_end(args, accelerator) + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, control_net, force_sync_upload=True) + + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + # train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + # train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + + parser.add_argument( + "--controlnet_model_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + parser.add_argument( + "--control_net_lr", + type=float, + default=1e-4, + help="learning rate for controlnet modules / controlnetモジュールの学習率", + ) + return parser + + +if __name__ == "__main__": + # sdxl_original_unet.USE_REENTRANT = False + + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 4d6e3f184..20e32155c 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -79,7 +79,9 @@ def get_models_for_text_encoding(self, args, accelerator, text_encoders): def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: - return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) + return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions + ) else: return None diff --git a/train_controlnet.py b/train_controlnet.py index c2945b083..8c7882c8f 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -254,6 +254,7 @@ def __contains__(self, name): accelerator.wait_for_everyone() if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() controlnet.enable_gradient_checkpointing() # 学習に必要なクラスを準備する @@ -304,6 +305,20 @@ def __contains__(self, name): controlnet, optimizer, train_dataloader, lr_scheduler ) + if args.fused_backward_pass: + import library.adafactor_fused + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + unet.requires_grad_(False) text_encoder.requires_grad_(False) unet.to(accelerator.device) @@ -497,13 +512,17 @@ def remove_model(old_ckpt_name): loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = controlnet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + if not args.fused_backward_pass: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/train_db.py b/train_db.py index a5d520b12..e49a7e70f 100644 --- a/train_db.py +++ b/train_db.py @@ -356,21 +356,17 @@ def train(args): # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenize_strategy.tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, [text_encoder], input_ids_list, weights_list + )[0] else: input_ids = batch["input_ids_list"][0].to(accelerator.device) encoder_hidden_states = text_encoding_strategy.encode_tokens( tokenize_strategy, [text_encoder], [input_ids] )[0] - if args.full_fp16: - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified diff --git a/train_network.py b/train_network.py index f0d397b9e..e48e6a070 100644 --- a/train_network.py +++ b/train_network.py @@ -1123,14 +1123,21 @@ def remove_model(old_ckpt_name): with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: - # SD only - encoded_text_encoder_conds = get_weighted_text_embeddings( - tokenizers[0], - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, + # # SD only + # encoded_text_encoder_conds = get_weighted_text_embeddings( + # tokenizers[0], + # text_encoder, + # batch["captions"], + # accelerator.device, + # args.max_token_length // 75 if args.max_token_length else 1, + # clip_skip=args.clip_skip, + # ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids_list, + weights_list, ) else: input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] @@ -1139,8 +1146,8 @@ def remove_model(old_ckpt_name): self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids, ) - if args.full_fp16: - encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + if args.full_fp16: + encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] # if text_encoder_conds is not cached, use encoded_text_encoder_conds if len(text_encoder_conds) == 0: