From e5268286bf90ddcc53ad1deb31aba857cfa967d5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 15 Jun 2024 22:20:24 +0900 Subject: [PATCH 001/163] add sd3 models and inference script --- library/sd3_models.py | 1796 ++++++++++++++++++++++++++++++++++++++ library/sd3_utils.py | 113 +++ sd3_minimal_inference.py | 347 ++++++++ 3 files changed, 2256 insertions(+) create mode 100644 library/sd3_models.py create mode 100644 library/sd3_utils.py create mode 100644 sd3_minimal_inference.py diff --git a/library/sd3_models.py b/library/sd3_models.py new file mode 100644 index 000000000..294a69b06 --- /dev/null +++ b/library/sd3_models.py @@ -0,0 +1,1796 @@ +# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref +# the original code is licensed under the MIT License + +# and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! + +from functools import partial +import math +from typing import Dict, Optional +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +from transformers import CLIPTokenizer, T5TokenizerFast + + +memory_efficient_attention = None +try: + import xformers +except: + pass + +try: + from xformers.ops import memory_efficient_attention +except: + memory_efficient_attention = None + + +# region tokenizer +class SDTokenizer: + def __init__( + self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None + ): + """ + サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。 + Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings. + """ + self.tokenizer = tokenizer + self.max_length = max_length + self.min_length = min_length + empty = self.tokenizer("")["input_ids"] + if has_start_token: + self.tokens_start = 1 + self.start_token = empty[0] + self.end_token = empty[1] + else: + self.tokens_start = 0 + self.start_token = None + self.end_token = empty[0] + self.pad_with_end = pad_with_end + self.pad_to_max_length = pad_to_max_length + vocab = self.tokenizer.get_vocab() + self.inv_vocab = {v: k for k, v in vocab.items()} + self.max_word_length = 8 + + def tokenize_with_weights(self, text: str): + """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. + The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" + """ + ja: テキストをトークン化し、重み値を持ちます - すべての値に1.0を仮定し、他の機能を無視します。 + 詳細は参考実装には関係なく、重み自体はSD3に対して弱い影響しかありません。へぇ~ + """ + if self.pad_with_end: + pad_token = self.end_token + else: + pad_token = 0 + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0)) + to_tokenize = text.replace("\n", " ").split(" ") + to_tokenize = [x for x in to_tokenize if x != ""] + for word in to_tokenize: + batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]]) + batch.append((self.end_token, 1.0)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) + if self.min_length is not None and len(batch) < self.min_length: + batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) + return [batch] + + +class T5XXLTokenizer(SDTokenizer): + """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + + def __init__(self): + super().__init__( + pad_with_end=False, + tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), + has_start_token=False, + pad_to_max_length=False, + max_length=99999999, + min_length=77, + ) + + +class SDXLClipGTokenizer(SDTokenizer): + def __init__(self, tokenizer): + super().__init__(pad_with_end=False, tokenizer=tokenizer) + + +class SD3Tokenizer: + def __init__(self, t5xxl=True): + # TODO cache tokenizer settings locally or hold them in the repo like ComfyUI + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) + self.clip_g = SDXLClipGTokenizer(clip_tokenizer) + self.t5xxl = T5XXLTokenizer() if t5xxl else None + + def tokenize_with_weights(self, text: str): + return ( + self.clip_l.tokenize_with_weights(text), + self.clip_g.tokenize_with_weights(text), + self.t5xxl.tokenize_with_weights(text) if self.t5xxl is not None else None, + ) + + +# endregion + +# region mmdit + + +def get_2d_sincos_pos_embed( + embed_dim, + grid_size, + scaling_factor=None, + offset=None, +): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + if scaling_factor is not None: + grid = grid / scaling_factor + if offset is not None: + grid = grid - offset + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid_torch( + embed_dim, + pos, + device=None, + dtype=torch.float32, +): + omega = torch.arange(embed_dim // 2, device=device, dtype=dtype) + omega *= 2.0 / embed_dim + omega = 1.0 / 10000**omega + out = torch.outer(pos.reshape(-1), omega) + emb = torch.cat([out.sin(), out.cos()], dim=1) + return emb + + +def get_2d_sincos_pos_embed_torch( + embed_dim, + w, + h, + val_center=7.5, + val_magnitude=7.5, + device=None, + dtype=torch.float32, +): + small = min(h, w) + val_h = (h / small) * val_magnitude + val_w = (w / small) * val_magnitude + grid_h, grid_w = torch.meshgrid( + torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), + torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), + indexing="ij", + ) + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) + emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) + return emb + + +def modulate(x, shift, scale): + if shift is None: + shift = torch.zeros_like(scale) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def default(x, default_value): + if x is None: + return default_value + return x + + +def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + # device=t.device, dtype=t.dtype + # ) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(dtype=t.dtype) + return embedding + + +def rmsnorm(x, eps=1e-6): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +class PatchEmbed(nn.Module): + def __init__( + self, + img_size=256, + patch_size=4, + in_channels=3, + embed_dim=512, + norm_layer=None, + flatten=True, + bias=True, + strict_img_size=True, + dynamic_img_pad=True, + ): + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad + if img_size is not None: + self.img_size = img_size + self.grid_size = img_size // patch_size + self.num_patches = self.grid_size**2 + else: + self.img_size = None + self.grid_size = None + self.num_patches = None + + self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias) + self.norm = nn.Identity() if norm_layer is None else norm_layer(embed_dim) + + def forward(self, x): + B, C, H, W = x.shape + + if self.dynamic_img_pad: + # Pad input so we won't have partial patch + pad_h = (self.patch_size - H % self.patch_size) % self.patch_size + pad_w = (self.patch_size - W % self.patch_size) % self.patch_size + x = nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="reflect") + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +# FinalLayer in mmdit.py +class UnPatch(nn.Module): + def __init__(self, hidden_size=512, patch_size=4, out_channels=3): + super().__init__() + self.patch_size = patch_size + self.c = out_channels + + # eps is default in mmdit.py + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size), + ) + + def forward(self, x: torch.Tensor, cmod, H=None, W=None): + b, n, _ = x.shape + p = self.patch_size + c = self.c + if H is None and W is None: + w = h = int(n**0.5) + assert h * w == n + else: + h = H // p if H else n // (W // p) + w = W // p if W else n // h + assert h * w == n + + shift, scale = self.adaLN_modulation(cmod).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + + x = x.view(b, h, w, p, p, c) + x = x.permute(0, 5, 1, 3, 2, 4).contiguous() + x = x.view(b, c, h * p, w * p) + return x + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=lambda: nn.GELU(), + norm_layer=None, + bias=True, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.use_conv = use_conv + + layer = partial(nn.Conv1d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = layer(in_features, hidden_features, bias=bias) + self.fc2 = layer(hidden_features, out_features, bias=bias) + self.act = act_layer() + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.fc2(x) + return x + + +class TimestepEmbedding(nn.Module): + def __init__(self, hidden_size, freq_embed_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(freq_embed_size, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + self.freq_embed_size = freq_embed_size + + def forward(self, t, dtype=None, **kwargs): + t_freq = timestep_embedding(t, self.freq_embed_size).to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class Embedder(nn.Module): + def __init__(self, input_dim, hidden_size): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + + def forward(self, x): + return self.mlp(x) + + +class RMSNorm(torch.nn.Module): + def __init__( + self, + dim: int, + elementwise_affine: bool = False, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + """ + super().__init__() + self.eps = eps + self.learnable_scale = elementwise_affine + if self.learnable_scale: + self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + """ + x = rmsnorm(x, eps=self.eps) + if self.learnable_scale: + return x * self.weight.to(device=x.device, dtype=x.dtype) + else: + return x + + +class SwiGLUFeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: float = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +# Linears for SelfAttention in mmdit.py +class AttentionLinears(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + pre_only: bool = False, + qk_norm: str = None, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + if not pre_only: + self.proj = nn.Linear(dim, dim) + self.pre_only = pre_only + + if qk_norm == "rms": + self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + elif qk_norm == "ln": + self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + elif qk_norm is None: + self.ln_q = nn.Identity() + self.ln_k = nn.Identity() + else: + raise ValueError(qk_norm) + + def pre_attention(self, x: torch.Tensor) -> torch.Tensor: + """ + output: + q, k, v: [B, L, D] + """ + B, L, C = x.shape + qkv: torch.Tensor = self.qkv(x) + q, k, v = qkv.reshape(B, L, -1, self.head_dim).chunk(3, dim=2) + q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1) + k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1) + return (q, k, v) + + def post_attention(self, x: torch.Tensor) -> torch.Tensor: + assert not self.pre_only + x = self.proj(x) + return x + + +MEMORY_LAYOUTS = { + "torch": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), + lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), + lambda x: (1, x, 1, 1), + ), + "xformers": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim), + lambda x: x.reshape(x.shape[0], x.shape[1], -1), + lambda x: (1, 1, x, 1), + ), + "math": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), + lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), + lambda x: (1, x, 1, 1), + ), +} +# ATTN_FUNCTION = { +# "torch": F.scaled_dot_product_attention, +# "xformers": memory_efficient_attention, +# } + + +def vanilla_attention(q, k, v, mask, scale=None): + if scale is None: + scale = math.sqrt(q.size(-1)) + scores = torch.bmm(q, k.transpose(-1, -2)) / scale + if mask is not None: + mask = einops.rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(scores.dtype).max + mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3)) + scores = scores.masked_fill(~mask, max_neg_value) + p_attn = F.softmax(scores, dim=-1) + return torch.bmm(p_attn, v) + + +def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"): + """ + q, k, v: [B, L, D] + """ + pre_attn_layout = MEMORY_LAYOUTS[mode][0] + post_attn_layout = MEMORY_LAYOUTS[mode][1] + q = pre_attn_layout(q, head_dim) + k = pre_attn_layout(k, head_dim) + v = pre_attn_layout(v, head_dim) + + # scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale) + if mode == "torch": + assert scale is None + scores = F.scaled_dot_product_attention(q, k.to(q), v.to(q), mask) # , scale=scale) + elif mode == "xformers": + scores = memory_efficient_attention(q, k.to(q), v.to(q), mask, scale=scale) + else: + scores = vanilla_attention(q, k.to(q), v.to(q), mask, scale=scale) + + scores = post_attn_layout(scores) + return scores + + +class SelfAttention(AttentionLinears): + def __init__(self, dim, num_heads=8, mode="xformers"): + super().__init__(dim, num_heads, qkv_bias=True, pre_only=False) + assert mode in MEMORY_LAYOUTS + self.head_dim = dim // num_heads + self.attn_mode = mode + + def set_attn_mode(self, mode): + self.attn_mode = mode + + def forward(self, x): + q, k, v = self.pre_attention(x) + attn_score = attention(q, k, v, self.head_dim, mode=self.attn_mode) + return self.post_attention(attn_score) + + +class TransformerBlock(nn.Module): + def __init__(self, context_size, mode="xformers"): + super().__init__() + self.context_size = context_size + self.norm1 = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) + self.attn = SelfAttention(context_size, mode=mode) + self.norm2 = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) + self.mlp = MLP( + in_features=context_size, + hidden_features=context_size * 4, + act_layer=lambda: nn.GELU(approximate="tanh"), + ) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, context_size, num_layers, mode="xformers"): + super().__init__() + self.layers = nn.ModuleList([TransformerBlock(context_size, mode) for _ in range(num_layers)]) + self.norm = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return self.norm(x) + + +# DismantledBlock in mmdit.py +class SingleDiTBlock(nn.Module): + """ + A DiT block with gated adaptive layer norm (adaLN) conditioning. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: str = "xformers", + qkv_bias: bool = False, + pre_only: bool = False, + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + qk_norm: Optional[str] = None, + **block_kwargs, + ): + super().__init__() + assert attn_mode in MEMORY_LAYOUTS + self.attn_mode = attn_mode + if not rmsnorm: + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + else: + self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = AttentionLinears( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + pre_only=pre_only, + qk_norm=qk_norm, + ) + if not pre_only: + if not rmsnorm: + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + else: + self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + if not pre_only: + if not swiglu: + self.mlp = MLP( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=lambda: nn.GELU(approximate="tanh"), + ) + else: + self.mlp = SwiGLUFeedForward( + dim=hidden_size, + hidden_dim=mlp_hidden_dim, + multiple_of=256, + ) + self.scale_mod_only = scale_mod_only + if not scale_mod_only: + n_mods = 6 if not pre_only else 2 + else: + n_mods = 4 if not pre_only else 1 + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size)) + self.pre_only = pre_only + + def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + if not self.pre_only: + if not self.scale_mod_only: + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation( + c + ).chunk(6, dim=-1) + else: + shift_msa = None + shift_mlp = None + ( + scale_msa, + gate_msa, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation( + c + ).chunk(4, dim=-1) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, ( + x, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) + else: + if not self.scale_mod_only: + ( + shift_msa, + scale_msa, + ) = self.adaLN_modulation( + c + ).chunk(2, dim=-1) + else: + shift_msa = None + scale_msa = self.adaLN_modulation(c) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, None + + def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): + assert not self.pre_only + x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +# JointBlock + block_mixing in mmdit.py +class MMDiTBlock(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + pre_only = kwargs.pop("pre_only") + self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs) + self.x_block = SingleDiTBlock(*args, pre_only=False, **kwargs) + self.head_dim = self.x_block.attn.head_dim + self.mode = self.x_block.attn_mode + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def _forward(self, context, x, c): + ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c) + x_qkv, x_intermediate = self.x_block.pre_attention(x, c) + + ctx_len = ctx_qkv[0].size(1) + + q = torch.concat((ctx_qkv[0], x_qkv[0]), dim=1) + k = torch.concat((ctx_qkv[1], x_qkv[1]), dim=1) + v = torch.concat((ctx_qkv[2], x_qkv[2]), dim=1) + + attn = attention(q, k, v, head_dim=self.head_dim, mode=self.mode) + ctx_attn_out = attn[:, :ctx_len] + x_attn_out = attn[:, ctx_len:] + + x = self.x_block.post_attention(x_attn_out, *x_intermediate) + if not self.context_block.pre_only: + context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate) + else: + context = None + return context, x + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + +class MMDiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size: int = 32, + patch_size: int = 2, + in_channels: int = 4, + depth: int = 28, + # hidden_size: Optional[int] = None, + # num_heads: Optional[int] = None, + mlp_ratio: float = 4.0, + learn_sigma: bool = False, + adm_in_channels: Optional[int] = None, + context_embedder_config: Optional[Dict] = None, + use_checkpoint: bool = False, + register_length: int = 0, + attn_mode: str = "torch", + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + out_channels: Optional[int] = None, + pos_embed_scaling_factor: Optional[float] = None, + pos_embed_offset: Optional[float] = None, + pos_embed_max_size: Optional[int] = None, + num_patches=None, + qk_norm: Optional[str] = None, + qkv_bias: bool = True, + context_processor_layers=None, + context_size=4096, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + default_out_channels = in_channels * 2 if learn_sigma else in_channels + self.out_channels = default(out_channels, default_out_channels) + self.patch_size = patch_size + self.pos_embed_scaling_factor = pos_embed_scaling_factor + self.pos_embed_offset = pos_embed_offset + self.pos_embed_max_size = pos_embed_max_size + self.gradient_checkpointing = use_checkpoint + + # hidden_size = default(hidden_size, 64 * depth) + # num_heads = default(num_heads, hidden_size // 64) + + # apply magic --> this defines a head_size of 64 + self.hidden_size = 64 * depth + num_heads = depth + + self.num_heads = num_heads + + self.x_embedder = PatchEmbed( + input_size, + patch_size, + in_channels, + self.hidden_size, + bias=True, + strict_img_size=self.pos_embed_max_size is None, + ) + self.t_embedder = TimestepEmbedding(self.hidden_size) + + self.y_embedder = None + if adm_in_channels is not None: + assert isinstance(adm_in_channels, int) + self.y_embedder = Embedder(adm_in_channels, self.hidden_size) + + if context_processor_layers is not None: + self.context_processor = Transformer(context_size, context_processor_layers, attn_mode) + else: + self.context_processor = None + + self.context_embedder = nn.Linear(context_size, self.hidden_size) + self.register_length = register_length + if self.register_length > 0: + self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size)) + + # num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + # just use a buffer already + if num_patches is not None: + self.register_buffer( + "pos_embed", + torch.empty(1, num_patches, self.hidden_size), + ) + else: + self.pos_embed = None + + self.use_checkpoint = use_checkpoint + self.joint_blocks = nn.ModuleList( + [ + MMDiTBlock( + self.hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + qkv_bias=qkv_bias, + pre_only=i == depth - 1, + rmsnorm=rmsnorm, + scale_mod_only=scale_mod_only, + swiglu=swiglu, + qk_norm=qk_norm, + ) + for i in range(depth) + ] + ) + for block in self.joint_blocks: + block.gradient_checkpointing = use_checkpoint + + self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels) + # self.initialize_weights() + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + for block in self.joint_blocks: + block.enable_gradient_checkpointing() + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + for block in self.joint_blocks: + block.disable_gradient_checkpointing() + + def initialize_weights(self): + # TODO: Init context_embedder? + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding + if self.pos_embed is not None: + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.pos_embed.shape[-2] ** 0.5), + scaling_factor=self.pos_embed_scaling_factor, + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + if getattr(self, "y_embedder", None) is not None: + nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.joint_blocks: + nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def cropped_pos_embed(self, h, w, device=None): + p = self.x_embedder.patch_size + # patched size + h = (h + 1) // p + w = (w + 1) // p + if self.pos_embed is None: + return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device) + assert self.pos_embed_max_size is not None + assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) + assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) + top = (self.pos_embed_max_size - h) // 2 + left = (self.pos_embed_max_size - w) // 2 + spatial_pos_embed = self.pos_embed.reshape( + 1, + self.pos_embed_max_size, + self.pos_embed_max_size, + self.pos_embed.shape[-1], + ) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + return spatial_pos_embed + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, D) tensor of class labels + """ + + if self.context_processor is not None: + context = self.context_processor(context) + + B, C, H, W = x.shape + x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device).to(dtype=x.dtype) + c = self.t_embedder(t, dtype=x.dtype) # (N, D) + if y is not None and self.y_embedder is not None: + y = self.y_embedder(y) # (N, D) + c = c + y # (N, D) + + if context is not None: + context = self.context_embedder(context) + + if self.register_length > 0: + context = torch.cat( + ( + einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), + default(context, torch.Tensor([]).type_as(x)), + ), + 1, + ) + + for block in self.joint_blocks: + context, x = block(context, x, c) + x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify + return x[:, :, :H, :W] + + +def create_mmdit_sd3_medium_configs(attn_mode: str): + # {'patch_size': 2, 'depth': 24, 'num_patches': 36864, + # 'pos_embed_max_size': 192, 'adm_in_channels': 2048, 'context_embedder': + # {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}}} + mmdit = MMDiT( + input_size=None, + pos_embed_max_size=192, + patch_size=2, + in_channels=16, + adm_in_channels=2048, + depth=24, + mlp_ratio=4, + qk_norm=None, + num_patches=36864, + context_size=4096, + attn_mode=attn_mode, + ) + return mmdit + + +# endregion + +# region VAE + + +def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) + + +class ResnetBlock(torch.nn.Module): + def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = Normalize(in_channels, dtype=dtype, device=device) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.norm2 = Normalize(out_channels, dtype=dtype, device=device) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + if self.in_channels != self.out_channels: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device + ) + else: + self.nin_shortcut = None + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + hidden = x + hidden = self.norm1(hidden) + hidden = self.swish(hidden) + hidden = self.conv1(hidden) + hidden = self.norm2(hidden) + hidden = self.swish(hidden) + hidden = self.conv2(hidden) + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + hidden + + +class AttnBlock(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.norm = Normalize(in_channels, dtype=dtype, device=device) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + + def forward(self, x): + hidden = self.norm(x) + q = self.q(hidden) + k = self.k(hidden) + v = self.v(hidden) + b, c, h, w = q.shape + q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)) + hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default + hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + hidden = self.proj_out(hidden) + return x + hidden + + +class Downsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class VAEEncoder(torch.nn.Module): + def __init__( + self, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = torch.nn.ModuleList() + for i_level in range(self.num_resolutions): + block = torch.nn.ModuleList() + attn = torch.nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + down = torch.nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, dtype=dtype, device=device) + self.down.append(down) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = self.swish(h) + h = self.conv_out(h) + return h + + +class VAEDecoder(torch.nn.Module): + def __init__( + self, + ch=128, + out_ch=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + resolution=256, + z_channels=16, + dtype=torch.float32, + device=None, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # upsampling + self.up = torch.nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = torch.nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + up = torch.nn.Module() + up.block = block + if i_level != 0: + up.upsample = Upsample(block_in, dtype=dtype, device=device) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, z): + # z to block_in + hidden = self.conv_in(z) + # middle + hidden = self.mid.block_1(hidden) + hidden = self.mid.attn_1(hidden) + hidden = self.mid.block_2(hidden) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hidden = self.up[i_level].block[i_block](hidden) + if i_level != 0: + hidden = self.up[i_level].upsample(hidden) + # end + hidden = self.norm_out(hidden) + hidden = self.swish(hidden) + hidden = self.conv_out(hidden) + return hidden + + +class SDVAE(torch.nn.Module): + def __init__(self, dtype=torch.float32, device=None): + super().__init__() + self.encoder = VAEEncoder(dtype=dtype, device=device) + self.decoder = VAEDecoder(dtype=dtype, device=device) + + @torch.autocast("cuda", dtype=torch.float16) + def decode(self, latent): + return self.decoder(latent) + + @torch.autocast("cuda", dtype=torch.float16) + def encode(self, image): + hidden = self.encoder(image) + mean, logvar = torch.chunk(hidden, 2, dim=1) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + + +# endregion + + +# region Text Encoder +class CLIPAttention(torch.nn.Module): + def __init__(self, embed_dim, heads, dtype, device, mode="xformers"): + super().__init__() + self.heads = heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.attn_mode = mode + + def set_attn_mode(self, mode): + self.attn_mode = mode + + def forward(self, x, mask=None): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + out = attention(q, k, v, self.heads, mask, mode=self.attn_mode) + return self.out_proj(out) + + +ACTIVATIONS = { + "quick_gelu": lambda: (lambda a: a * torch.sigmoid(1.702 * a)), + # "gelu": torch.nn.functional.gelu, + "gelu": lambda: nn.GELU(), +} + + +class CLIPLayer(torch.nn.Module): + def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) + self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + # # self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) + # self.mlp = Mlp( + # embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device + # ) + self.mlp = MLP(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation]) + self.mlp.to(device=device, dtype=dtype) + + def forward(self, x, mask=None): + x += self.self_attn(self.layer_norm1(x), mask) + x += self.mlp(self.layer_norm2(x)) + return x + + +class CLIPEncoder(torch.nn.Module): + def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layers = torch.nn.ModuleList( + [CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)] + ) + + def forward(self, x, mask=None, intermediate_output=None): + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layers) + intermediate_output + intermediate = None + for i, l in enumerate(self.layers): + x = l(x, mask) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + + +class CLIPEmbeddings(torch.nn.Module): + def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None): + super().__init__() + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) + self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens): + return self.token_embedding(input_tokens) + self.position_embedding.weight + + +class CLIPTextModel_(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + super().__init__() + self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) + self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) + self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True): + x = self.embeddings(input_tokens) + + if x.dtype == torch.bfloat16: + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=torch.float32, device=x.device).fill_(float("-inf")).triu_(1) + causal_mask = causal_mask.to(dtype=x.dtype) + else: + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + + x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output) + x = self.final_layer_norm(x) + if i is not None and final_layer_norm_intermediate: + i = self.final_layer_norm(i) + pooled_output = x[ + torch.arange(x.shape[0], device=x.device), + input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1), + ] + return x, i, pooled_output + + +class CLIPTextModel(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_hidden_layers"] + self.text_model = CLIPTextModel_(config_dict, dtype, device) + embed_dim = config_dict["hidden_size"] + self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + self.text_projection.weight.copy_(torch.eye(embed_dim)) + self.dtype = dtype + + def get_input_embeddings(self): + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, embeddings): + self.text_model.embeddings.token_embedding = embeddings + + def forward(self, *args, **kwargs): + x = self.text_model(*args, **kwargs) + out = self.text_projection(x[2]) + return (x[0], x[1], out, x[2]) + + +class ClipTokenWeightEncoder: + def encode_token_weights(self, token_weight_pairs): + tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + out, pooled = self([tokens]) + if pooled is not None: + first_pooled = pooled[0:1].cpu() + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2).cpu(), first_pooled + + +class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + device="cpu", + max_length=77, + layer="last", + layer_idx=None, + textmodel_json_config=None, + dtype=None, + model_class=CLIPTextModel, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, + layer_norm_hidden_state=True, + return_projected_pooled=True, + ): + super().__init__() + assert layer in self.LAYERS + self.transformer = model_class(textmodel_json_config, dtype, device) + self.num_layers = self.transformer.num_layers + self.max_length = max_length + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + self.layer = layer + self.layer_idx = None + self.special_tokens = special_tokens + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = return_projected_pooled + if layer == "hidden": + assert layer_idx is not None + assert abs(layer_idx) < self.num_layers + self.set_clip_options({"layer": layer_idx}) + self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) + + def set_attn_mode(self, mode): + raise NotImplementedError("This model does not support setting the attention mode") + + def set_clip_options(self, options): + layer_idx = options.get("layer", self.layer_idx) + self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + if layer_idx is None or abs(layer_idx) > self.num_layers: + self.layer = "last" + else: + self.layer = "hidden" + self.layer_idx = layer_idx + + def forward(self, tokens): + backup_embeds = self.transformer.get_input_embeddings() + device = backup_embeds.weight.device + tokens = torch.LongTensor(tokens).to(device) + outputs = self.transformer( + tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state + ) + self.transformer.set_input_embeddings(backup_embeds) + if self.layer == "last": + z = outputs[0] + else: + z = outputs[1] + pooled_output = None + if len(outputs) >= 3: + if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() + return z.float(), pooled_output + + def set_attn_mode(self, mode): + clip_text_model = self.transformer.text_model + for layer in clip_text_model.encoder.layers: + layer.self_attn.set_attn_mode(mode) + + +class SDXLClipG(SDClipModel): + """Wraps the CLIP-G model into the SD-CLIP-Model interface""" + + def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None): + if layer == "penultimate": + layer = "hidden" + layer_idx = -2 + super().__init__( + device=device, + layer=layer, + layer_idx=layer_idx, + textmodel_json_config=config, + dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 0}, + layer_norm_hidden_state=False, + ) + + def set_attn_mode(self, mode): + clip_text_model = self.transformer.text_model + for layer in clip_text_model.encoder.layers: + layer.self_attn.set_attn_mode(mode) + + +class T5XXLModel(SDClipModel): + """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience""" + + def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None): + super().__init__( + device=device, + layer=layer, + layer_idx=layer_idx, + textmodel_json_config=config, + dtype=dtype, + special_tokens={"end": 1, "pad": 0}, + model_class=T5, + ) + + def set_attn_mode(self, mode): + t5: T5 = self.transformer + for t5block in t5.encoder.block: + t5block: T5Block + t5layer: T5LayerSelfAttention = t5block.layer[0] + t5SaSa: T5Attention = t5layer.SelfAttention + t5SaSa.set_attn_mode(mode) + + +################################################################################################# +### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl +################################################################################################# + + +class T5XXLTokenizer(SDTokenizer): + """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + + def __init__(self): + super().__init__( + pad_with_end=False, + tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), + has_start_token=False, + pad_to_max_length=False, + max_length=99999999, + min_length=77, + ) + + +class T5LayerNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) + self.variance_epsilon = eps + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + return self.weight.to(device=x.device, dtype=x.dtype) * x + + +class T5DenseGatedActDense(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") + hidden_linear = self.wi_1(x) + x = hidden_gelu * hidden_linear + x = self.wo(x) + return x + + +class T5LayerFF(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x): + forwarded_states = self.layer_norm(x) + forwarded_states = self.DenseReluDense(forwarded_states) + x += forwarded_states + return x + + +class T5Attention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) + self.num_heads = num_heads + self.relative_attention_bias = None + if relative_attention_bias: + self.relative_attention_num_buckets = 32 + self.relative_attention_max_distance = 128 + self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) + + self.attn_mode = "xformers" # TODO 何とかする + + def set_attn_mode(self, mode): + self.attn_mode = mode + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=True, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward(self, x, past_bias=None): + q = self.q(x) + k = self.k(x) + v = self.v(x) + if self.relative_attention_bias is not None: + past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) + if past_bias is not None: + mask = past_bias + out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask, mode=self.attn_mode) + return self.o(out), past_bias + + +class T5LayerSelfAttention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x, past_bias=None): + output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias) + x += output + return x, past_bias + + +class T5Block(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.layer = torch.nn.ModuleList() + self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device)) + self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device)) + + def forward(self, x, past_bias=None): + x, past_bias = self.layer[0](x, past_bias) + x = self.layer[-1](x) + return x, past_bias + + +class T5Stack(torch.nn.Module): + def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device): + super().__init__() + self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) + self.block = torch.nn.ModuleList( + [ + T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) + for i in range(num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): + intermediate = None + x = self.embed_tokens(input_ids) + past_bias = None + for i, l in enumerate(self.block): + # print(i, x.mean(), x.std()) + x, past_bias = l(x, past_bias) + if i == intermediate_output: + intermediate = x.clone() + # print(x.mean(), x.std()) + x = self.final_layer_norm(x) + if intermediate is not None and final_layer_norm_intermediate: + intermediate = self.final_layer_norm(intermediate) + # print(x.mean(), x.std()) + return x, intermediate + + +class T5(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_layers"] + self.encoder = T5Stack( + self.num_layers, + config_dict["d_model"], + config_dict["d_model"], + config_dict["d_ff"], + config_dict["num_heads"], + config_dict["vocab_size"], + dtype, + device, + ) + self.dtype = dtype + + def get_input_embeddings(self): + return self.encoder.embed_tokens + + def set_input_embeddings(self, embeddings): + self.encoder.embed_tokens = embeddings + + def forward(self, *args, **kwargs): + return self.encoder(*args, **kwargs) + + +def create_clip_l(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None): + r""" + state_dict is not loaded, but updated with missing keys + """ + CLIPL_CONFIG = { + "hidden_act": "quick_gelu", + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, + } + with torch.no_grad(): + clip_l = SDClipModel( + layer="hidden", + layer_idx=-2, + device=device, + dtype=dtype, + layer_norm_hidden_state=False, + return_projected_pooled=False, + textmodel_json_config=CLIPL_CONFIG, + ) + if state_dict is not None: + # update state_dict if provided to include logit_scale and text_projection.weight avoid errors + if "logit_scale" not in state_dict: + state_dict["logit_scale"] = clip_l.logit_scale + if "transformer.text_projection.weight" not in state_dict: + state_dict["transformer.text_projection.weight"] = clip_l.transformer.text_projection.weight + return clip_l + + +def create_clip_g(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None): + r""" + state_dict is not loaded, but updated with missing keys + """ + CLIPG_CONFIG = { + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, + } + with torch.no_grad(): + clip_g = SDXLClipG(CLIPG_CONFIG, device=device, dtype=dtype) + if state_dict is not None: + if "logit_scale" not in state_dict: + state_dict["logit_scale"] = clip_g.logit_scale + return clip_g + + +def create_t5xxl(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> T5XXLModel: + T5_CONFIG = {"d_ff": 10240, "d_model": 4096, "num_heads": 64, "num_layers": 24, "vocab_size": 32128} + with torch.no_grad(): + t5 = T5XXLModel(T5_CONFIG, dtype=dtype, device=device) + if state_dict is not None: + if "logit_scale" not in state_dict: + state_dict["logit_scale"] = t5.logit_scale + if "transformer.shared.weight" in state_dict: + state_dict.pop("transformer.shared.weight") + return t5 + + +# endregion diff --git a/library/sd3_utils.py b/library/sd3_utils.py new file mode 100644 index 000000000..6f8c361fd --- /dev/null +++ b/library/sd3_utils.py @@ -0,0 +1,113 @@ +import math +from typing import Dict +import torch + +from library import sd3_models + + +def get_cond( + prompt: str, + tokenizer: sd3_models.SD3Tokenizer, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: sd3_models.T5XXLModel, +): + l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) + l_out, l_pooled = clip_l.encode_token_weights(l_tokens) + g_out, g_pooled = clip_g.encode_token_weights(g_tokens) + lg_out = torch.cat([l_out, g_out], dim=-1) + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + + if t5_tokens is None: + t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device) + else: + t5_out, t5_pooled = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None + t5_out = t5_out.to(lg_out.dtype) + + return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1) + + +# used if other sd3 models is available +r""" +def get_sd3_configs(state_dict: Dict): + # Important configuration values can be quickly determined by checking shapes in the source file + # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change) + # prefix = "model.diffusion_model." + prefix = "" + + patch_size = state_dict[prefix + "x_embedder.proj.weight"].shape[2] + depth = state_dict[prefix + "x_embedder.proj.weight"].shape[0] // 64 + num_patches = state_dict[prefix + "pos_embed"].shape[1] + pos_embed_max_size = round(math.sqrt(num_patches)) + adm_in_channels = state_dict[prefix + "y_embedder.mlp.0.weight"].shape[1] + context_shape = state_dict[prefix + "context_embedder.weight"].shape + context_embedder_config = { + "target": "torch.nn.Linear", + "params": {"in_features": context_shape[1], "out_features": context_shape[0]}, + } + return { + "patch_size": patch_size, + "depth": depth, + "num_patches": num_patches, + "pos_embed_max_size": pos_embed_max_size, + "adm_in_channels": adm_in_channels, + "context_embedder": context_embedder_config, + } + + +def create_mmdit_from_sd3_checkpoint(state_dict: Dict, attn_mode: str = "xformers"): + "" + Doesn't load state dict. + "" + sd3_configs = get_sd3_configs(state_dict) + + mmdit = sd3_models.MMDiT( + input_size=None, + pos_embed_max_size=sd3_configs["pos_embed_max_size"], + patch_size=sd3_configs["patch_size"], + in_channels=16, + adm_in_channels=sd3_configs["adm_in_channels"], + depth=sd3_configs["depth"], + mlp_ratio=4, + qk_norm=None, + num_patches=sd3_configs["num_patches"], + context_size=4096, + attn_mode=attn_mode, + ) + return mmdit +""" + + +class ModelSamplingDiscreteFlow: + """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models""" + + def __init__(self, shift=1.0): + self.shift = shift + timesteps = 1000 + self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1)) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return sigma * 1000 + + def sigma(self, timestep: torch.Tensor): + timestep = timestep / 1000.0 + if self.shift == 1.0: + return timestep + return self.shift * timestep / (1 + (self.shift - 1) * timestep) + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + # assert max_denoise is False, "max_denoise not implemented" + # max_denoise is always True, I'm not sure why it's there + return sigma * noise + (1.0 - sigma) * latent_image diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py new file mode 100644 index 000000000..e14f784d4 --- /dev/null +++ b/sd3_minimal_inference.py @@ -0,0 +1,347 @@ +# Minimum Inference Code for SD3 + +import argparse +import datetime +import math +import os +import random +from typing import Optional, Tuple +import numpy as np + +import torch +from safetensors.torch import safe_open, load_file +from tqdm import tqdm +from PIL import Image + +from library.device_utils import init_ipex, get_preferred_device + +init_ipex() + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sd3_models, sd3_utils + + +def get_noise(seed, latent): + generator = torch.manual_seed(seed) + return torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu").to(latent.dtype) + + +def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): + start = sampling.timestep(sampling.sigma_max) + end = sampling.timestep(sampling.sigma_min) + timesteps = torch.linspace(start, end, steps) + sigs = [] + for x in range(len(timesteps)): + ts = timesteps[x] + sigs.append(sampling.sigma(ts)) + sigs += [0.0] + return torch.FloatTensor(sigs) + + +def max_denoise(model_sampling, sigmas): + max_sigma = float(model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma + + +def do_sample( + height: int, + width: int, + initial_latent: Optional[torch.Tensor], + seed: int, + cond: Tuple[torch.Tensor, torch.Tensor], + neg_cond: Tuple[torch.Tensor, torch.Tensor], + mmdit: sd3_models.MMDiT, + steps: int, + guidance_scale: float, + dtype: torch.dtype, + device: str, +): + if initial_latent is None: + latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 + else: + latent = initial_latent + + latent = latent.to(dtype).to(device) + + noise = get_noise(seed, latent).to(device) + + model_sampling = sd3_utils.ModelSamplingDiscreteFlow() + + sigmas = get_sigmas(model_sampling, steps).to(device) + # sigmas = sigmas[int(steps * (1 - denoise)) :] # do not support i2i + + # conditioning = fix_cond(conditioning) + # neg_cond = fix_cond(neg_cond) + # extra_args = {"cond": cond, "uncond": neg_cond, "cond_scale": guidance_scale} + + noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) + + c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype) + y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype) + + x = noise_scaled.to(device).to(dtype) + # print(x.shape) + + with torch.no_grad(): + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] + + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) + + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * guidance_scale + # print(denoised.shape) + + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims + + dt = sigmas[i + 1] - sigma_hat + + # Euler method + x = x + d * dt + x = x.to(dtype) + + latent = x + scale_factor = 1.5305 + shift_factor = 0.0609 + # def process_out(self, latent): + # return (latent / self.scale_factor) + self.shift_factor + latent = (latent / scale_factor) + shift_factor + return latent + + +if __name__ == "__main__": + target_height = 1024 + target_width = 1024 + + # steps = 50 # 28 # 50 + guidance_scale = 5 + # seed = 1 # None # 1 + + device = get_preferred_device() + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--clip_g", type=str, required=False) + parser.add_argument("--clip_l", type=str, required=False) + parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--prompt", type=str, default="A photo of a cat") + # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders + parser.add_argument("--negative_prompt", type=str, default="") + parser.add_argument("--output_dir", type=str, default=".") + parser.add_argument("--do_not_use_t5xxl", action="store_true") + parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch") + parser.add_argument("--fp16", action="store_true") + parser.add_argument("--bf16", action="store_true") + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--steps", type=int, default=50) + # parser.add_argument( + # "--lora_weights", + # type=str, + # nargs="*", + # default=[], + # help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)", + # ) + # parser.add_argument("--interactive", action="store_true") + args = parser.parse_args() + + seed = args.seed + steps = args.steps + + sd3_dtype = torch.float32 + if args.fp16: + sd3_dtype = torch.float16 + elif args.bf16: + sd3_dtype = torch.bfloat16 + + # TODO test with separated safetenors files for each model + + # load state dict + logger.info(f"Loading SD3 models from {args.ckpt_path}...") + state_dict = load_file(args.ckpt_path) + + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k, v in list(state_dict.items()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + else: + logger.info(f"Lodaing clip_g from {args.clip_g}...") + clip_g_sd = load_file(args.clip_g) + for key in list(clip_g_sd.keys()): + clip_g_sd["transformer." + key] = clip_g_sd.pop(key) + + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k, v in list(state_dict.items()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + else: + logger.info(f"Lodaing clip_l from {args.clip_l}...") + clip_l_sd = load_file(args.clip_l) + for key in list(clip_l_sd.keys()): + clip_l_sd["transformer." + key] = clip_l_sd.pop(key) + + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + if not args.do_not_use_t5xxl: + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k, v in list(state_dict.items()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + else: + logger.info("but not used") + for key in list(state_dict.keys()): + if key.startswith("text_encoders.t5xxl."): + state_dict.pop(key) + t5xxl_sd = None + elif args.t5xxl: + assert not args.do_not_use_t5xxl, "t5xxl is not used but specified" + logger.info(f"Lodaing t5xxl from {args.t5xxl}...") + t5xxl_sd = load_file(args.t5xxl) + for key in list(t5xxl_sd.keys()): + t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) + else: + logger.info("t5xxl is not used") + t5xxl_sd = None + + use_t5xxl = t5xxl_sd is not None + + # MMDiT and VAE + vae_sd = {} + vae_prefix = "first_stage_model." + mmdit_prefix = "model.diffusion_model." + for k, v in list(state_dict.items()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + elif k.startswith(mmdit_prefix): + state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k) + + # load tokenizers + logger.info("Loading tokenizers...") + tokenizer = sd3_models.SD3Tokenizer(use_t5xxl) # combined tokenizer + + # load models + # logger.info("Create MMDiT from SD3 checkpoint...") + # mmdit = sd3_utils.create_mmdit_from_sd3_checkpoint(state_dict) + logger.info("Create MMDiT") + mmdit = sd3_models.create_mmdit_sd3_medium_configs(args.attn_mode) + + logger.info("Loading state dict...") + info = mmdit.load_state_dict(state_dict) + logger.info(f"Loaded MMDiT: {info}") + + logger.info(f"Move MMDiT to {device} and {sd3_dtype}...") + mmdit.to(device, dtype=sd3_dtype) + mmdit.eval() + + # load VAE + logger.info("Create VAE") + vae = sd3_models.SDVAE() + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + + logger.info(f"Move VAE to {device} and {sd3_dtype}...") + vae.to(device, dtype=sd3_dtype) + vae.eval() + + # load text encoders + logger.info("Create clip_l") + clip_l = sd3_models.create_clip_l(device, sd3_dtype, clip_l_sd) + + logger.info("Loading state dict...") + info = clip_l.load_state_dict(clip_l_sd) + logger.info(f"Loaded clip_l: {info}") + + logger.info(f"Move clip_l to {device} and {sd3_dtype}...") + clip_l.to(device, dtype=sd3_dtype) + clip_l.eval() + logger.info(f"Set attn_mode to {args.attn_mode}...") + clip_l.set_attn_mode(args.attn_mode) + + logger.info("Create clip_g") + clip_g = sd3_models.create_clip_g(device, sd3_dtype, clip_g_sd) + + logger.info("Loading state dict...") + info = clip_g.load_state_dict(clip_g_sd) + logger.info(f"Loaded clip_g: {info}") + + logger.info(f"Move clip_g to {device} and {sd3_dtype}...") + clip_g.to(device, dtype=sd3_dtype) + clip_g.eval() + logger.info(f"Set attn_mode to {args.attn_mode}...") + clip_g.set_attn_mode(args.attn_mode) + + if use_t5xxl: + logger.info("Create t5xxl") + t5xxl = sd3_models.create_t5xxl(device, sd3_dtype, t5xxl_sd) + + logger.info("Loading state dict...") + info = t5xxl.load_state_dict(t5xxl_sd) + logger.info(f"Loaded t5xxl: {info}") + + logger.info(f"Move t5xxl to {device} and {sd3_dtype}...") + t5xxl.to(device, dtype=sd3_dtype) + # t5xxl.to("cpu", dtype=torch.float32) # run on CPU + t5xxl.eval() + logger.info(f"Set attn_mode to {args.attn_mode}...") + t5xxl.set_attn_mode(args.attn_mode) + else: + t5xxl = None + + # prepare embeddings + logger.info("Encoding prompts...") + # embeds, pooled_embed + cond = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl) + neg_cond = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl) + + # generate image + logger.info("Generating image...") + latent_sampled = do_sample( + target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, guidance_scale, sd3_dtype, device + ) + + # latent to image + with torch.no_grad(): + image = vae.decode(latent_sampled) + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) + decoded_np = decoded_np.astype(np.uint8) + out_image = Image.fromarray(decoded_np) + + # save image + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") + out_image.save(output_path) + + logger.info(f"Saved image to {output_path}") From d53ea22b2a8366e6bc9f14aaeec057cd817f60d3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 23 Jun 2024 23:38:20 +0900 Subject: [PATCH 002/163] sd3 training --- README.md | 25 + library/sai_model_spec.py | 20 +- library/sd3_models.py | 102 ++++- library/sd3_train_utils.py | 544 ++++++++++++++++++++++ library/sd3_utils.py | 211 ++++++++- library/train_util.py | 137 +++++- sd3_minimal_inference.py | 7 +- sd3_train.py | 907 +++++++++++++++++++++++++++++++++++++ 8 files changed, 1909 insertions(+), 44 deletions(-) create mode 100644 library/sd3_train_utils.py create mode 100644 sd3_train.py diff --git a/README.md b/README.md index 946df58f3..34aa2bb2f 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,30 @@ This repository contains training, generation and utility scripts for Stable Diffusion. +## SD3 training + +SD3 training is done with `sd3_train.py`. + +`optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently. + +`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. + +t5xxl doesn't seem to work with `fp16`, so use`bf16` or `fp32`. + +There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. + +```toml +learning_rate = 1e-5 # seems to be too high +optimizer_type = "adafactor" +optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] +cache_text_encoder_outputs = true +cache_text_encoder_outputs_to_disk = true +vae_batch_size = 1 +cache_latents = true +cache_latents_to_disk = true +``` + +--- + [__Change History__](#change-history) is moved to the bottom of the page. 更新履歴は[ページ末尾](#change-history)に移しました。 diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index a63bd82ec..f7bf644d7 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -6,8 +6,10 @@ from typing import List, Optional, Tuple, Union import safetensors from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) r""" @@ -55,11 +57,14 @@ ARCH_SD_V2_512 = "stable-diffusion-v2-512" ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" +ARCH_SD3_M = "stable-diffusion-3-medium" +ARCH_SD3_UNKNOWN = "stable-diffusion-3" ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" +IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI" IMPL_DIFFUSERS = "diffusers" PRED_TYPE_EPSILON = "epsilon" @@ -113,7 +118,11 @@ def build_metadata( merged_from: Optional[str] = None, timesteps: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, + sd3: str = None, ): + """ + sd3: only supports "m" + """ # if state_dict is None, hash is not calculated metadata = {} @@ -126,6 +135,11 @@ def build_metadata( if sdxl: arch = ARCH_SD_XL_V1_BASE + elif sd3 is not None: + if sd3 == "m": + arch = ARCH_SD3_M + else: + arch = ARCH_SD3_UNKNOWN elif v2: if v_parameterization: arch = ARCH_SD_V2_768_V @@ -142,7 +156,7 @@ def build_metadata( metadata["modelspec.architecture"] = arch if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: - is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion + is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA @@ -236,7 +250,7 @@ def build_metadata( # assert all([v is not None for v in metadata.values()]), metadata if not all([v is not None for v in metadata.values()]): logger.error(f"Internal error: some metadata values are None: {metadata}") - + return metadata @@ -250,7 +264,7 @@ def get_title(metadata: dict) -> Optional[str]: def load_metadata_from_safetensors(model: str) -> dict: if not model.endswith(".safetensors"): return {} - + with safetensors.safe_open(model, framework="pt") as f: metadata = f.metadata() if metadata is None: diff --git a/library/sd3_models.py b/library/sd3_models.py index 294a69b06..a4fe400e3 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1,11 +1,13 @@ -# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref +# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref # the original code is licensed under the MIT License # and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! +from ast import Tuple from functools import partial import math -from typing import Dict, Optional +from types import SimpleNamespace +from typing import Dict, List, Optional, Union import einops import numpy as np import torch @@ -106,6 +108,8 @@ def __init__(self, t5xxl=True): self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) self.clip_g = SDXLClipGTokenizer(clip_tokenizer) self.t5xxl = T5XXLTokenizer() if t5xxl else None + # t5xxl has 99999999 max length, clip has 77 + self.model_max_length = self.clip_l.max_length # 77 def tokenize_with_weights(self, text: str): return ( @@ -870,6 +874,10 @@ def __init__( self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels) # self.initialize_weights() + @property + def model_type(self): + return "m" # only support medium + def enable_gradient_checkpointing(self): self.gradient_checkpointing = True for block in self.joint_blocks: @@ -1013,6 +1021,10 @@ def create_mmdit_sd3_medium_configs(attn_mode: str): # endregion # region VAE +# TODO support xformers + +VAE_SCALE_FACTOR = 1.5305 +VAE_SHIFT_FACTOR = 0.0609 def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None): @@ -1222,6 +1234,14 @@ def __init__(self, dtype=torch.float32, device=None): self.encoder = VAEEncoder(dtype=dtype, device=device) self.decoder = VAEDecoder(dtype=dtype, device=device) + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + @torch.autocast("cuda", dtype=torch.float16) def decode(self, latent): return self.decoder(latent) @@ -1234,6 +1254,43 @@ def encode(self, image): std = torch.exp(0.5 * logvar) return mean + std * torch.randn_like(mean) + @staticmethod + def process_in(latent): + return (latent - VAE_SHIFT_FACTOR) * VAE_SCALE_FACTOR + + @staticmethod + def process_out(latent): + return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR + + +class VAEOutput: + def __init__(self, latent): + self.latent = latent + + @property + def latent_dist(self): + return self + + def sample(self): + return self.latent + + +class VAEWrapper: + def __init__(self, vae): + self.vae = vae + + @property + def device(self): + return self.vae.device + + @property + def dtype(self): + return self.vae.dtype + + # latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + def encode(self, image): + return VAEOutput(self.vae.encode(image)) + # endregion @@ -1370,15 +1427,39 @@ def forward(self, *args, **kwargs): class ClipTokenWeightEncoder: - def encode_token_weights(self, token_weight_pairs): - tokens = list(map(lambda a: a[0], token_weight_pairs[0])) - out, pooled = self([tokens]) - if pooled is not None: - first_pooled = pooled[0:1].cpu() + # def encode_token_weights(self, token_weight_pairs): + # tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + # out, pooled = self([tokens]) + # if pooled is not None: + # first_pooled = pooled[0:1] + # else: + # first_pooled = pooled + # output = [out[0:1]] + # return torch.cat(output, dim=-2), first_pooled + + # fix to support batched inputs + # : Union[List[Tuple[torch.Tensor, torch.Tensor]], List[List[Tuple[torch.Tensor, torch.Tensor]]]] + def encode_token_weights(self, list_of_token_weight_pairs): + has_batch = isinstance(list_of_token_weight_pairs[0][0], list) + + if has_batch: + list_of_tokens = [] + for pairs in list_of_token_weight_pairs: + tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] + list_of_tokens.append(tokens) else: - first_pooled = pooled - output = [out[0:1]] - return torch.cat(output, dim=-2).cpu(), first_pooled + list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] + + out, pooled = self(list_of_tokens) + if has_batch: + return out, pooled + else: + if pooled is not None: + first_pooled = pooled[0:1] + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2), first_pooled class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): @@ -1694,6 +1775,7 @@ def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermed x = self.embed_tokens(input_ids) past_bias = None for i, l in enumerate(self.block): + # uncomment to debug layerwise output: fp16 may cause issues # print(i, x.mean(), x.std()) x, past_bias = l(x, past_bias) if i == intermediate_output: diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py new file mode 100644 index 000000000..4e45871f4 --- /dev/null +++ b/library/sd3_train_utils.py @@ -0,0 +1,544 @@ +import argparse +import math +import os +from typing import Optional, Tuple + +import torch +from safetensors.torch import save_file + +from library import sd3_models, sd3_utils, train_util +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate import init_empty_weights +from tqdm import tqdm + +# from transformers import CLIPTokenizer +# from library import model_util +# , sdxl_model_util, train_util, sdxl_original_unet +# from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from .sdxl_train_util import match_mixed_precision + + +def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype) -> Tuple[ + sd3_models.MMDiT, + Optional[sd3_models.SDClipModel], + Optional[sd3_models.SDXLClipG], + Optional[sd3_models.T5XXLModel], + sd3_models.SDVAE, +]: + model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 + + for pi in range(accelerator.state.num_processes): + if pi == accelerator.state.local_process_index: + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + + mmdit, clip_l, clip_g, t5xxl, vae = sd3_utils.load_models( + args.pretrained_model_name_or_path, + args.clip_l, + args.clip_g, + args.t5xxl, + args.vae, + attn_mode, + accelerator.device if args.lowram else "cpu", + weight_dtype, + args.disable_mmap_load_safetensors, + t5xxl_device, + t5xxl_dtype, + ) + + # work on low-ram device + if args.lowram: + if clip_l is not None: + clip_l.to(accelerator.device) + if clip_g is not None: + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(accelerator.device) + vae.to(accelerator.device) + mmdit.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + accelerator.wait_for_everyone() + + return mmdit, clip_l, clip_g, t5xxl, vae + + +def save_models( + ckpt_path: str, + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel], + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, +): + r""" + Save models to checkpoint file. Only supports unified checkpoint format. + """ + + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("model.diffusion_model.", mmdit.state_dict()) + update_sd("first_stage_model.", vae.state_dict()) + + if clip_l is not None: + update_sd("text_encoders.clip_l.", clip_l.state_dict()) + if clip_g is not None: + update_sd("text_encoders.clip_g.", clip_g.state_dict()) + if t5xxl is not None: + update_sd("text_encoders.t5xxl.", t5xxl.state_dict()) + + save_file(state_dict, ckpt_path, metadata=sai_metadata) + + +def save_sd3_model_on_train_end( + args: argparse.Namespace, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel], + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type + ) + save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_sd3_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel], + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type + ) + save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +def add_sd3_training_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) + + parser.add_argument( + "--clip_l", + type=str, + required=False, + help="CLIP-L model path. if not specified, use ckpt's state_dict / CLIP-Lモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--clip_g", + type=str, + required=False, + help="CLIP-G model path. if not specified, use ckpt's state_dict / CLIP-Gモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--t5xxl", + type=str, + required=False, + help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--save_clip", action="store_true", help="save CLIP models to checkpoint / CLIPモデルをチェックポイントに保存する" + ) + parser.add_argument( + "--save_t5xxl", action="store_true", help="save T5-XXL model to checkpoint / T5-XXLモデルをチェックポイントに保存する" + ) + + parser.add_argument( + "--t5xxl_device", + type=str, + default=None, + help="T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用", + ) + parser.add_argument( + "--t5xxl_dtype", + type=str, + default=None, + help="T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用", + ) + + # copy from Diffusers + parser.add_argument( + "--weighting_scheme", + type=str, + default="logit_normal", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + + +def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): + assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" + if args.v_parameterization: + logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") + + if args.clip_skip is not None: + logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") + + # if args.multires_noise_iterations: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" + # ) + # else: + # if args.noise_offset is None: + # args.noise_offset = DEFAULT_NOISE_OFFSET + # elif args.noise_offset != DEFAULT_NOISE_OFFSET: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" + # ) + # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + + assert ( + not hasattr(args, "weighted_captions") or not args.weighted_captions + ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + + if supportTextEncoderCaching: + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + args.cache_text_encoder_outputs = True + logger.warning( + "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" + ) + + +def sample_images(*args, **kwargs): + return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) + + +# region Diffusers + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils.torch_utils import randn_tensor +from diffusers.utils import BaseOutput + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + ): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps) + + sigmas = timesteps / self.config.num_train_timesteps + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + + timesteps = sigmas * self.config.num_train_timesteps + self.timesteps = timesteps.to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + + # if self.config.prediction_type == "vector_field": + + denoised = sample - model_output * sigma + # 2. Convert to an ODE derivative + derivative = (sample - denoised) / sigma_hat + + dt = self.sigmas[self.step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps + + +# endregion diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 6f8c361fd..c2c914123 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -1,30 +1,226 @@ import math -from typing import Dict +from typing import Dict, Optional, Union import torch +import safetensors +from safetensors.torch import load_file +from accelerate import init_empty_weights + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) from library import sd3_models +# TODO move some of functions to model_util.py +from library import sdxl_model_util + +# region models + + +def load_models( + ckpt_path: str, + clip_l_path: str, + clip_g_path: str, + t5xxl_path: str, + vae_path: str, + attn_mode: str, + device: Union[str, torch.device], + weight_dtype: torch.dtype, + disable_mmap: bool = False, + t5xxl_device: Optional[str] = None, + t5xxl_dtype: Optional[str] = None, +): + def load_state_dict(path: str, dvc: Union[str, torch.device] = device): + if disable_mmap: + return safetensors.torch.load(open(path, "rb").read()) + else: + try: + return load_file(path, device=dvc) + except: + return load_file(path) # prevent device invalid Error + + t5xxl_device = t5xxl_device or device + + logger.info(f"Loading SD3 models from {ckpt_path}...") + state_dict = load_state_dict(ckpt_path) + + # load clip_l + clip_l_sd = None + if clip_l_path: + logger.info(f"Loading clip_l from {clip_l_path}...") + clip_l_sd = load_state_dict(clip_l_path) + for key in list(clip_l_sd.keys()): + clip_l_sd["transformer." + key] = clip_l_sd.pop(key) + else: + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + + # load clip_g + clip_g_sd = None + if clip_g_path: + logger.info(f"Loading clip_g from {clip_g_path}...") + clip_g_sd = load_state_dict(clip_g_path) + for key in list(clip_g_sd.keys()): + clip_g_sd["transformer." + key] = clip_g_sd.pop(key) + else: + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + + # load t5xxl + t5xxl_sd = None + if t5xxl_path: + logger.info(f"Loading t5xxl from {t5xxl_path}...") + t5xxl_sd = load_state_dict(t5xxl_path, t5xxl_device) + for key in list(t5xxl_sd.keys()): + t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) + else: + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k in list(state_dict.keys()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + + # MMDiT and VAE + vae_sd = {} + if vae_path: + logger.info(f"Loading VAE from {vae_path}...") + vae_sd = load_state_dict(vae_path) + else: + # remove prefix "first_stage_model." + vae_sd = {} + vae_prefix = "first_stage_model." + for k in list(state_dict.keys()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + + mmdit_prefix = "model.diffusion_model." + for k in list(state_dict.keys()): + if k.startswith(mmdit_prefix): + state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k) + else: + state_dict.pop(k) # remove other keys + + # load MMDiT + logger.info("Building MMDit") + with init_empty_weights(): + mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) + + logger.info("Loading state dict...") + info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) + logger.info(f"Loaded MMDiT: {info}") + + # load ClipG and ClipL + if clip_l_sd is None: + clip_l = None + else: + logger.info("Building ClipL") + clip_l = sd3_models.create_clip_l(device, weight_dtype, clip_l_sd) + logger.info("Loading state dict...") + info = clip_l.load_state_dict(clip_l_sd) + logger.info(f"Loaded ClipL: {info}") + clip_l.set_attn_mode(attn_mode) + + if clip_g_sd is None: + clip_g = None + else: + logger.info("Building ClipG") + clip_g = sd3_models.create_clip_g(device, weight_dtype, clip_g_sd) + logger.info("Loading state dict...") + info = clip_g.load_state_dict(clip_g_sd) + logger.info(f"Loaded ClipG: {info}") + clip_g.set_attn_mode(attn_mode) + + # load T5XXL + if t5xxl_sd is None: + t5xxl = None + else: + logger.info("Building T5XXL") + t5xxl = sd3_models.create_t5xxl(t5xxl_device, t5xxl_dtype, t5xxl_sd) + logger.info("Loading state dict...") + info = t5xxl.load_state_dict(t5xxl_sd) + logger.info(f"Loaded T5XXL: {info}") + t5xxl.set_attn_mode(attn_mode) + + # load VAE + logger.info("Building VAE") + vae = sd3_models.SDVAE() + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + + return mmdit, clip_l, clip_g, t5xxl, vae + + +# endregion +# region utils + def get_cond( prompt: str, tokenizer: sd3_models.SD3Tokenizer, clip_l: sd3_models.SDClipModel, clip_g: sd3_models.SDXLClipG, - t5xxl: sd3_models.T5XXLModel, + t5xxl: Optional[sd3_models.T5XXLModel] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) + return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype) + + +def get_cond_from_tokens( + l_tokens, + g_tokens, + t5_tokens, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): l_out, l_pooled = clip_l.encode_token_weights(l_tokens) g_out, g_pooled = clip_g.encode_token_weights(g_tokens) lg_out = torch.cat([l_out, g_out], dim=-1) lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + if device is not None: + lg_out = lg_out.to(device=device) + l_pooled = l_pooled.to(device=device) + g_pooled = g_pooled.to(device=device) + if dtype is not None: + lg_out = lg_out.to(dtype=dtype) + l_pooled = l_pooled.to(dtype=dtype) + g_pooled = g_pooled.to(dtype=dtype) + # t5xxl may be in another device (eg. cpu) if t5_tokens is None: - t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device) + t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) else: - t5_out, t5_pooled = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None - t5_out = t5_out.to(lg_out.dtype) + t5_out, _ = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None + if device is not None: + t5_out = t5_out.to(device=device) + if dtype is not None: + t5_out = t5_out.to(dtype=dtype) - return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1) + # return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1) + return lg_out, t5_out, torch.cat((l_pooled, g_pooled), dim=-1) # used if other sd3 models is available @@ -111,3 +307,6 @@ def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): # assert max_denoise is False, "max_denoise not implemented" # max_denoise is always True, I'm not sure why it's there return sigma * noise + (1.0 - sigma) * latent_image + + +# endregion diff --git a/library/train_util.py b/library/train_util.py index 4736ff4ff..c67e8737c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -58,7 +58,7 @@ KDPM2AncestralDiscreteScheduler, AutoencoderKL, ) -from library import custom_train_functions +from library import custom_train_functions, sd3_utils from library.original_unet import UNet2DConditionModel from huggingface_hub import hf_hub_download import numpy as np @@ -135,6 +135,7 @@ ) TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" class ImageInfo: @@ -985,7 +986,7 @@ def is_text_encoder_output_cacheable(self): ] ) - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching latents.") @@ -1006,7 +1007,7 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # check disk cache exists and size of latents if cache_to_disk: - info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" + info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix if not is_main_process: # store to info only continue @@ -1040,14 +1041,43 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc for batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) - # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる - # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する - # SD1/2に対応するにはv2のフラグを持つ必要があるので後回し + # if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype + # this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset + # to support SD1/2, it needs a flag for v2, but it is postponed def cache_text_encoder_outputs( - self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True + self, tokenizers, text_encoders, device, output_dtype, cache_to_disk=False, is_main_process=True ): assert len(tokenizers) == 2, "only support SDXL" + return self.cache_text_encoder_outputs_common( + tokenizers, text_encoders, [device, device], output_dtype, [output_dtype], cache_to_disk, is_main_process + ) + # same as above, but for SD3 + def cache_text_encoder_outputs_sd3( + self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True + ): + return self.cache_text_encoder_outputs_common( + [tokenizer], + text_encoders, + devices, + output_dtype, + te_dtypes, + cache_to_disk, + is_main_process, + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3, + ) + + def cache_text_encoder_outputs_common( + self, + tokenizers, + text_encoders, + devices, + output_dtype, + te_dtypes, + cache_to_disk=False, + is_main_process=True, + file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX, + ): # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") @@ -1058,13 +1088,14 @@ def cache_text_encoder_outputs( for info in tqdm(image_infos): # subset = self.image_to_subset[info.image_key] if cache_to_disk: - te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX + te_out_npz = os.path.splitext(info.absolute_path)[0] + file_suffix info.text_encoder_outputs_npz = te_out_npz if not is_main_process: # store to info only continue if os.path.exists(te_out_npz): + # TODO check varidity of cache here continue image_infos_to_cache.append(info) @@ -1073,18 +1104,23 @@ def cache_text_encoder_outputs( return # prepare tokenizers and text encoders - for text_encoder in text_encoders: + for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes): text_encoder.to(device) - if weight_dtype is not None: - text_encoder.to(dtype=weight_dtype) + if te_dtype is not None: + text_encoder.to(dtype=te_dtype) # create batch + is_sd3 = len(tokenizers) == 1 batch = [] batches = [] for info in image_infos_to_cache: - input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) - input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) - batch.append((info, input_ids1, input_ids2)) + if not is_sd3: + input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) + input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) + batch.append((info, input_ids1, input_ids2)) + else: + l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption) + batch.append((info, l_tokens, g_tokens, t5_tokens)) if len(batch) >= self.batch_size: batches.append(batch) @@ -1095,13 +1131,32 @@ def cache_text_encoder_outputs( # iterate batches: call text encoder and cache outputs for memory or disk logger.info("caching text encoder outputs...") - for batch in tqdm(batches): - infos, input_ids1, input_ids2 = zip(*batch) - input_ids1 = torch.stack(input_ids1, dim=0) - input_ids2 = torch.stack(input_ids2, dim=0) - cache_batch_text_encoder_outputs( - infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype - ) + if not is_sd3: + for batch in tqdm(batches): + infos, input_ids1, input_ids2 = zip(*batch) + input_ids1 = torch.stack(input_ids1, dim=0) + input_ids2 = torch.stack(input_ids2, dim=0) + cache_batch_text_encoder_outputs( + infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, output_dtype + ) + else: + for batch in tqdm(batches): + infos, l_tokens, g_tokens, t5_tokens = zip(*batch) + + # stack tokens + # l_tokens = [tokens[0] for tokens in l_tokens] + # g_tokens = [tokens[0] for tokens in g_tokens] + # t5_tokens = [tokens[0] for tokens in t5_tokens] + + cache_batch_text_encoder_outputs_sd3( + infos, + tokenizers[0], + text_encoders, + self.max_token_length, + cache_to_disk, + (l_tokens, g_tokens, t5_tokens), + output_dtype, + ) def get_image_size(self, image_path): return imagesize.get(image_path) @@ -1332,6 +1387,7 @@ def __getitem__(self, index): captions.append(caption) if not self.token_padding_disabled: # this option might be omitted in future + # TODO get_input_ids must support SD3 if self.XTI_layers: token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) else: @@ -2140,10 +2196,10 @@ def enable_XTI(self, *args, **kwargs): for dataset in self.datasets: dataset.enable_XTI(*args, **kwargs) - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True @@ -2152,6 +2208,15 @@ def cache_text_encoder_outputs( logger.info(f"[Dataset {i}]") dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) + def cache_text_encoder_outputs_sd3( + self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True + ): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.cache_text_encoder_outputs_sd3( + tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process + ) + def set_caching_mode(self, caching_mode): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) @@ -2585,6 +2650,30 @@ def cache_batch_text_encoder_outputs( info.text_encoder_pool2 = pool2 +def cache_batch_text_encoder_outputs_sd3( + image_infos, tokenizer, text_encoders, max_token_length, cache_to_disk, input_ids, output_dtype +): + # make input_ids for each text encoder + l_tokens, g_tokens, t5_tokens = input_ids + + clip_l, clip_g, t5xxl = text_encoders + with torch.no_grad(): + b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens( + l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, "cpu", output_dtype + ) + b_lg_out = b_lg_out.detach() + b_t5_out = b_t5_out.detach() + b_pool = b_pool.detach() + + for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): + if cache_to_disk: + save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) + else: + info.text_encoder_outputs1 = lg_out + info.text_encoder_outputs2 = t5_out + info.text_encoder_pool2 = pool + + def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2): np.savez( npz_path, @@ -2907,6 +2996,7 @@ def get_sai_model_spec( lora: bool, textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA + sd3: str = None, ): timestamp = time.time() @@ -2940,6 +3030,7 @@ def get_sai_model_spec( tags=args.metadata_tags, timesteps=timesteps, clip_skip=args.clip_skip, # None or int + sd3=sd3, ) return metadata diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index e14f784d4..96e9da4ac 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -320,8 +320,11 @@ def do_sample( # prepare embeddings logger.info("Encoding prompts...") # embeds, pooled_embed - cond = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl) - neg_cond = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl) + lg_out, t5_out, pooled = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl) + cond = torch.cat([lg_out, t5_out], dim=-2), pooled + + lg_out, t5_out, pooled = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl) + neg_cond = torch.cat([lg_out, t5_out], dim=-2), pooled # generate image logger.info("Generating image...") diff --git a/sd3_train.py b/sd3_train.py new file mode 100644 index 000000000..0721b2ae4 --- /dev/null +++ b/sd3_train.py @@ -0,0 +1,907 @@ +# training with captions + +import argparse +import copy +import math +import os +from multiprocessing import Value +from typing import List +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + + +init_ipex() + +from accelerate.utils import set_seed +from diffusers import DDPMScheduler +from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils + +# , sdxl_model_util + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions + +# from library.custom_train_functions import ( +# apply_snr_weight, +# prepare_scheduler_for_custom_training, +# scale_v_prediction_loss_like_noise_prediction, +# add_v_prediction_like_loss, +# apply_debiased_estimation, +# apply_masked_loss, +# ) + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + assert ( + not args.weighted_captions + ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + assert ( + not args.train_text_encoder or not args.cache_text_encoder_outputs + ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + + # if args.block_lr: + # block_lrs = [float(lr) for lr in args.block_lr.split(",")] + # assert ( + # len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR + # ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください" + # else: + # block_lrs = None + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # load tokenizer + sd3_tokenizer = sd3_models.SD3Tokenizer() + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[sd3_tokenizer]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, [sd3_tokenizer]) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(8) # TODO これでいいか確認 + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = weight_dtype # torch.float32 if args.no_half_vae else weight_dtype # SD3 VAE works with fp16 + + t5xxl_dtype = weight_dtype + if args.t5xxl_dtype is not None: + if args.t5xxl_dtype == "fp16": + t5xxl_dtype = torch.float16 + elif args.t5xxl_dtype == "bf16": + t5xxl_dtype = torch.bfloat16 + elif args.t5xxl_dtype == "fp32" or args.t5xxl_dtype == "float": + t5xxl_dtype = torch.float32 + else: + raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") + t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device + + # モデルを読み込む + attn_mode = "xformers" if args.xformers else "torch" + + assert ( + attn_mode == "torch" + ), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + + mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( + args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype + ) + assert clip_l is not None, "clip_l is required / clip_lは必須です" + assert clip_g is not None, "clip_g is required / clip_gは必須です" + # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible + with torch.no_grad(): + train_dataset_group.cache_latents( + vae_wrapper, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process, file_suffix="_sd3.npz" + ) + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # 学習を準備する:モデルを適切な状態にする + if args.gradient_checkpointing: + mmdit.enable_gradient_checkpointing() + train_mmdit = args.learning_rate != 0 + train_clip_l = False + train_clip_g = False + train_t5xxl = False + + # if args.train_text_encoder: + # # TODO each option for two text encoders? + # accelerator.print("enable text encoder training") + # if args.gradient_checkpointing: + # text_encoder1.gradient_checkpointing_enable() + # text_encoder2.gradient_checkpointing_enable() + # lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + # lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train + # train_clip_l = lr_te1 != 0 + # train_clip_g = lr_te2 != 0 + + # # caching one text encoder output is not supported + # if not train_clip_l: + # text_encoder1.to(weight_dtype) + # if not train_clip_g: + # text_encoder2.to(weight_dtype) + # text_encoder1.requires_grad_(train_clip_l) + # text_encoder2.requires_grad_(train_clip_g) + # text_encoder1.train(train_clip_l) + # text_encoder2.train(train_clip_g) + # else: + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + clip_l.requires_grad_(False) + clip_g.requires_grad_(False) + clip_l.eval() + clip_g.eval() + if t5xxl is not None: + t5xxl.to(t5xxl_dtype) + t5xxl.requires_grad_(False) + t5xxl.eval() + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + + with torch.no_grad(), accelerator.autocast(): + train_dataset_group.cache_text_encoder_outputs_sd3( + sd3_tokenizer, + (clip_l, clip_g, t5xxl), + (accelerator.device, accelerator.device, t5xxl_device), + None, + (None, None, None), + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) + accelerator.wait_for_everyone() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared + + training_models = [] + params_to_optimize = [] + # if train_unet: + training_models.append(mmdit) + # if block_lrs is None: + params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate}) + # else: + # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) + + # if train_clip_l: + # training_models.append(text_encoder1) + # params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + # if train_clip_g: + # training_models.append(text_encoder2) + # params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"train mmdit: {train_mmdit}") # , text_encoder1: {train_clip_l}, text_encoder2: {train_clip_g}") + accelerator.print(f"number of models: {len(training_models)}") + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.fused_optimizer_groups: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # This balances memory usage and management complexity. + + # calculate total number of parameters + n_total_params = sum(len(params["params"]) for params in params_to_optimize) + params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) + + # split params into groups, keeping the learning rate the same for all params in a group + # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + grouped_params = [] + param_group = [] + param_group_lr = -1 + for group in params_to_optimize: + lr = group["lr"] + for p in group["params"]: + # if the learning rate is different for different params, start a new group + if lr != param_group_lr: + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = lr + + param_group.append(p) + + # if the group has enough parameters, start a new group + if len(param_group) == params_per_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = -1 + + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + # dataloaderを準備する + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + mmdit.to(weight_dtype) + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + if t5xxl is not None: + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + mmdit.to(weight_dtype) + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + if t5xxl is not None: + t5xxl.to(weight_dtype) + + # TODO check if this is necessary. SD3 uses pool for clip_l and clip_g + # # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer + # if train_clip_l: + # text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) + # text_encoder1.text_model.final_layer_norm.requires_grad_(False) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model( + args, + mmdit=mmdit, + # mmdie=mmdit if train_mmdit else None, + # text_encoder1=text_encoder1 if train_clip_l else None, + # text_encoder2=text_encoder2 if train_clip_g else None, + ) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # acceleratorがなんかよろしくやってくれるらしい + if train_mmdit: + mmdit = accelerator.prepare(mmdit) + # if train_clip_l: + # text_encoder1 = accelerator.prepare(text_encoder1) + # if train_clip_g: + # text_encoder2 = accelerator.prepare(text_encoder2) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + clip_l.to("cpu", dtype=torch.float32) + clip_g.to("cpu", dtype=torch.float32) + if t5xxl is not None: + t5xxl.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + # TODO support CPU for text encoders + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(accelerator.device) + + # TODO cache sample prompt's embeddings to free text encoder's memory + if args.cache_text_encoder_outputs: + if not args.save_t5xxl: + t5xxl = None # free memory + clean_memory_on_device(accelerator.device) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + elif args.fused_optimizer_groups: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def optimizer_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(optimizer_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + # noise_scheduler = DDPMScheduler( + # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + # ) + + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + # prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + # if args.zero_terminal_snr: + # custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + # # For --sample_at_first + # sd3_train_utils.sample_images( + # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], mmdit + # ) + + # following function will be moved to sd3_train_utils + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None + ): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + loss_recorder = train_util.LossRecorder() + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = vae.encode(batch["images"].to(vae_dtype)).to(weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + # latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + latents = sd3_models.SDVAE.process_in(latents) + + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + # not cached, get text encoder outputs + # XXX This does not work yet + input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl = batch["input_ids"] + with torch.set_grad_enabled(args.train_text_encoder): + # TODO support weighted captions + # TODO support length > 75 + input_ids_clip_l = input_ids_clip_l.to(accelerator.device) + input_ids_clip_g = input_ids_clip_g.to(accelerator.device) + input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) + + # get text encoder outputs: outputs are concatenated + context, pool = sd3_utils.get_cond_from_tokens( + input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl, clip_l, clip_g, t5xxl + ) + else: + # encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + # encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) + # pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + # TODO this reuses SDXL keys, it should be fixed + lg_out = batch["text_encoder_outputs1_list"] + t5_out = batch["text_encoder_outputs2_list"] + pool = batch["text_encoder_pool2_list"] + context = torch.cat([lg_out, t5_out], dim=-2) + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) + + # Add noise according to flow matching. + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + # call model + with accelerator.autocast(): + model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool) + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = latents + + # Compute regular loss. TODO simplify this + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.fused_optimizer_groups): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # sdxl_train_util.sample_images( + # accelerator, + # args, + # None, + # global_step, + # accelerator.device, + # vae, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # mmdit, + # ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + clip_l if args.save_clip else None, + clip_g if args.save_clip else None, + t5xxl if args.save_t5xxl else None, + mmdit, + vae, + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_mmdit) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + clip_l if args.save_clip else None, + clip_g if args.save_clip else None, + t5xxl if args.save_t5xxl else None, + mmdit, + vae, + ) + + # sdxl_train_util.sample_images( + # accelerator, + # args, + # epoch + 1, + # global_step, + # accelerator.device, + # vae, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # mmdit, + # ) + + is_main_process = accelerator.is_main_process + # if is_main_process: + mmdit = accelerator.unwrap_model(mmdit) + clip_l = accelerator.unwrap_model(clip_l) + clip_g = accelerator.unwrap_model(clip_g) + if t5xxl is not None: + t5xxl = accelerator.unwrap_model(t5xxl) + + accelerator.end_training() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + sd3_train_utils.save_sd3_model_on_train_end( + args, + save_dtype, + epoch, + global_step, + clip_l if args.save_clip else None, + clip_g if args.save_clip else None, + t5xxl if args.save_t5xxl else None, + mmdit, + vae, + ) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sd3_train_utils.add_sd3_training_arguments(parser) + + # TE training is disabled temporarily + + # parser.add_argument( + # "--learning_rate_te1", + # type=float, + # default=None, + # help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率", + # ) + # parser.add_argument( + # "--learning_rate_te2", + # type=float, + # default=None, + # help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", + # ) + + # parser.add_argument( + # "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + # ) + # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + # parser.add_argument( + # "--no_half_vae", + # action="store_true", + # help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + # ) + # parser.add_argument( + # "--block_lr", + # type=str, + # default=None, + # help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + # + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", + # ) + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) From 0fe4eafac996fa5139a311aadc86aca28ddc6930 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 24 Jun 2024 23:12:48 +0900 Subject: [PATCH 003/163] fix to use zero for initial latent --- sd3_minimal_inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index 96e9da4ac..7f5f28cea 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -64,7 +64,8 @@ def do_sample( device: str, ): if initial_latent is None: - latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 + # latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 + latent = torch.zeros(1, 16, height // 8, width // 8, device=device) else: latent = initial_latent From 4802e4aaec74429f733fae289e41c5618ebb0e92 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 24 Jun 2024 23:13:14 +0900 Subject: [PATCH 004/163] workaround for long caption ref #1382 --- library/sd3_models.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index a4fe400e3..c19aec6aa 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -56,7 +56,7 @@ def __init__( self.inv_vocab = {v: k for k, v in vocab.items()} self.max_word_length = 8 - def tokenize_with_weights(self, text: str): + def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None): """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" """ @@ -79,6 +79,14 @@ def tokenize_with_weights(self, text: str): batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) if self.min_length is not None and len(batch) < self.min_length: batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) + + # truncate to max_length + # print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}") + if truncate_to_max_length and len(batch) > self.max_length: + batch = batch[: self.max_length] + if truncate_length is not None and len(batch) > truncate_length: + batch = batch[:truncate_length] + return [batch] @@ -112,10 +120,15 @@ def __init__(self, t5xxl=True): self.model_max_length = self.clip_l.max_length # 77 def tokenize_with_weights(self, text: str): + # temporary truncate to max_length even for t5xxl return ( self.clip_l.tokenize_with_weights(text), self.clip_g.tokenize_with_weights(text), - self.t5xxl.tokenize_with_weights(text) if self.t5xxl is not None else None, + ( + self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length) + if self.t5xxl is not None + else None + ), ) From 8f2ba27869e4c5b9225a309aeed275a47d8eed6a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Jun 2024 20:36:22 +0900 Subject: [PATCH 005/163] support text_encoder_batch_size for caching --- library/sd3_train_utils.py | 7 +++++++ library/train_util.py | 14 ++++++++++---- sd3_train.py | 1 + 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 4e45871f4..70c83c0ba 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -173,6 +173,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): action="store_true", help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", ) + parser.add_argument( + "--text_encoder_batch_size", + type=int, + default=None, + help="text encoder batch size (default: None, use dataset's batch size)" + + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", + ) parser.add_argument( "--disable_mmap_load_safetensors", action="store_true", diff --git a/library/train_util.py b/library/train_util.py index c67e8737c..96d32e3bc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1054,7 +1054,7 @@ def cache_text_encoder_outputs( # same as above, but for SD3 def cache_text_encoder_outputs_sd3( - self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True + self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None ): return self.cache_text_encoder_outputs_common( [tokenizer], @@ -1065,6 +1065,7 @@ def cache_text_encoder_outputs_sd3( cache_to_disk, is_main_process, TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3, + batch_size, ) def cache_text_encoder_outputs_common( @@ -1077,10 +1078,15 @@ def cache_text_encoder_outputs_common( cache_to_disk=False, is_main_process=True, file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX, + batch_size=None, ): # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") + + if batch_size is None: + batch_size = self.batch_size + image_infos = list(self.image_data.values()) logger.info("checking cache existence...") @@ -1122,7 +1128,7 @@ def cache_text_encoder_outputs_common( l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption) batch.append((info, l_tokens, g_tokens, t5_tokens)) - if len(batch) >= self.batch_size: + if len(batch) >= batch_size: batches.append(batch) batch = [] @@ -2209,12 +2215,12 @@ def cache_text_encoder_outputs( dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) def cache_text_encoder_outputs_sd3( - self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True + self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None ): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") dataset.cache_text_encoder_outputs_sd3( - tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process + tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size ) def set_caching_mode(self, caching_mode): diff --git a/sd3_train.py b/sd3_train.py index 0721b2ae4..8216a62b3 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -254,6 +254,7 @@ def train(args): (None, None, None), args.cache_text_encoder_outputs_to_disk, accelerator.is_main_process, + args.text_encoder_batch_size, ) accelerator.wait_for_everyone() From 828a581e2968935c00d22e7e03ca32c1281aa5dd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Jun 2024 20:43:31 +0900 Subject: [PATCH 006/163] fix assertion for experimental impl ref #1389 --- sd3_train.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index 8216a62b3..ea9a11049 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -60,9 +60,19 @@ def train(args): assert ( not args.weighted_captions ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + # assert ( + # not args.train_text_encoder or not args.cache_text_encoder_outputs + # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + + # training text encoder is not supported + assert ( + not args.train_text_encoder + ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" + + # training without text encoder cache is not supported assert ( - not args.train_text_encoder or not args.cache_text_encoder_outputs - ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + args.cache_text_encoder_outputs + ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" # if args.block_lr: # block_lrs = [float(lr) for lr in args.block_lr.split(",")] From 381598c8bbd3d4e50ec4327fa27d5d0072ec2a67 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Jun 2024 21:15:02 +0900 Subject: [PATCH 007/163] fix resolution in metadata for sd3 --- library/sai_model_spec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index f7bf644d7..af073677e 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -216,7 +216,7 @@ def build_metadata( reso = (reso[0], reso[0]) else: # resolution is defined in dataset, so use default - if sdxl: + if sdxl or sd3 is not None: reso = 1024 elif v2 and v_parameterization: reso = 768 From 66cf43547972647389fbd2addb53cff2ab478660 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 27 Jun 2024 13:14:09 +0900 Subject: [PATCH 008/163] re-fix assertion ref #1389 --- sd3_train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index ea9a11049..b6c932c4c 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -64,10 +64,10 @@ def train(args): # not args.train_text_encoder or not args.cache_text_encoder_outputs # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" - # training text encoder is not supported - assert ( - not args.train_text_encoder - ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" + # # training text encoder is not supported + # assert ( + # not args.train_text_encoder + # ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" # training without text encoder cache is not supported assert ( From 19086465e8040c01c38d38eec5c53f966f0dad8b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 29 Jun 2024 17:21:25 +0900 Subject: [PATCH 009/163] Fix fp16 mixed precision, model is in bf16 without full_bf16 --- README.md | 11 +++++++-- library/sd3_train_utils.py | 10 +++++---- library/sd3_utils.py | 46 +++++++++++++++++++++++++++++++++----- sd3_train.py | 9 +++++--- 4 files changed, 61 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 34aa2bb2f..3eed636c5 100644 --- a/README.md +++ b/README.md @@ -4,21 +4,28 @@ This repository contains training, generation and utility scripts for Stable Dif SD3 training is done with `sd3_train.py`. +__Jun 29, 2024__: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). + +`fp16` and `bf16` are available for mixed precision training. We are not sure which is better. + `optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently. `clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. -t5xxl doesn't seem to work with `fp16`, so use`bf16` or `fp32`. +t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. +`text_encoder_batch_size` is added experimentally for caching faster. + ```toml -learning_rate = 1e-5 # seems to be too high +learning_rate = 1e-6 # seems to depend on the batch size optimizer_type = "adafactor" optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] cache_text_encoder_outputs = true cache_text_encoder_outputs_to_disk = true vae_batch_size = 1 +text_encoder_batch_size = 4 cache_latents = true cache_latents_to_disk = true ``` diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 70c83c0ba..c8d52e1c8 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -28,14 +28,14 @@ from .sdxl_train_util import match_mixed_precision -def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype) -> Tuple[ +def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[ sd3_models.MMDiT, Optional[sd3_models.SDClipModel], Optional[sd3_models.SDXLClipG], Optional[sd3_models.T5XXLModel], sd3_models.SDVAE, ]: - model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 + model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16 for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: @@ -49,13 +49,15 @@ def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, args.vae, attn_mode, accelerator.device if args.lowram else "cpu", - weight_dtype, + model_dtype, args.disable_mmap_load_safetensors, + clip_dtype, t5xxl_device, t5xxl_dtype, + vae_dtype, ) - # work on low-ram device + # work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device if args.lowram: if clip_l is not None: clip_l.to(accelerator.device) diff --git a/library/sd3_utils.py b/library/sd3_utils.py index c2c914123..45b49b04b 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -28,11 +28,41 @@ def load_models( vae_path: str, attn_mode: str, device: Union[str, torch.device], - weight_dtype: torch.dtype, + default_dtype: Optional[Union[str, torch.dtype]] = None, disable_mmap: bool = False, - t5xxl_device: Optional[str] = None, - t5xxl_dtype: Optional[str] = None, + clip_dtype: Optional[Union[str, torch.dtype]] = None, + t5xxl_device: Optional[Union[str, torch.device]] = None, + t5xxl_dtype: Optional[Union[str, torch.dtype]] = None, + vae_dtype: Optional[Union[str, torch.dtype]] = None, ): + """ + Load SD3 models from checkpoint files. + + Args: + ckpt_path: Path to the SD3 checkpoint file. + clip_l_path: Path to the clip_l checkpoint file. + clip_g_path: Path to the clip_g checkpoint file. + t5xxl_path: Path to the t5xxl checkpoint file. + vae_path: Path to the VAE checkpoint file. + attn_mode: Attention mode for MMDiT model. + device: Device for MMDiT model. + default_dtype: Default dtype for each model. In training, it's usually None. None means using float32. + disable_mmap: Disable memory mapping when loading state dict. + clip_dtype: Dtype for Clip models, or None to use default dtype. + t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device. + t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype. + vae_dtype: Dtype for VAE model, or None to use default dtype. + + Returns: + Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models. + """ + + # In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict. + # However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict. + # Therefore, we need clip_dtype and t5xxl_dtype. + + # default_dtype is used for full_fp16/full_bf16 training. + def load_state_dict(path: str, dvc: Union[str, torch.device] = device): if disable_mmap: return safetensors.torch.load(open(path, "rb").read()) @@ -43,6 +73,9 @@ def load_state_dict(path: str, dvc: Union[str, torch.device] = device): return load_file(path) # prevent device invalid Error t5xxl_device = t5xxl_device or device + clip_dtype = clip_dtype or default_dtype or torch.float32 + t5xxl_dtype = t5xxl_dtype or default_dtype or torch.float32 + vae_dtype = vae_dtype or default_dtype or torch.float32 logger.info(f"Loading SD3 models from {ckpt_path}...") state_dict = load_state_dict(ckpt_path) @@ -124,7 +157,7 @@ def load_state_dict(path: str, dvc: Union[str, torch.device] = device): mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) logger.info("Loading state dict...") - info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) + info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, default_dtype) logger.info(f"Loaded MMDiT: {info}") # load ClipG and ClipL @@ -132,7 +165,7 @@ def load_state_dict(path: str, dvc: Union[str, torch.device] = device): clip_l = None else: logger.info("Building ClipL") - clip_l = sd3_models.create_clip_l(device, weight_dtype, clip_l_sd) + clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) logger.info("Loading state dict...") info = clip_l.load_state_dict(clip_l_sd) logger.info(f"Loaded ClipL: {info}") @@ -142,7 +175,7 @@ def load_state_dict(path: str, dvc: Union[str, torch.device] = device): clip_g = None else: logger.info("Building ClipG") - clip_g = sd3_models.create_clip_g(device, weight_dtype, clip_g_sd) + clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) logger.info("Loading state dict...") info = clip_g.load_state_dict(clip_g_sd) logger.info(f"Loaded ClipG: {info}") @@ -165,6 +198,7 @@ def load_state_dict(path: str, dvc: Union[str, torch.device] = device): logger.info("Loading state dict...") info = vae.load_state_dict(vae_sd) logger.info(f"Loaded VAE: {info}") + vae.to(device=device, dtype=vae_dtype) return mmdit, clip_l, clip_g, t5xxl, vae diff --git a/sd3_train.py b/sd3_train.py index b6c932c4c..bd30cdc72 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -182,6 +182,8 @@ def train(args): raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device + clip_dtype = weight_dtype # if not args.train_text_encoder else None + # モデルを読み込む attn_mode = "xformers" if args.xformers else "torch" @@ -189,8 +191,9 @@ def train(args): attn_mode == "torch" ), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( - args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype + args, accelerator, attn_mode, None, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype ) assert clip_l is not None, "clip_l is required / clip_lは必須です" assert clip_g is not None, "clip_g is required / clip_gは必須です" @@ -868,8 +871,9 @@ def setup_parser() -> argparse.ArgumentParser: custom_train_functions.add_custom_train_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) - # TE training is disabled temporarily + # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + # TE training is disabled temporarily # parser.add_argument( # "--learning_rate_te1", # type=float, @@ -886,7 +890,6 @@ def setup_parser() -> argparse.ArgumentParser: # parser.add_argument( # "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" # ) - # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") # parser.add_argument( # "--no_half_vae", # action="store_true", From ea18d5ba6d856995d5c44be4b449b63ac66fe5db Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 29 Jun 2024 17:45:50 +0900 Subject: [PATCH 010/163] Fix to work full_bf16 and full_fp16. --- library/sd3_models.py | 8 ++++++++ library/sd3_utils.py | 14 ++++++-------- sd3_train.py | 20 ++++++++++---------- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index c19aec6aa..7041420cb 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -891,6 +891,14 @@ def __init__( def model_type(self): return "m" # only support medium + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + def enable_gradient_checkpointing(self): self.gradient_checkpointing = True for block in self.joint_blocks: diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 45b49b04b..9dc9e7967 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -28,7 +28,7 @@ def load_models( vae_path: str, attn_mode: str, device: Union[str, torch.device], - default_dtype: Optional[Union[str, torch.dtype]] = None, + weight_dtype: Optional[Union[str, torch.dtype]] = None, disable_mmap: bool = False, clip_dtype: Optional[Union[str, torch.dtype]] = None, t5xxl_device: Optional[Union[str, torch.device]] = None, @@ -46,7 +46,7 @@ def load_models( vae_path: Path to the VAE checkpoint file. attn_mode: Attention mode for MMDiT model. device: Device for MMDiT model. - default_dtype: Default dtype for each model. In training, it's usually None. None means using float32. + weight_dtype: Default dtype of weights for all models. This is weight dtype, so the model dtype may be different. disable_mmap: Disable memory mapping when loading state dict. clip_dtype: Dtype for Clip models, or None to use default dtype. t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device. @@ -61,8 +61,6 @@ def load_models( # However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict. # Therefore, we need clip_dtype and t5xxl_dtype. - # default_dtype is used for full_fp16/full_bf16 training. - def load_state_dict(path: str, dvc: Union[str, torch.device] = device): if disable_mmap: return safetensors.torch.load(open(path, "rb").read()) @@ -73,9 +71,9 @@ def load_state_dict(path: str, dvc: Union[str, torch.device] = device): return load_file(path) # prevent device invalid Error t5xxl_device = t5xxl_device or device - clip_dtype = clip_dtype or default_dtype or torch.float32 - t5xxl_dtype = t5xxl_dtype or default_dtype or torch.float32 - vae_dtype = vae_dtype or default_dtype or torch.float32 + clip_dtype = clip_dtype or weight_dtype or torch.float32 + t5xxl_dtype = t5xxl_dtype or weight_dtype or torch.float32 + vae_dtype = vae_dtype or weight_dtype or torch.float32 logger.info(f"Loading SD3 models from {ckpt_path}...") state_dict = load_state_dict(ckpt_path) @@ -157,7 +155,7 @@ def load_state_dict(path: str, dvc: Union[str, torch.device] = device): mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) logger.info("Loading state dict...") - info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, default_dtype) + info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) logger.info(f"Loaded MMDiT: {info}") # load ClipG and ClipL diff --git a/sd3_train.py b/sd3_train.py index bd30cdc72..de763ac6d 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -182,7 +182,7 @@ def train(args): raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device - clip_dtype = weight_dtype # if not args.train_text_encoder else None + clip_dtype = weight_dtype # if not args.train_text_encoder else None # モデルを読み込む attn_mode = "xformers" if args.xformers else "torch" @@ -193,7 +193,7 @@ def train(args): # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( - args, accelerator, attn_mode, None, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype + args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype ) assert clip_l is not None, "clip_l is required / clip_lは必須です" assert clip_g is not None, "clip_g is required / clip_gは必須です" @@ -769,10 +769,10 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): epoch, num_train_epochs, global_step, - clip_l if args.save_clip else None, - clip_g if args.save_clip else None, - t5xxl if args.save_t5xxl else None, - mmdit, + accelerator.unwrap_model(clip_l) if args.save_clip else None, + accelerator.unwrap_model(clip_g) if args.save_clip else None, + accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, + accelerator.unwrap_model(mmdit), vae, ) @@ -807,10 +807,10 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): epoch, num_train_epochs, global_step, - clip_l if args.save_clip else None, - clip_g if args.save_clip else None, - t5xxl if args.save_t5xxl else None, - mmdit, + accelerator.unwrap_model(clip_l) if args.save_clip else None, + accelerator.unwrap_model(clip_g) if args.save_clip else None, + accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, + accelerator.unwrap_model(mmdit), vae, ) From 50e3d6247459c9f59facaef42e03b34cd8d6287d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 8 Jul 2024 19:46:23 +0900 Subject: [PATCH 011/163] fix to work T5XXL with fp16 --- library/sd3_models.py | 144 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 138 insertions(+), 6 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 7041420cb..e4c0790d9 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1124,7 +1124,12 @@ def __init__(self, in_channels, dtype=torch.float32, device=None): self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) def forward(self, x): + org_dtype = x.dtype + if x.dtype == torch.bfloat16: + x = x.to(torch.float32) x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if x.dtype != org_dtype: + x = x.to(org_dtype) x = self.conv(x) return x @@ -1263,11 +1268,11 @@ def device(self): def dtype(self): return next(self.parameters()).dtype - @torch.autocast("cuda", dtype=torch.float16) + # @torch.autocast("cuda", dtype=torch.float16) def decode(self, latent): return self.decoder(latent) - @torch.autocast("cuda", dtype=torch.float16) + # @torch.autocast("cuda", dtype=torch.float16) def encode(self, image): hidden = self.encoder(image) mean, logvar = torch.chunk(hidden, 2, dim=1) @@ -1630,10 +1635,25 @@ def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) self.variance_epsilon = eps - def forward(self, x): - variance = x.pow(2).mean(-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) - return self.weight.to(device=x.device, dtype=x.dtype) * x + # def forward(self, x): + # variance = x.pow(2).mean(-1, keepdim=True) + # x = x * torch.rsqrt(variance + self.variance_epsilon) + # return self.weight.to(device=x.device, dtype=x.dtype) * x + + # copy from transformers' T5LayerNorm + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states class T5DenseGatedActDense(torch.nn.Module): @@ -1775,7 +1795,27 @@ def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_b def forward(self, x, past_bias=None): x, past_bias = self.layer[0](x, past_bias) + + # copy from transformers' T5Block + # clamp inf values to enable fp16 training + if x.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(x).any(), + torch.finfo(x.dtype).max - 1000, + torch.finfo(x.dtype).max, + ) + x = torch.clamp(x, min=-clamp_value, max=clamp_value) + x = self.layer[-1](x) + # clamp inf values to enable fp16 training + if x.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(x).any(), + torch.finfo(x.dtype).max - 1000, + torch.finfo(x.dtype).max, + ) + x = torch.clamp(x, min=-clamp_value, max=clamp_value) + return x, past_bias @@ -1896,4 +1936,96 @@ def create_t5xxl(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[st return t5 +""" + # snippet for using the T5 model from transformers + + from transformers import T5EncoderModel, T5Config + import accelerate + import json + + T5_CONFIG_JSON = "" +{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.41.2", + "use_cache": true, + "vocab_size": 32128 +} +"" + config = json.loads(T5_CONFIG_JSON) + config = T5Config(**config) + + # model = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3") + # print(model.config) + # # model(**load_model.config) + + # with accelerate.init_empty_weights(): + model = T5EncoderModel._from_config(config) # , torch_dtype=dtype) + for key in list(state_dict.keys()): + if key.startswith("transformer."): + new_key = key[len("transformer.") :] + state_dict[new_key] = state_dict.pop(key) + + info = model.load_state_dict(state_dict) + print(info) + model.set_attn_mode = lambda x: None + # model.to("cpu") + + _self = model + + def enc(list_of_token_weight_pairs): + has_batch = isinstance(list_of_token_weight_pairs[0][0], list) + + if has_batch: + list_of_tokens = [] + for pairs in list_of_token_weight_pairs: + tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] + list_of_tokens.append(tokens) + else: + list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] + + list_of_tokens = np.array(list_of_tokens) + list_of_tokens = torch.from_numpy(list_of_tokens).to("cuda", dtype=torch.long) + out = _self(list_of_tokens) + pooled = None + if has_batch: + return out, pooled + else: + if pooled is not None: + first_pooled = pooled[0:1] + else: + first_pooled = pooled + return out[0], first_pooled + # output = [out[0:1]] + # return torch.cat(output, dim=-2), first_pooled + + model.encode_token_weights = enc + + return model +""" + # endregion From c9de7c4e9a3d02ab6f18f105c880a9ba88b667ab Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 8 Jul 2024 19:48:28 +0900 Subject: [PATCH 012/163] WIP: new latents caching --- library/sd3_train_utils.py | 94 +++++++++++++++++++++++- library/train_util.py | 147 ++++++++++++++++++++++++++++++++++++- sd3_train.py | 37 +++++++++- 3 files changed, 270 insertions(+), 8 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index c8d52e1c8..9309ee30c 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -1,7 +1,7 @@ import argparse import math import os -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch from safetensors.torch import save_file @@ -283,6 +283,98 @@ def sample_images(*args, **kwargs): return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) +class Sd3LatensCachingStrategy(train_util.LatentsCachingStrategy): + SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" + + def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + self.vae = vae + + def get_latents_npz_path(self, absolute_path: str): + return os.path.splitext(absolute_path)[0] + self.SD3_LATENTS_NPZ_SUFFIX + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H) + + try: + npz = np.load(npz_path) + if npz["latents"].shape[1:3] != expected_latents_size: + return False + + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + + if alpha_mask: + if "alpha_mask" not in npz: + return False + if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): + return False + else: + if "alpha_mask" in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): + img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( + image_infos, alpha_mask, random_crop + ) + img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype) + + with torch.no_grad(): + latents = self.vae.encode(img_tensor).to("cpu") + if flip_aug: + img_tensor = torch.flip(img_tensor, dims=[3]) + with torch.no_grad(): + flipped_latents = self.vae.encode(img_tensor).to("cpu") + else: + flipped_latents = [None] * len(latents) + + for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks): + if self.cache_to_disk: + # save_latents_to_disk( + # info.latents_npz, + # latent, + # info.latents_original_size, + # info.latents_crop_ltrb, + # flipped_latent, + # alpha_mask, + # ) + kwargs = {} + if flipped_latent is not None: + kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() + np.savez( + info.latents_npz, + latents=latents.float().cpu().numpy(), + original_size=np.array(original_sizes), + crop_ltrb=np.array(crop_ltrbs), + **kwargs, + ) + else: + info.latents = latent + if flip_aug: + info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask + + if not train_util.HIGH_VRAM: + clean_memory_on_device(self.vae.device) + + # region Diffusers diff --git a/library/train_util.py b/library/train_util.py index 96d32e3bc..8444827df 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -359,6 +359,30 @@ def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarra return self.color_aug if use_color_aug else None +class LatentsCachingStrategy: + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + def get_latents_npz_path(self, absolute_path: str): + raise NotImplementedError + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + raise NotImplementedError + + def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): + raise NotImplementedError + + class BaseSubset: def __init__( self, @@ -986,6 +1010,69 @@ def is_text_encoder_output_cacheable(self): ] ) + def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy): + r""" + a brand new method to cache latents. This method caches latents with caching strategy. + normal cache_latents method is used by default, but this method is used when caching strategy is specified. + """ + logger.info("caching latents with caching strategy.") + image_infos = list(self.image_data.values()) + + # sort by resolution + image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) + + # split by resolution + batches = [] + batch = [] + logger.info("checking cache validity...") + for info in tqdm(image_infos): + subset = self.image_to_subset[info.image_key] + + if info.latents_npz is not None: # fine tuning dataset + continue + + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path) + if not is_main_process: # prepare for multi-gpu, only store to info + continue + + cache_available = caching_strategy.is_disk_cached_latents_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) + if cache_available: # do not add to batch + continue + + # if last member of batch has different resolution, flush the batch + if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: + batches.append(batch) + batch = [] + + batch.append(info) + + # if number of data in batch is enough, flush the batch + if len(batch) >= caching_strategy.batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + # if cache to disk, don't cache latents in non-main process, set to info only + if caching_strategy.cache_to_disk and not is_main_process: + return + + if len(batches) == 0: + logger.info("no latents to cache") + return + + # iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded + logger.info("caching latents...") + for batch in tqdm(batches, smoothing=1, total=len(batches)): + # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching latents.") @@ -1086,7 +1173,7 @@ def cache_text_encoder_outputs_common( if batch_size is None: batch_size = self.batch_size - + image_infos = list(self.image_data.values()) logger.info("checking cache existence...") @@ -2207,6 +2294,11 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) + def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.new_cache_latents(is_main_process, strategy) + def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True ): @@ -2550,6 +2642,51 @@ def trim_and_resize_if_required( return image, original_size, crop_ltrb +# for new_cache_latents +def load_images_and_masks_for_caching( + image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool +) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: + r""" + requires image_infos to have: [absolute_path or image], bucket_reso, resized_size + + returns: image_tensor, alpha_masks, original_sizes, crop_ltrbs + + image_tensor: torch.Tensor = torch.Size([B, 3, H, W]), ...], normalized to [-1, 1] + alpha_masks: List[np.ndarray] = [np.ndarray([H, W]), ...], normalized to [0, 1] + original_sizes: List[Tuple[int, int]] = [(W, H), ...] + crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...] + """ + images: List[torch.Tensor] = [] + alpha_masks: List[np.ndarray] = [] + original_sizes: List[Tuple[int, int]] = [] + crop_ltrbs: List[Tuple[int, int, int, int]] = [] + for info in image_infos: + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + + original_sizes.append(original_size) + crop_ltrbs.append(crop_ltrb) + + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + else: + alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] + else: + alpha_mask = None + alpha_masks.append(alpha_mask) + + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) + + img_tensor = torch.stack(images, dim=0) + return img_tensor, alpha_masks, original_sizes, crop_ltrbs + + def cache_batch_latents( vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool ) -> None: @@ -2661,7 +2798,7 @@ def cache_batch_text_encoder_outputs_sd3( ): # make input_ids for each text encoder l_tokens, g_tokens, t5_tokens = input_ids - + clip_l, clip_g, t5xxl = text_encoders with torch.no_grad(): b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens( @@ -2670,8 +2807,12 @@ def cache_batch_text_encoder_outputs_sd3( b_lg_out = b_lg_out.detach() b_t5_out = b_t5_out.detach() b_pool = b_pool.detach() - + for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): + # debug: NaN check + if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any(): + raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}") + if cache_to_disk: save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) else: diff --git a/sd3_train.py b/sd3_train.py index de763ac6d..c073ec0e2 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -204,11 +204,22 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible - with torch.no_grad(): - train_dataset_group.cache_latents( - vae_wrapper, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process, file_suffix="_sd3.npz" + + if not args.new_caching: + vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible + with torch.no_grad(): + train_dataset_group.cache_latents( + vae_wrapper, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + file_suffix="_sd3.npz", + ) + else: + strategy = sd3_train_utils.Sd3LatensCachingStrategy( + vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check ) + train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -699,6 +710,17 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + # debug: NaN check for all inputs + if torch.any(torch.isnan(noisy_model_input)): + accelerator.print("NaN found in noisy_model_input, replacing with zeros") + noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input) + if torch.any(torch.isnan(context)): + accelerator.print("NaN found in context, replacing with zeros") + context = torch.nan_to_num(context, 0, out=context) + if torch.any(torch.isnan(pool)): + accelerator.print("NaN found in pool, replacing with zeros") + pool = torch.nan_to_num(pool, 0, out=pool) + # call model with accelerator.autocast(): model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool) @@ -908,6 +930,13 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", ) + + parser.add_argument("--new_caching", action="store_true", help="use new caching method / 新しいキャッシング方法を使う") + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="skip latents validity check / latentsの正当性チェックをスキップする", + ) return parser From 3ea4fce5e0f3d1a9c2718d77f49c3b304d25e565 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 8 Jul 2024 22:04:43 +0900 Subject: [PATCH 013/163] load models one by one --- library/sd3_train_utils.py | 56 ++++++------ library/sd3_utils.py | 169 +++++++++++++++++++++++++++++++++++++ sd3_train.py | 58 +++++++++---- 3 files changed, 236 insertions(+), 47 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 9309ee30c..98ee66bf8 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -1,19 +1,17 @@ import argparse import math import os -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from safetensors.torch import save_file +from accelerate import Accelerator from library import sd3_models, sd3_utils, train_util from library.device_utils import init_ipex, clean_memory_on_device init_ipex() -from accelerate import init_empty_weights -from tqdm import tqdm - # from transformers import CLIPTokenizer # from library import model_util # , sdxl_model_util, train_util, sdxl_original_unet @@ -28,50 +26,48 @@ from .sdxl_train_util import match_mixed_precision -def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[ +def load_target_model( + model_type: str, + args: argparse.Namespace, + state_dict: dict, + accelerator: Accelerator, + attn_mode: str, + model_dtype: Optional[torch.dtype], + device: Optional[torch.device], +) -> Union[ sd3_models.MMDiT, Optional[sd3_models.SDClipModel], Optional[sd3_models.SDXLClipG], Optional[sd3_models.T5XXLModel], sd3_models.SDVAE, ]: - model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16 + loading_device = device if device is not None else (accelerator.device if args.lowram else "cpu") for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") - mmdit, clip_l, clip_g, t5xxl, vae = sd3_utils.load_models( - args.pretrained_model_name_or_path, - args.clip_l, - args.clip_g, - args.t5xxl, - args.vae, - attn_mode, - accelerator.device if args.lowram else "cpu", - model_dtype, - args.disable_mmap_load_safetensors, - clip_dtype, - t5xxl_device, - t5xxl_dtype, - vae_dtype, - ) + if model_type == "mmdit": + model = sd3_utils.load_mmdit(state_dict, attn_mode, model_dtype, loading_device) + elif model_type == "clip_l": + model = sd3_utils.load_clip_l(state_dict, args.clip_l, attn_mode, model_dtype, loading_device) + elif model_type == "clip_g": + model = sd3_utils.load_clip_g(state_dict, args.clip_g, attn_mode, model_dtype, loading_device) + elif model_type == "t5xxl": + model = sd3_utils.load_t5xxl(state_dict, args.t5xxl, attn_mode, model_dtype, loading_device) + elif model_type == "vae": + model = sd3_utils.load_vae(state_dict, args.vae, model_dtype, loading_device) + else: + raise ValueError(f"Unknown model type: {model_type}") # work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device if args.lowram: - if clip_l is not None: - clip_l.to(accelerator.device) - if clip_g is not None: - clip_g.to(accelerator.device) - if t5xxl is not None: - t5xxl.to(accelerator.device) - vae.to(accelerator.device) - mmdit.to(accelerator.device) + model = model.to(accelerator.device) clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() - return mmdit, clip_l, clip_g, t5xxl, vae + return model def save_models( diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 9dc9e7967..16f80c60d 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -20,6 +20,175 @@ # region models +def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False): + if disable_mmap: + return safetensors.torch.load(open(path, "rb").read()) + else: + try: + return load_file(path, device=dvc) + except: + return load_file(path) # prevent device invalid Error + + +def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]): + mmdit_sd = {} + + mmdit_prefix = "model.diffusion_model." + for k in list(state_dict.keys()): + if k.startswith(mmdit_prefix): + mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k) + + # load MMDiT + logger.info("Building MMDit") + with init_empty_weights(): + mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) + + logger.info("Loading state dict...") + info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype) + logger.info(f"Loaded MMDiT: {info}") + return mmdit + + +def load_clip_l( + state_dict: Dict, + clip_l_path: Optional[str], + attn_mode: str, + clip_dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, +): + clip_l_sd = None + if clip_l_path: + logger.info(f"Loading clip_l from {clip_l_path}...") + clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap) + for key in list(clip_l_sd.keys()): + clip_l_sd["transformer." + key] = clip_l_sd.pop(key) + else: + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + + if clip_l_sd is None: + clip_l = None + else: + logger.info("Building ClipL") + clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) + logger.info("Loading state dict...") + info = clip_l.load_state_dict(clip_l_sd) + logger.info(f"Loaded ClipL: {info}") + clip_l.set_attn_mode(attn_mode) + return clip_l + + +def load_clip_g( + state_dict: Dict, + clip_g_path: Optional[str], + attn_mode: str, + clip_dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, +): + clip_g_sd = None + if clip_g_path: + logger.info(f"Loading clip_g from {clip_g_path}...") + clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap) + for key in list(clip_g_sd.keys()): + clip_g_sd["transformer." + key] = clip_g_sd.pop(key) + else: + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + + if clip_g_sd is None: + clip_g = None + else: + logger.info("Building ClipG") + clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) + logger.info("Loading state dict...") + info = clip_g.load_state_dict(clip_g_sd) + logger.info(f"Loaded ClipG: {info}") + clip_g.set_attn_mode(attn_mode) + return clip_g + + +def load_t5xxl( + state_dict: Dict, + t5xxl_path: Optional[str], + attn_mode: str, + dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, +): + t5xxl_sd = None + if t5xxl_path: + logger.info(f"Loading t5xxl from {t5xxl_path}...") + t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap) + for key in list(t5xxl_sd.keys()): + t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) + else: + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k in list(state_dict.keys()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + + if t5xxl_sd is None: + t5xxl = None + else: + logger.info("Building T5XXL") + + # workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device + t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd) + t5xxl.to(dtype=dtype) + + logger.info("Loading state dict...") + info = t5xxl.load_state_dict(t5xxl_sd) + logger.info(f"Loaded T5XXL: {info}") + t5xxl.set_attn_mode(attn_mode) + return t5xxl + + +def load_vae( + state_dict: Dict, + vae_path: Optional[str], + vae_dtype: Optional[Union[str, torch.dtype]], + device: Optional[Union[str, torch.device]], + disable_mmap: bool = False, +): + vae_sd = {} + if vae_path: + logger.info(f"Loading VAE from {vae_path}...") + vae_sd = load_safetensors(vae_path, device, disable_mmap) + else: + # remove prefix "first_stage_model." + vae_sd = {} + vae_prefix = "first_stage_model." + for k in list(state_dict.keys()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + + logger.info("Building VAE") + vae = sd3_models.SDVAE() + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + vae.to(device=device, dtype=vae_dtype) + return vae + + def load_models( ckpt_path: str, clip_l_path: str, diff --git a/sd3_train.py b/sd3_train.py index c073ec0e2..10cc5d57f 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -13,12 +13,12 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device - init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils +from library.sdxl_train_util import match_mixed_precision # , sdxl_model_util @@ -189,18 +189,19 @@ def train(args): assert ( attn_mode == "torch" - ), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" - # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. - mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( - args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype + # SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying. + logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}") + device_to_load = accelerator.device if args.lowram else "cpu" + sd3_state_dict = sd3_utils.load_safetensors( + args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors ) - assert clip_l is not None, "clip_l is required / clip_lは必須です" - assert clip_g is not None, "clip_g is required / clip_gは必須です" - # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) - # 学習を準備する + # load VAE for caching latents + vae: sd3_models.SDVAE = None if cache_latents: + vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() @@ -220,15 +221,25 @@ def train(args): vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check ) train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) - vae.to("cpu") + vae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + # load clip_l, clip_g, t5xxl for caching text encoder outputs + # # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. + # mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( + # args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype + # ) + clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + assert clip_l is not None, "clip_l is required / clip_lは必須です" + assert clip_g is not None, "clip_g is required / clip_gは必須です" + + t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load) + # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + # 学習を準備する:モデルを適切な状態にする - if args.gradient_checkpointing: - mmdit.enable_gradient_checkpointing() - train_mmdit = args.learning_rate != 0 train_clip_l = False train_clip_g = False train_t5xxl = False @@ -280,17 +291,30 @@ def train(args): accelerator.is_main_process, args.text_encoder_batch_size, ) + + # TODO we can delete text encoders after caching accelerator.wait_for_everyone() + # load MMDIT + # if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32). + # by loading with model_dtype, we can reduce memory usage. + model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) + mmdit = sd3_train_utils.load_target_model("mmdit", args, sd3_state_dict, accelerator, attn_mode, model_dtype, device_to_load) + if args.gradient_checkpointing: + mmdit.enable_gradient_checkpointing() + + train_mmdit = args.learning_rate != 0 + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdie will not be prepared + if not cache_latents: + # load VAE here if not cached + vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) vae.requires_grad_(False) vae.eval() vae.to(accelerator.device, dtype=vae_dtype) - mmdit.requires_grad_(train_mmdit) - if not train_mmdit: - mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared - training_models = [] params_to_optimize = [] # if train_unet: From 9dc7997803d70c718969526352e88908e827f091 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 9 Jul 2024 20:37:00 +0900 Subject: [PATCH 014/163] fix typo --- library/sd3_models.py | 2 +- library/sd3_train_utils.py | 2 +- sd3_train.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index e4c0790d9..a1ff1e75a 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1643,7 +1643,7 @@ def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): # copy from transformers' T5LayerNorm def forward(self, hidden_states): # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # half-precision inputs is done in fp32 variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 98ee66bf8..660342108 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -279,7 +279,7 @@ def sample_images(*args, **kwargs): return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) -class Sd3LatensCachingStrategy(train_util.LatentsCachingStrategy): +class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: diff --git a/sd3_train.py b/sd3_train.py index 10cc5d57f..30d994c78 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -217,7 +217,7 @@ def train(args): file_suffix="_sd3.npz", ) else: - strategy = sd3_train_utils.Sd3LatensCachingStrategy( + strategy = sd3_train_utils.Sd3LatentsCachingStrategy( vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check ) train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) From 3d402927efb2d396f8f33fe6a1747e43f7a5f0f3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 9 Jul 2024 23:15:38 +0900 Subject: [PATCH 015/163] WIP: update new latents caching --- library/sd3_train_utils.py | 49 +++++++++++++++++++++++++------------- library/train_util.py | 39 ++++++++++++++++++++++++++---- sd3_train.py | 15 ++++++++---- 3 files changed, 77 insertions(+), 26 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 660342108..245912199 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -1,4 +1,5 @@ import argparse +import glob import math import os from typing import List, Optional, Tuple, Union @@ -282,12 +283,26 @@ def sample_images(*args, **kwargs): class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" - def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + self.vae = None + + def set_vae(self, vae: sd3_models.SDVAE): self.vae = vae - def get_latents_npz_path(self, absolute_path: str): - return os.path.splitext(absolute_path)[0] + self.SD3_LATENTS_NPZ_SUFFIX + def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): if not self.cache_to_disk: @@ -331,24 +346,24 @@ def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype) with torch.no_grad(): - latents = self.vae.encode(img_tensor).to("cpu") + latents_tensors = self.vae.encode(img_tensor).to("cpu") if flip_aug: img_tensor = torch.flip(img_tensor, dims=[3]) with torch.no_grad(): flipped_latents = self.vae.encode(img_tensor).to("cpu") else: - flipped_latents = [None] * len(latents) + flipped_latents = [None] * len(latents_tensors) + + # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): + for i in range(len(image_infos)): + info = image_infos[i] + latents = latents_tensors[i] + flipped_latent = flipped_latents[i] + alpha_mask = alpha_masks[i] + original_size = original_sizes[i] + crop_ltrb = crop_ltrbs[i] - for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks): if self.cache_to_disk: - # save_latents_to_disk( - # info.latents_npz, - # latent, - # info.latents_original_size, - # info.latents_crop_ltrb, - # flipped_latent, - # alpha_mask, - # ) kwargs = {} if flipped_latent is not None: kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy() @@ -357,12 +372,12 @@ def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: np.savez( info.latents_npz, latents=latents.float().cpu().numpy(), - original_size=np.array(original_sizes), - crop_ltrb=np.array(crop_ltrbs), + original_size=np.array(original_size), + crop_ltrb=np.array(crop_ltrb), **kwargs, ) else: - info.latents = latent + info.latents = latents if flip_aug: info.latents_flipped = flipped_latent info.alpha_mask = alpha_mask diff --git a/library/train_util.py b/library/train_util.py index 8444827df..9db226ea8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -360,11 +360,23 @@ def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarra class LatentsCachingStrategy: + _strategy = None # strategy instance: actual strategy class + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: self._cache_to_disk = cache_to_disk self._batch_size = batch_size self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: + return cls._strategy + @property def cache_to_disk(self): return self._cache_to_disk @@ -373,10 +385,15 @@ def cache_to_disk(self): def batch_size(self): return self._batch_size - def get_latents_npz_path(self, absolute_path: str): + def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + raise NotImplementedError + + def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str: raise NotImplementedError - def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + def is_disk_cached_latents_expected( + self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ) -> bool: raise NotImplementedError def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -1034,7 +1051,7 @@ def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCach # check disk cache exists and size of latents if caching_strategy.cache_to_disk: # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path) + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) if not is_main_process: # prepare for multi-gpu, only store to info continue @@ -1730,6 +1747,18 @@ def load_dreambooth_dir(subset: DreamBoothSubset): img_paths = glob_images(subset.image_dir, "*") sizes = [None] * len(img_paths) + # new caching: get image size from cache files + strategy = LatentsCachingStrategy.get_strategy() + if strategy is not None: + logger.info("get image size from cache files") + size_set_count = 0 + for i, img_path in enumerate(tqdm(img_paths)): + w, h = strategy.get_image_size_from_image_absolute_path(img_path) + if w is not None and h is not None: + sizes[i] = [w, h] + size_set_count += 1 + logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") if use_cached_info_for_subset: @@ -2807,12 +2836,12 @@ def cache_batch_text_encoder_outputs_sd3( b_lg_out = b_lg_out.detach() b_t5_out = b_t5_out.detach() b_pool = b_pool.detach() - + for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): # debug: NaN check if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any(): raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}") - + if cache_to_disk: save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) else: diff --git a/sd3_train.py b/sd3_train.py index 30d994c78..e2f622e47 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -91,6 +91,15 @@ def train(args): # load tokenizer sd3_tokenizer = sd3_models.SD3Tokenizer() + # prepare caching strategy + if args.new_caching: + latents_caching_strategy = sd3_train_utils.Sd3LatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + ) + else: + latents_caching_strategy = None + train_util.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + # データセットを準備する if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) @@ -217,10 +226,8 @@ def train(args): file_suffix="_sd3.npz", ) else: - strategy = sd3_train_utils.Sd3LatentsCachingStrategy( - vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check - ) - train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) + latents_caching_strategy.set_vae(vae) + train_dataset_group.new_cache_latents(accelerator.is_main_process, latents_caching_strategy) vae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) From 6f0e235f2cb9a9829bc12280c29e12c0ae66c88f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 11 Jul 2024 08:00:45 +0900 Subject: [PATCH 016/163] Fix shift value in SD3 inference. --- sd3_minimal_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index 7f5f28cea..ffa0d46de 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -64,7 +64,7 @@ def do_sample( device: str, ): if initial_latent is None: - # latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 + # latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 # this seems to be a bug in the original code. thanks to furusu for pointing it out latent = torch.zeros(1, 16, height // 8, width // 8, device=device) else: latent = initial_latent @@ -73,7 +73,7 @@ def do_sample( noise = get_noise(seed, latent).to(device) - model_sampling = sd3_utils.ModelSamplingDiscreteFlow() + model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 sigmas = get_sigmas(model_sampling, steps).to(device) # sigmas = sigmas[int(steps * (1 - denoise)) :] # do not support i2i From b8896aad400222c8c4441b217fda0f9bb0807ffd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 11 Jul 2024 08:01:23 +0900 Subject: [PATCH 017/163] update README --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3eed636c5..5d4f9621d 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,9 @@ This repository contains training, generation and utility scripts for Stable Dif SD3 training is done with `sd3_train.py`. -__Jun 29, 2024__: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). +__Jul 11, 2024__: Fixed to work t5xxl with `fp16`. If you change the dtype to `fp16` for t5xxl, please remove existing latents cache files (`*_sd3.npz`). The shift in `sd3_minimum_inference.py` is fixed to 3.0. Thanks to araleza! + +Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). `fp16` and `bf16` are available for mixed precision training. We are not sure which is better. @@ -12,7 +14,7 @@ __Jun 29, 2024__: Fixed mixed precision training with fp16 is not working. Fixed `clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. -t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. +~~t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. ~~ t5xxl works with `fp16` now. There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. From 082f13658bdbaed872ede6c0a7a75ab1a5f3712d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 12 Jul 2024 21:28:01 +0900 Subject: [PATCH 018/163] reduce peak GPU memory usage before training --- library/sd3_models.py | 2 +- library/train_util.py | 1 + sd3_train.py | 44 +++++++++++++++++++++---------------------- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index a1ff1e75a..ec8e1bbdd 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -471,7 +471,7 @@ def __init__( num_heads: int = 8, qkv_bias: bool = False, pre_only: bool = False, - qk_norm: str = None, + qk_norm: Optional[str] = None, ): super().__init__() self.num_heads = num_heads diff --git a/library/train_util.py b/library/train_util.py index 9db226ea8..7af0070e1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2410,6 +2410,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) +# TODO update to use CachingStrategy def load_latents_from_disk( npz_path, ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: diff --git a/sd3_train.py b/sd3_train.py index e2f622e47..f34e47124 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -458,6 +458,28 @@ def train(args): # text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) # text_encoder1.text_model.final_layer_norm.requires_grad_(False) + # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + clip_l.to("cpu", dtype=torch.float32) + clip_g.to("cpu", dtype=torch.float32) + if t5xxl is not None: + t5xxl.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + # TODO support CPU for text encoders + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(accelerator.device) + + # TODO cache sample prompt's embeddings to free text encoder's memory + if args.cache_text_encoder_outputs: + if not args.save_t5xxl: + t5xxl = None # free memory + clean_memory_on_device(accelerator.device) + if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model( args, @@ -482,28 +504,6 @@ def train(args): # text_encoder2 = accelerator.prepare(text_encoder2) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) - # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する - if args.cache_text_encoder_outputs: - # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 - clip_l.to("cpu", dtype=torch.float32) - clip_g.to("cpu", dtype=torch.float32) - if t5xxl is not None: - t5xxl.to("cpu", dtype=torch.float32) - clean_memory_on_device(accelerator.device) - else: - # make sure Text Encoders are on GPU - # TODO support CPU for text encoders - clip_l.to(accelerator.device) - clip_g.to(accelerator.device) - if t5xxl is not None: - t5xxl.to(accelerator.device) - - # TODO cache sample prompt's embeddings to free text encoder's memory - if args.cache_text_encoder_outputs: - if not args.save_t5xxl: - t5xxl = None # free memory - clean_memory_on_device(accelerator.device) - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. From 87526942a67fd71bb775bc479b0a7449df516dd8 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Fri, 12 Jul 2024 22:56:38 +0800 Subject: [PATCH 019/163] judge image size for using diff interpolation --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..74720fec6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2362,7 +2362,7 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA if image_width > resized_size[0] and image_height > resized_size[1] else cv2.INTER_LANCZOS4) image_height, image_width = image.shape[0:2] From 2e67978ee243a20f169ce76d7644bb1f9dec9bad Mon Sep 17 00:00:00 2001 From: Millie Date: Thu, 18 Jul 2024 11:52:58 -0700 Subject: [PATCH 020/163] Generate sample images without having CUDA (such as on Macs) --- library/train_util.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..9b0397d7d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5229,7 +5229,7 @@ def sample_images_common( clean_memory_on_device(accelerator.device) torch.set_rng_state(rng_state) - if cuda_rng_state is not None: + if torch.cuda.is_available() and cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) @@ -5263,11 +5263,13 @@ def sample_image_inference( if seed is not None: torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) else: # True random sample image generation torch.seed() - torch.cuda.seed() + if torch.cuda.is_available(): + torch.cuda.seed() scheduler = get_my_scheduler( sample_sampler=sampler_name, @@ -5302,8 +5304,9 @@ def sample_image_inference( controlnet_image=controlnet_image, ) - with torch.cuda.device(torch.cuda.current_device()): - torch.cuda.empty_cache() + if torch.cuda.is_available(): + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() image = pipeline.latents_to_image(latents)[0] From 1f16b80e88b1c4f05d49b4fc328d3b9b105ebcbe Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 20 Jul 2024 21:35:24 +0800 Subject: [PATCH 021/163] Revert "judge image size for using diff interpolation" This reverts commit 87526942a67fd71bb775bc479b0a7449df516dd8. --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 74720fec6..15c23f3cc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2362,7 +2362,7 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA if image_width > resized_size[0] and image_height > resized_size[1] else cv2.INTER_LANCZOS4) + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ image_height, image_width = image.shape[0:2] From 9ca7a5b6cc99e25820a1aa6d02a779004d73bca0 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 20 Jul 2024 21:59:11 +0800 Subject: [PATCH 022/163] instead cv2 LANCZOS4 resize to pil resize --- finetune/tag_images_by_wd14_tagger.py | 8 +++++--- library/train_util.py | 11 ++++++----- library/utils.py | 14 +++++++++++++- tools/detect_face_rotate.py | 7 +++++-- tools/resize_images_to_resolution.py | 11 +++++++---- 5 files changed, 36 insertions(+), 15 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index a327bbd61..6f5bdd36b 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,7 +11,7 @@ from tqdm import tqdm import library.train_util as train_util -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging @@ -42,8 +42,10 @@ def preprocess_image(image): pad_t = pad_y // 2 image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) - interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 - image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) + if size > IMAGE_SIZE: + image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA) + else: + image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE)) image = image.astype(np.float32) return image diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..160e3b44b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -71,7 +71,7 @@ import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec import library.deepspeed_utils as deepspeed_utils -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging @@ -2028,9 +2028,7 @@ def __getitem__(self, index): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = cv2.resize( - cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4 - ) + cond_img=pil_resize(cond_img,(int(target_size_hw[1]), int(target_size_hw[0]))) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2362,7 +2360,10 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + if image_width > resized_size[0] and image_height > resized_size[1]: + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + else: + image = pil_resize(image, resized_size) image_height, image_width = image.shape[0:2] diff --git a/library/utils.py b/library/utils.py index 3037c055d..a219f6cb7 100644 --- a/library/utils.py +++ b/library/utils.py @@ -7,7 +7,9 @@ from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput - +import cv2 +from PIL import Image +import numpy as np def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -78,7 +80,17 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) +def pil_resize(image, size, interpolation=Image.LANCZOS): + + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # use Pillow resize + resized_pil = pil_image.resize(size, interpolation) + + # return cv2 image + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + return resized_cv2 # TODO make inf_utils.py diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index bbc643edc..d2a4d9cfb 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,7 +15,7 @@ from anime_face_detector import create_detector from tqdm import tqdm import numpy as np -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging logger = logging.getLogger(__name__) @@ -172,7 +172,10 @@ def process(args): if scale != 1.0: w = int(w * scale + .5) h = int(h * scale + .5) - face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4) + if scale < 1.0: + face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA) + else: + face_img = pil_resize(face_img, (w, h)) cx = int(cx * scale + .5) cy = int(cy * scale + .5) fw = int(fw * scale + .5) diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index b8069fc1d..0f9e00b1e 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,7 +6,7 @@ import math from PIL import Image import numpy as np -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging logger = logging.getLogger(__name__) @@ -24,9 +24,9 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi # Select interpolation method if interpolation == 'lanczos4': - cv2_interpolation = cv2.INTER_LANCZOS4 + pil_interpolation = Image.LANCZOS elif interpolation == 'cubic': - cv2_interpolation = cv2.INTER_CUBIC + pil_interpolation = Image.BICUBIC else: cv2_interpolation = cv2.INTER_AREA @@ -64,7 +64,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi new_width = int(img.shape[1] * math.sqrt(scale_factor)) # Resize image - img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + if cv2_interpolation: + img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + else: + img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation) else: new_height, new_width = img.shape[0:2] From 41dee60383a3b88859b80929a2c0d94b12c42068 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 27 Jul 2024 13:50:05 +0900 Subject: [PATCH 023/163] Refactor caching mechanism for latents and text encoder outputs, etc. --- README.md | 21 +- fine_tune.py | 54 +++- library/config_util.py | 2 - library/sd3_models.py | 47 +++- library/sd3_train_utils.py | 105 ------- library/sd3_utils.py | 1 + library/sdxl_train_util.py | 2 +- library/strategy_base.py | 328 ++++++++++++++++++++++ library/strategy_sd.py | 139 ++++++++++ library/strategy_sd3.py | 229 ++++++++++++++++ library/strategy_sdxl.py | 247 +++++++++++++++++ library/train_util.py | 451 +++++++++++++++---------------- sd3_minimal_inference.py | 22 +- sd3_train.py | 272 +++++++++++-------- sdxl_train.py | 108 ++++---- sdxl_train_control_net_lllite.py | 99 ++++--- sdxl_train_network.py | 48 +++- sdxl_train_textual_inversion.py | 49 ++-- train_db.py | 67 +++-- train_network.py | 122 ++++++--- train_textual_inversion.py | 118 ++++---- 21 files changed, 1792 insertions(+), 739 deletions(-) create mode 100644 library/strategy_base.py create mode 100644 library/strategy_sd.py create mode 100644 library/strategy_sd3.py create mode 100644 library/strategy_sdxl.py diff --git a/README.md b/README.md index 5d4f9621d..d406fecde 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,16 @@ This repository contains training, generation and utility scripts for Stable Dif SD3 training is done with `sd3_train.py`. -__Jul 11, 2024__: Fixed to work t5xxl with `fp16`. If you change the dtype to `fp16` for t5xxl, please remove existing latents cache files (`*_sd3.npz`). The shift in `sd3_minimum_inference.py` is fixed to 3.0. Thanks to araleza! +__Jul 27, 2024__: +- Latents and text encoder outputs caching mechanism is refactored significantly. + - Existing cache files for SD3 need to be recreated. Please delete the previous cache files. + - With this change, dataset initialization is significantly faster, especially for large datasets. -Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). +- Architecture-dependent parts are extracted from the dataset (`train_util.py`). This is expected to make it easier to add future architectures. + +- Architecture-dependent parts including the cache mechanism for SD1/2/SDXL are also extracted. The basic operation of SD1/2/SDXL training on the sd3 branch has been confirmed, but there may be bugs. Please use the main or dev branch for SD1/2/SDXL training. + +--- `fp16` and `bf16` are available for mixed precision training. We are not sure which is better. @@ -14,7 +21,7 @@ Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the `clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. -~~t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. ~~ t5xxl works with `fp16` now. +t5xxl works with `fp16` now. There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. @@ -32,6 +39,14 @@ cache_latents = true cache_latents_to_disk = true ``` +__2024/7/27:__ + +Latents およびテキストエンコーダ出力のキャッシュの仕組みを大きくリファクタリングしました。SD3 用の既存のキャッシュファイルの再作成が必要になりますが、ご了承ください(以前のキャッシュファイルは削除してください)。これにより、特にデータセットの規模が大きい場合のデータセット初期化が大幅に高速化されます。 + +データセット (`train_util.py`) からアーキテクチャ依存の部分を切り出しました。これにより将来的なアーキテクチャ追加が容易になると期待しています。 + +SD1/2/SDXL のキャッシュ機構を含むアーキテクチャ依存の部分も切り出しました。sd3 ブランチの SD1/2/SDXL 学習について、基本的な動作は確認していますが、不具合があるかもしれません。SD1/2/SDXL の学習には main または dev ブランチをお使いください。 + --- [__Change History__](#change-history) is moved to the bottom of the page. diff --git a/fine_tune.py b/fine_tune.py index d865cd2de..c9102f6c0 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -10,7 +10,7 @@ from tqdm import tqdm import torch -from library import deepspeed_utils +from library import deepspeed_utils, strategy_base from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -39,6 +39,7 @@ scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) +import library.strategy_sd as strategy_sd def train(args): @@ -52,7 +53,15 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if cache_latents: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -81,10 +90,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -165,8 +174,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -192,6 +202,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: text_encoder.eval() + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if not cache_latents: vae.requires_grad_(False) vae.eval() @@ -214,7 +227,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print("prepare optimizer, data loader etc.") _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -317,7 +334,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): ) # For --sample_at_first - train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -342,8 +361,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning if args.weighted_captions: + # TODO move to strategy_sd.py encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, + tokenize_strategy.tokenizer, text_encoder, batch["captions"], accelerator.device, @@ -351,10 +371,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): clip_skip=args.clip_skip, ) else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -409,7 +431,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) # 指定ステップごとにモデルを保存 @@ -472,7 +494,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) is_main_process = accelerator.is_main_process if is_main_process: diff --git a/library/config_util.py b/library/config_util.py index 10b2457f3..f8cdfe60a 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -104,8 +104,6 @@ class ControlNetSubsetParams(BaseSubsetParams): @dataclass class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False diff --git a/library/sd3_models.py b/library/sd3_models.py index ec8e1bbdd..28378c73b 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -38,7 +38,7 @@ def __init__( サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。 Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings. """ - self.tokenizer = tokenizer + self.tokenizer: CLIPTokenizer = tokenizer self.max_length = max_length self.min_length = min_length empty = self.tokenizer("")["input_ids"] @@ -56,6 +56,19 @@ def __init__( self.inv_vocab = {v: k for k, v in vocab.items()} self.max_word_length = 8 + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + """ + Tokenize the text without weights. + """ + if type(text) == str: + text = [text] + batch_tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt") + # return tokens["input_ids"] + + pad_token = self.end_token if self.pad_with_end else 0 + for tokens in batch_tokens["input_ids"]: + assert tokens[0] == self.start_token, f"tokens[0]: {tokens[0]}, start_token: {self.start_token}" + def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None): """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" @@ -75,13 +88,14 @@ def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate for word in to_tokenize: batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]]) batch.append((self.end_token, 1.0)) + print(len(batch), self.max_length, self.min_length) if self.pad_to_max_length: batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) if self.min_length is not None and len(batch) < self.min_length: batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) # truncate to max_length - # print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}") + print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}") if truncate_to_max_length and len(batch) > self.max_length: batch = batch[: self.max_length] if truncate_length is not None and len(batch) > truncate_length: @@ -110,27 +124,38 @@ def __init__(self, tokenizer): class SD3Tokenizer: - def __init__(self, t5xxl=True): + def __init__(self, t5xxl=True, t5xxl_max_length: Optional[int] = 256): + if t5xxl_max_length is None: + t5xxl_max_length = 256 + # TODO cache tokenizer settings locally or hold them in the repo like ComfyUI clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) self.clip_g = SDXLClipGTokenizer(clip_tokenizer) + # self.clip_l = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + # self.clip_g = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") self.t5xxl = T5XXLTokenizer() if t5xxl else None # t5xxl has 99999999 max length, clip has 77 - self.model_max_length = self.clip_l.max_length # 77 + self.t5xxl_max_length = t5xxl_max_length def tokenize_with_weights(self, text: str): - # temporary truncate to max_length even for t5xxl return ( self.clip_l.tokenize_with_weights(text), self.clip_g.tokenize_with_weights(text), ( - self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length) + self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.t5xxl_max_length) if self.t5xxl is not None else None ), ) + def tokenize(self, text: str): + return ( + self.clip_l.tokenize(text), + self.clip_g.tokenize(text), + (self.t5xxl.tokenize(text) if self.t5xxl is not None else None), + ) + # endregion @@ -1474,7 +1499,10 @@ def encode_token_weights(self, list_of_token_weight_pairs): tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] list_of_tokens.append(tokens) else: - list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] + if isinstance(list_of_token_weight_pairs[0], torch.Tensor): + list_of_tokens = [list(list_of_token_weight_pairs[0])] + else: + list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] out, pooled = self(list_of_tokens) if has_batch: @@ -1614,9 +1642,9 @@ def set_attn_mode(self, mode): ### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl ################################################################################################# - +""" class T5XXLTokenizer(SDTokenizer): - """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + ""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"" def __init__(self): super().__init__( @@ -1627,6 +1655,7 @@ def __init__(self): max_length=99999999, min_length=77, ) +""" class T5LayerNorm(torch.nn.Module): diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 245912199..8f99d9474 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -280,111 +280,6 @@ def sample_images(*args, **kwargs): return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) -class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): - SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" - - def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: - super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) - self.vae = None - - def set_vae(self, vae: sd3_models.SDVAE): - self.vae = vae - - def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) - if len(npz_file) == 0: - return None, None - w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") - return int(w), int(h) - - def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: - return ( - os.path.splitext(absolute_path)[0] - + f"_{image_size[0]:04d}x{image_size[1]:04d}" - + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX - ) - - def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - if not self.cache_to_disk: - return False - if not os.path.exists(npz_path): - return False - if self.skip_disk_cache_validity_check: - return True - - expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H) - - try: - npz = np.load(npz_path) - if npz["latents"].shape[1:3] != expected_latents_size: - return False - - if flip_aug: - if "latents_flipped" not in npz: - return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: - return False - - if alpha_mask: - if "alpha_mask" not in npz: - return False - if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): - return False - else: - if "alpha_mask" in npz: - return False - except Exception as e: - logger.error(f"Error loading file: {npz_path}") - raise e - - return True - - def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): - img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( - image_infos, alpha_mask, random_crop - ) - img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype) - - with torch.no_grad(): - latents_tensors = self.vae.encode(img_tensor).to("cpu") - if flip_aug: - img_tensor = torch.flip(img_tensor, dims=[3]) - with torch.no_grad(): - flipped_latents = self.vae.encode(img_tensor).to("cpu") - else: - flipped_latents = [None] * len(latents_tensors) - - # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): - for i in range(len(image_infos)): - info = image_infos[i] - latents = latents_tensors[i] - flipped_latent = flipped_latents[i] - alpha_mask = alpha_masks[i] - original_size = original_sizes[i] - crop_ltrb = crop_ltrbs[i] - - if self.cache_to_disk: - kwargs = {} - if flipped_latent is not None: - kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy() - if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - np.savez( - info.latents_npz, - latents=latents.float().cpu().numpy(), - original_size=np.array(original_size), - crop_ltrb=np.array(crop_ltrb), - **kwargs, - ) - else: - info.latents = latents - if flip_aug: - info.latents_flipped = flipped_latent - info.alpha_mask = alpha_mask - - if not train_util.HIGH_VRAM: - clean_memory_on_device(self.vae.device) - # region Diffusers diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 16f80c60d..5849518fb 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -384,6 +384,7 @@ def get_cond( dtype: Optional[torch.dtype] = None, ): l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) + print(t5_tokens) return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index b74bea91a..f009b5779 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -327,7 +327,7 @@ def diffusers_saver(out_dir): ) -def add_sdxl_training_arguments(parser: argparse.ArgumentParser): +def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True): parser.add_argument( "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" ) diff --git a/library/strategy_base.py b/library/strategy_base.py new file mode 100644 index 000000000..594cca5eb --- /dev/null +++ b/library/strategy_base.py @@ -0,0 +1,328 @@ +# base class for platform strategies. this file defines the interface for strategies + +import os +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection + + +# TODO remove circular import by moving ImageInfo to a separate file +# from library.train_util import ImageInfo + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class TokenizeStrategy: + _strategy = None # strategy instance: actual strategy class + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TokenizeStrategy"]: + return cls._strategy + + def _load_tokenizer( + self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None + ) -> Any: + tokenizer = None + if tokenizer_cache_dir: + local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2 + + if tokenizer is None: + tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder) + + if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + + return tokenizer + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + raise NotImplementedError + + def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor: + """ + for SD1.5/2.0/SDXL + TODO support batch input + """ + if max_length is None: + max_length = tokenizer.model_max_length - 2 + + input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids + + if max_length > tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if tokenizer.pad_token_id == tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75) + ids_chunk = ( + input_ids[0].unsqueeze(0), + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 or SDXL + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): + ids_chunk = ( + input_ids[0].unsqueeze(0), # BOS + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id: + ids_chunk[-1] = tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == tokenizer.pad_token_id: + ids_chunk[1] = tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + return input_ids + + +class TextEncodingStrategy: + _strategy = None # strategy instance: actual strategy class + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TextEncodingStrategy"]: + return cls._strategy + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Encode tokens into embeddings and outputs. + :param tokens: list of token tensors for each TextModel + :return: list of output embeddings for each architecture + """ + raise NotImplementedError + + +class TextEncoderOutputsCachingStrategy: + _strategy = None # strategy instance: actual strategy class + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + ) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + self._is_partial = is_partial + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]: + return cls._strategy + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + @property + def is_partial(self): + return self._is_partial + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + raise NotImplementedError + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + raise NotImplementedError + + def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: + raise NotImplementedError + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List + ): + raise NotImplementedError + + +class LatentsCachingStrategy: + # TODO commonize utillity functions to this class, such as npz handling etc. + + _strategy = None # strategy instance: actual strategy class + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: + return cls._strategy + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + raise NotImplementedError + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + raise NotImplementedError + + def is_disk_cached_latents_expected( + self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ) -> bool: + raise NotImplementedError + + def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + raise NotImplementedError + + def _defualt_is_disk_cached_latents_expected( + self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + + try: + npz = np.load(npz_path) + if npz["latents"].shape[1:3] != expected_latents_size: + return False + + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + + if alpha_mask: + if "alpha_mask" not in npz: + return False + if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): + return False + else: + if "alpha_mask" in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + # TODO remove circular dependency for ImageInfo + def _default_cache_batch_latents( + self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool + ): + """ + Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. + """ + from library import train_util # import here to avoid circular import + + img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( + image_infos, alpha_mask, random_crop + ) + img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype) + + with torch.no_grad(): + latents_tensors = encode_by_vae(img_tensor).to("cpu") + if flip_aug: + img_tensor = torch.flip(img_tensor, dims=[3]) + with torch.no_grad(): + flipped_latents = encode_by_vae(img_tensor).to("cpu") + else: + flipped_latents = [None] * len(latents_tensors) + + # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): + for i in range(len(image_infos)): + info = image_infos[i] + latents = latents_tensors[i] + flipped_latent = flipped_latents[i] + alpha_mask = alpha_masks[i] + original_size = original_sizes[i] + crop_ltrb = crop_ltrbs[i] + + if self.cache_to_disk: + self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask) + else: + info.latents_original_size = original_size + info.latents_crop_ltrb = crop_ltrb + info.latents = latents + if flip_aug: + info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask + + def load_latents_from_disk( + self, npz_path: str + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + npz = np.load(npz_path) + if "latents" not in npz: + raise ValueError(f"error: npz is old format. please re-generate {npz_path}") + + latents = npz["latents"] + original_size = npz["original_size"].tolist() + crop_ltrb = npz["crop_ltrb"].tolist() + flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None + alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + + def save_latents_to_disk( + self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None + ): + kwargs = {} + if flipped_latents_tensor is not None: + kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() + np.savez( + npz_path, + latents=latents_tensor.float().cpu().numpy(), + original_size=np.array(original_size), + crop_ltrb=np.array(crop_ltrb), + **kwargs, + ) diff --git a/library/strategy_sd.py b/library/strategy_sd.py new file mode 100644 index 000000000..105816145 --- /dev/null +++ b/library/strategy_sd.py @@ -0,0 +1,139 @@ +import glob +import os +from typing import Any, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTokenizer +from library import train_util +from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +TOKENIZER_ID = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ + + +class SdTokenizeStrategy(TokenizeStrategy): + def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: + """ + max_length does not include and (None, 75, 150, 225) + """ + logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer") + if v2: + self.tokenizer = self._load_tokenizer( + CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir + ) + else: + self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + + if max_length is None: + self.max_length = self.tokenizer.model_max_length + else: + self.max_length = max_length + 2 + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] + + +class SdTextEncodingStrategy(TextEncodingStrategy): + def __init__(self, clip_skip: Optional[int] = None) -> None: + self.clip_skip = clip_skip + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + text_encoder = models[0] + tokens = tokens[0] + sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy + + # tokens: b,n,77 + b_size = tokens.size()[0] + max_token_length = tokens.size()[1] * tokens.size()[2] + model_max_length = sd_tokenize_strategy.tokenizer.model_max_length + tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77 + + if self.clip_skip is None: + encoder_hidden_states = text_encoder(tokens)[0] + else: + enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) + + if max_token_length != model_max_length: + v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id + if not v1: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token: + # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + return [encoder_hidden_states] + + +class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): + # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. + # and we keep the old npz for the backward compatibility. + + SD_OLD_LATENTS_NPZ_SUFFIX = ".npz" + SD_LATENTS_NPZ_SUFFIX = "_sd.npz" + SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz" + + def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + self.sd = sd + self.suffix = ( + SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX + ) + + def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + # does not include old npz + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + self.suffix) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + # support old .npz + old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX + if os.path.exists(old_npz_file): + return old_npz_file + return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample() + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py new file mode 100644 index 000000000..42630ab22 --- /dev/null +++ b/library/strategy_sd3.py @@ -0,0 +1,229 @@ +import os +import glob +from typing import Any, List, Optional, Tuple, Union +import torch +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast + +from library import sd3_utils, train_util +from library import sd3_models +from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14" +CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" +T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" + + +class Sd3TokenizeStrategy(TokenizeStrategy): + def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None: + self.t5xxl_max_length = t5xxl_max_length + self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + + l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + + l_tokens = l_tokens["input_ids"] + g_tokens = g_tokens["input_ids"] + t5_tokens = t5_tokens["input_ids"] + + return [l_tokens, g_tokens, t5_tokens] + + +class Sd3TextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + pass + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + clip_l, clip_g, t5xxl = models + + l_tokens, g_tokens, t5_tokens = tokens + if l_tokens is None: + assert g_tokens is None, "g_tokens must be None if l_tokens is None" + lg_out = None + else: + assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + l_out, l_pooled = clip_l(l_tokens) + g_out, g_pooled = clip_g(g_tokens) + lg_out = torch.cat([l_out, g_out], dim=-1) + + if t5xxl is not None and t5_tokens is not None: + t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096] + else: + t5_out = None + + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None + return [lg_out, t5_out, lg_pooled] + + def concat_encodings( + self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + if t5_out is None: + t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) + return torch.cat([lg_out, t5_out], dim=-2), lg_pooled + + +class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz" + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, abs_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(self.get_outputs_npz_path(abs_path)): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(self.get_outputs_npz_path(abs_path)) + if "clip_l" not in npz or "clip_g" not in npz: + return False + if "clip_l_pool" not in npz or "clip_g_pool" not in npz: + return False + # t5xxl is optional + except Exception as e: + logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + lg_out = data["lg_out"] + lg_pooled = data["lg_pooled"] + t5_out = data["t5_out"] if "t5_out" in data else None + return [lg_out, t5_out, lg_pooled] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + captions = [info.caption for info in infos] + + clip_l_tokens, clip_g_tokens, t5xxl_tokens = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + lg_out, t5_out, lg_pooled = text_encoding_strategy.encode_tokens( + tokenize_strategy, models, [clip_l_tokens, clip_g_tokens, t5xxl_tokens] + ) + + if lg_out.dtype == torch.bfloat16: + lg_out = lg_out.float() + if lg_pooled.dtype == torch.bfloat16: + lg_pooled = lg_pooled.float() + if t5_out is not None and t5_out.dtype == torch.bfloat16: + t5_out = t5_out.float() + + lg_out = lg_out.cpu().numpy() + lg_pooled = lg_pooled.cpu().numpy() + if t5_out is not None: + t5_out = t5_out.cpu().numpy() + + for i, info in enumerate(infos): + lg_out_i = lg_out[i] + t5_out_i = t5_out[i] if t5_out is not None else None + lg_pooled_i = lg_pooled[i] + + if self.cache_to_disk: + kwargs = {} + if t5_out is not None: + kwargs["t5_out"] = t5_out_i + np.savez(info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, **kwargs) + else: + info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i) + + +class Sd3LatentsCachingStrategy(LatentsCachingStrategy): + SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) + + +if __name__ == "__main__": + # test code for Sd3TokenizeStrategy + # tokenizer = sd3_models.SD3Tokenizer() + strategy = Sd3TokenizeStrategy(256) + text = "hello world" + + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + # print(l_tokens.shape) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + texts = ["hello world", "the quick brown fox jumps over the lazy dog"] + l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens_2 = strategy.t5xxl( + texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + print(l_tokens_2) + print(g_tokens_2) + print(t5_tokens_2) + + # compare + print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0])) + print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0])) + print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0])) + + text = ",".join(["hello world! this is long text"] * 50) + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + print(f"model max length l: {strategy.clip_l.model_max_length}") + print(f"model max length g: {strategy.clip_g.model_max_length}") + print(f"model max length t5: {strategy.t5xxl.model_max_length}") diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py new file mode 100644 index 000000000..a4513336d --- /dev/null +++ b/library/strategy_sdxl.py @@ -0,0 +1,247 @@ +import os +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection +from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +TOKENIZER1_PATH = "openai/clip-vit-large-patch14" +TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + + +class SdxlTokenizeStrategy(TokenizeStrategy): + def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: + self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir) + self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir) + self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2 + + if max_length is None: + self.max_length = self.tokenizer1.model_max_length + else: + self.max_length = max_length + 2 + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + return ( + torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0), + torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0), + ) + + +class SdxlTextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + pass + + def _pool_workaround( + self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int + ): + r""" + workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output + instead of the hidden states for the EOS token + If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output + + Original code from CLIP's pooling function: + + \# text_embeds.shape = [batch_size, sequence_length, transformer.width] + \# take features from the eot embedding (eot_token is the highest number in each sequence) + \# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + """ + + # input_ids: b*n,77 + # find index for EOS token + + # Following code is not working if one of the input_ids has multiple EOS tokens (very odd case) + # eos_token_index = torch.where(input_ids == eos_token_id)[1] + # eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # Create a mask where the EOS tokens are + eos_token_mask = (input_ids == eos_token_id).int() + + # Use argmax to find the last index of the EOS token for each element in the batch + eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine + eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # get hidden states for EOS token + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index + ] + + # apply projection: projection may be of different dtype than last_hidden_state + pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype)) + pooled_output = pooled_output.to(last_hidden_state.dtype) + + return pooled_output + + def _get_hidden_states_sdxl( + self, + input_ids1: torch.Tensor, + input_ids2: torch.Tensor, + tokenizer1: CLIPTokenizer, + tokenizer2: CLIPTokenizer, + text_encoder1: Union[CLIPTextModel, torch.nn.Module], + text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module], + unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None, + ): + # input_ids: b,n,77 -> b*n, 77 + b_size = input_ids1.size()[0] + max_token_length = input_ids1.size()[1] * input_ids1.size()[2] + input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 + input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 + input_ids1 = input_ids1.to(text_encoder1.device) + input_ids2 = input_ids2.to(text_encoder2.device) + + # text_encoder1 + enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True) + hidden_states1 = enc_out["hidden_states"][11] + + # text_encoder2 + enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) + hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer + + # pool2 = enc_out["text_embeds"] + unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2 + pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) + + # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 + n_size = 1 if max_token_length is None else max_token_length // 75 + hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1])) + hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) + + if max_token_length is not None: + # bs*3, 77, 768 or 1024 + # encoder1: ... の三連を ... へ戻す + states_list = [hidden_states1[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer1.model_max_length): + states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # の後から の前まで + states_list.append(hidden_states1[:, -1].unsqueeze(1)) # + hidden_states1 = torch.cat(states_list, dim=1) + + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [hidden_states2[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer2.model_max_length): + chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # の後から 最後の前まで + # this causes an error: + # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation + # if i > 1: + # for j in range(len(chunk)): # batch_size + # if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり ...のパターン + # chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(hidden_states2[:, -1].unsqueeze(1)) # のどちらか + hidden_states2 = torch.cat(states_list, dim=1) + + # pool はnの最初のものを使う + pool2 = pool2[::n_size] + + return hidden_states1, hidden_states2, pool2 + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Args: + tokenize_strategy: TokenizeStrategy + models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)] + tokens: List of tokens, for text_encoder1 and text_encoder2 + """ + if len(models) == 2: + text_encoder1, text_encoder2 = models + unwrapped_text_encoder2 = None + else: + text_encoder1, text_encoder2, unwrapped_text_encoder2 = models + tokens1, tokens2 = tokens + sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy + tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2 + + hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl( + tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2 + ) + return [hidden_states1, hidden_states2, pool2] + + +class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, abs_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(self.get_outputs_npz_path(abs_path)): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(self.get_outputs_npz_path(abs_path)) + if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + hidden_state1 = data["hidden_state1"] + hidden_state2 = data["hidden_state2"] + pool2 = data["pool2"] + return [hidden_state1, hidden_state2, pool2] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy + captions = [info.caption for info in infos] + + tokens1, tokens2 = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, [tokens1, tokens2] + ) + if hidden_state1.dtype == torch.bfloat16: + hidden_state1 = hidden_state1.float() + if hidden_state2.dtype == torch.bfloat16: + hidden_state2 = hidden_state2.float() + if pool2.dtype == torch.bfloat16: + pool2 = pool2.float() + + hidden_state1 = hidden_state1.cpu().numpy() + hidden_state2 = hidden_state2.cpu().numpy() + pool2 = pool2.cpu().numpy() + + for i, info in enumerate(infos): + hidden_state1_i = hidden_state1[i] + hidden_state2_i = hidden_state2[i] + pool2_i = pool2[i] + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + hidden_state1=hidden_state1_i, + hidden_state2=hidden_state2_i, + pool2=pool2_i, + ) + else: + info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i] diff --git a/library/train_util.py b/library/train_util.py index 7af0070e1..a747e0478 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -12,6 +12,7 @@ import shutil import time from typing import ( + Any, Dict, List, NamedTuple, @@ -34,6 +35,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device +from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy init_ipex() @@ -81,10 +83,6 @@ # from library.hypernetwork import replace_attentions_for_hypernetwork from library.original_unet import UNet2DConditionModel -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ - HIGH_VRAM = False # checkpointファイル名 @@ -148,18 +146,24 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.image_size: Tuple[int, int] = None self.resized_size: Tuple[int, int] = None self.bucket_reso: Tuple[int, int] = None - self.latents: torch.Tensor = None - self.latents_flipped: torch.Tensor = None - self.latents_npz: str = None - self.latents_original_size: Tuple[int, int] = None # original image size, not latents size - self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size - self.cond_img_path: str = None + self.latents: Optional[torch.Tensor] = None + self.latents_flipped: Optional[torch.Tensor] = None + self.latents_npz: Optional[str] = None # set in cache_latents + self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size + self.latents_crop_ltrb: Optional[Tuple[int, int]] = ( + None # crop left top right bottom in original pixel size, not latents size + ) + self.cond_img_path: Optional[str] = None self.image: Optional[Image.Image] = None # optional, original PIL Image - # SDXL, optional - self.text_encoder_outputs_npz: Optional[str] = None + self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs + + # new + self.text_encoder_outputs: Optional[List[torch.Tensor]] = None + # old self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None + self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime @@ -359,47 +363,6 @@ def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarra return self.color_aug if use_color_aug else None -class LatentsCachingStrategy: - _strategy = None # strategy instance: actual strategy class - - def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: - self._cache_to_disk = cache_to_disk - self._batch_size = batch_size - self.skip_disk_cache_validity_check = skip_disk_cache_validity_check - - @classmethod - def set_strategy(cls, strategy): - if cls._strategy is not None: - raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") - cls._strategy = strategy - - @classmethod - def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: - return cls._strategy - - @property - def cache_to_disk(self): - return self._cache_to_disk - - @property - def batch_size(self): - return self._batch_size - - def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - raise NotImplementedError - - def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str: - raise NotImplementedError - - def is_disk_cached_latents_expected( - self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool - ) -> bool: - raise NotImplementedError - - def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): - raise NotImplementedError - - class BaseSubset: def __init__( self, @@ -639,17 +602,12 @@ def __eq__(self, other) -> bool: class BaseDataset(torch.utils.data.Dataset): def __init__( self, - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], - max_token_length: int, resolution: Optional[Tuple[int, int]], network_multiplier: float, debug_dataset: bool, ) -> None: super().__init__() - self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] - - self.max_token_length = max_token_length # width/height is used when enable_bucket==False self.width, self.height = (None, None) if resolution is None else resolution self.network_multiplier = network_multiplier @@ -670,8 +628,6 @@ def __init__( self.bucket_no_upscale = None self.bucket_info = None # for metadata - self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2 - self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ self.current_step: int = 0 @@ -690,6 +646,15 @@ def __init__( # caching self.caching_mode = None # None, 'latents', 'text' + + self.tokenize_strategy = None + self.text_encoder_output_caching_strategy = None + self.latents_caching_strategy = None + + def set_current_strategies(self): + self.tokenize_strategy = TokenizeStrategy.get_strategy() + self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() + self.latents_caching_strategy = LatentsCachingStrategy.get_strategy() def set_seed(self, seed): self.seed = seed @@ -979,22 +944,6 @@ def make_buckets(self): for batch_index in range(batch_count): self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) - # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す - #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる - # - # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは - # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう - # # そのためバッチサイズを画像種類までに制限する - # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない? - # # TO DO 正則化画像をepochまたがりで利用する仕組み - # num_of_image_types = len(set(bucket)) - # bucket_batch_size = min(self.batch_size, num_of_image_types) - # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count) - # for batch_index in range(batch_count): - # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) - # ↑ここまで - self.shuffle_buckets() self._length = len(self.buckets_indices) @@ -1027,12 +976,13 @@ def is_text_encoder_output_cacheable(self): ] ) - def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy): + def new_cache_latents(self, model: Any, is_main_process: bool): r""" a brand new method to cache latents. This method caches latents with caching strategy. normal cache_latents method is used by default, but this method is used when caching strategy is specified. """ logger.info("caching latents with caching strategy.") + caching_strategy = LatentsCachingStrategy.get_strategy() image_infos = list(self.image_data.values()) # sort by resolution @@ -1088,7 +1038,7 @@ def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCach logger.info("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) - caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと @@ -1145,6 +1095,56 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc for batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + r""" + a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy. + """ + tokenize_strategy = TokenizeStrategy.get_strategy() + text_encoding_strategy = TextEncodingStrategy.get_strategy() + caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() + batch_size = caching_strategy.batch_size or self.batch_size + + # if cache to disk, don't cache TE outputs in non-main process + if caching_strategy.cache_to_disk and not is_main_process: + return + + logger.info("caching Text Encoder outputs with caching strategy.") + image_infos = list(self.image_data.values()) + + # split by resolution + batches = [] + batch = [] + logger.info("checking cache validity...") + for info in tqdm(image_infos): + te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) + + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + info.text_encoder_outputs_npz = te_out_npz + cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) + if cache_available: # do not add to batch + continue + + batch.append(info) + + # if number of data in batch is enough, flush the batch + if len(batch) >= batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + if len(batches) == 0: + logger.info("no Text Encoder outputs to cache") + return + + # iterate batches + logger.info("caching Text Encoder outputs...") + for batch in tqdm(batches, smoothing=1, total=len(batches)): + # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + caching_strategy.cache_batch_outputs(tokenize_strategy, models, text_encoding_strategy, batch) + # if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype # this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset # to support SD1/2, it needs a flag for v2, but it is postponed @@ -1188,6 +1188,8 @@ def cache_text_encoder_outputs_common( # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") + tokenize_strategy = TokenizeStrategy.get_strategy() + if batch_size is None: batch_size = self.batch_size @@ -1229,7 +1231,7 @@ def cache_text_encoder_outputs_common( input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) batch.append((info, input_ids1, input_ids2)) else: - l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption) + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(info.caption) batch.append((info, l_tokens, g_tokens, t5_tokens)) if len(batch) >= batch_size: @@ -1347,7 +1349,6 @@ def __getitem__(self, index): loss_weights = [] captions = [] input_ids_list = [] - input_ids2_list = [] latents_list = [] alpha_mask_list = [] images = [] @@ -1355,16 +1356,14 @@ def __getitem__(self, index): crop_top_lefts = [] target_sizes_hw = [] flippeds = [] # 変数名が微妙 - text_encoder_outputs1_list = [] - text_encoder_outputs2_list = [] - text_encoder_pool2_list = [] + text_encoder_outputs_list = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] - loss_weights.append( - self.prior_loss_weight if image_info.is_reg else 1.0 - ) # in case of fine tuning, is_reg is always False + + # in case of fine tuning, is_reg is always False + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance @@ -1381,7 +1380,9 @@ def __getitem__(self, index): image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( + self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz) + ) if flipped: latents = flipped_latents alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem @@ -1470,75 +1471,67 @@ def __getitem__(self, index): # captionとtext encoder outputを処理する caption = image_info.caption # default - if image_info.text_encoder_outputs1 is not None: - text_encoder_outputs1_list.append(image_info.text_encoder_outputs1) - text_encoder_outputs2_list.append(image_info.text_encoder_outputs2) - text_encoder_pool2_list.append(image_info.text_encoder_pool2) - captions.append(caption) + + tokenization_required = ( + self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial + ) + text_encoder_outputs = None + input_ids = None + + if image_info.text_encoder_outputs is not None: + # cached + text_encoder_outputs = image_info.text_encoder_outputs elif image_info.text_encoder_outputs_npz is not None: - text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk( + # on disk + text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( image_info.text_encoder_outputs_npz ) - text_encoder_outputs1_list.append(text_encoder_outputs1) - text_encoder_outputs2_list.append(text_encoder_outputs2) - text_encoder_pool2_list.append(text_encoder_pool2) - captions.append(caption) else: - caption = self.process_caption(subset, image_info.caption) - if self.XTI_layers: - caption_layer = [] - for layer in self.XTI_layers: - token_strings_from = " ".join(self.token_strings) - token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) - caption_ = caption.replace(token_strings_from, token_strings_to) - caption_layer.append(caption_) - captions.append(caption_layer) - else: - captions.append(caption) + tokenization_required = True + text_encoder_outputs_list.append(text_encoder_outputs) - if not self.token_padding_disabled: # this option might be omitted in future - # TODO get_input_ids must support SD3 - if self.XTI_layers: - token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) - else: - token_caption = self.get_input_ids(caption, self.tokenizers[0]) - input_ids_list.append(token_caption) + if tokenization_required: + caption = self.process_caption(subset, image_info.caption) + input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension + # if self.XTI_layers: + # caption_layer = [] + # for layer in self.XTI_layers: + # token_strings_from = " ".join(self.token_strings) + # token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + # caption_ = caption.replace(token_strings_from, token_strings_to) + # caption_layer.append(caption_) + # captions.append(caption_layer) + # else: + # captions.append(caption) + + # if not self.token_padding_disabled: # this option might be omitted in future + # # TODO get_input_ids must support SD3 + # if self.XTI_layers: + # token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) + # else: + # token_caption = self.get_input_ids(caption, self.tokenizers[0]) + # input_ids_list.append(token_caption) + + # if len(self.tokenizers) > 1: + # if self.XTI_layers: + # token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) + # else: + # token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) + # input_ids2_list.append(token_caption2) + + input_ids_list.append(input_ids) + captions.append(caption) - if len(self.tokenizers) > 1: - if self.XTI_layers: - token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) - else: - token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) - input_ids2_list.append(token_caption2) + def none_or_stack_elements(tensors_list, converter): + # [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)] + if len(tensors_list) == 0 or tensors_list[0] == None or len(tensors_list[0]) == 0 or tensors_list[0][0] is None: + return None + return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))] example = {} example["loss_weights"] = torch.FloatTensor(loss_weights) - - if len(text_encoder_outputs1_list) == 0: - if self.token_padding_disabled: - # padding=True means pad in the batch - example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids - if len(self.tokenizers) > 1: - example["input_ids2"] = self.tokenizer[1]( - captions, padding=True, truncation=True, return_tensors="pt" - ).input_ids - else: - example["input_ids2"] = None - else: - example["input_ids"] = torch.stack(input_ids_list) - example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None - example["text_encoder_outputs1_list"] = None - example["text_encoder_outputs2_list"] = None - example["text_encoder_pool2_list"] = None - else: - example["input_ids"] = None - example["input_ids2"] = None - # # for assertion - # example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions]) - # example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions]) - example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list) - example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) - example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) + example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor) + example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x) # if one of alpha_masks is not None, we need to replace None with ones none_or_not = [x is None for x in alpha_mask_list] @@ -1652,8 +1645,6 @@ def __init__( self, subsets: Sequence[DreamBoothSubset], batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -1664,7 +1655,7 @@ def __init__( prior_loss_weight: float, debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -1750,10 +1741,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # new caching: get image size from cache files strategy = LatentsCachingStrategy.get_strategy() if strategy is not None: - logger.info("get image size from cache files") + logger.info("get image size from name of cache files") size_set_count = 0 for i, img_path in enumerate(tqdm(img_paths)): - w, h = strategy.get_image_size_from_image_absolute_path(img_path) + w, h = strategy.get_image_size_from_disk_cache_path(img_path) if w is not None and h is not None: sizes[i] = [w, h] size_set_count += 1 @@ -1886,8 +1877,6 @@ def __init__( self, subsets: Sequence[FineTuningSubset], batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -1897,7 +1886,7 @@ def __init__( bucket_no_upscale: bool, debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset) self.batch_size = batch_size @@ -2111,8 +2100,6 @@ def __init__( self, subsets: Sequence[ControlNetSubset], batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -2122,7 +2109,7 @@ def __init__( bucket_no_upscale: bool, debug_dataset: float, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset) db_subsets = [] for subset in subsets: @@ -2160,8 +2147,6 @@ def __init__( self.dreambooth_dataset_delegate = DreamBoothDataset( db_subsets, batch_size, - tokenizer, - max_token_length, resolution, network_multiplier, enable_bucket, @@ -2221,6 +2206,9 @@ def __init__( self.conditioning_image_transforms = IMAGE_TRANSFORMS + def set_current_strategies(self): + return self.dreambooth_dataset_delegate.set_current_strategies() + def make_buckets(self): self.dreambooth_dataset_delegate.make_buckets() self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager @@ -2229,6 +2217,12 @@ def make_buckets(self): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + def new_cache_latents(self, model: Any, is_main_process: bool): + return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process) + + def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process) + def __len__(self): return self.dreambooth_dataset_delegate.__len__() @@ -2314,6 +2308,13 @@ def add_replacement(self, str_from, str_to): # for dataset in self.datasets: # dataset.make_buckets() + def set_text_encoder_output_caching_strategy(self, strategy: TextEncoderOutputsCachingStrategy): + """ + DataLoader is run in multiple processes, so we need to set the strategy manually. + """ + for dataset in self.datasets: + dataset.set_text_encoder_output_caching_strategy(strategy) + def enable_XTI(self, *args, **kwargs): for dataset in self.datasets: dataset.enable_XTI(*args, **kwargs) @@ -2323,10 +2324,10 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) - def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy): + def new_cache_latents(self, model: Any, is_main_process: bool): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.new_cache_latents(is_main_process, strategy) + dataset.new_cache_latents(model, is_main_process) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True @@ -2344,6 +2345,11 @@ def cache_text_encoder_outputs_sd3( tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size ) + def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.new_cache_text_encoder_outputs(models, is_main_process) + def set_caching_mode(self, caching_mode): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) @@ -2358,6 +2364,10 @@ def is_latent_cacheable(self) -> bool: def is_text_encoder_output_cacheable(self) -> bool: return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets]) + def set_current_strategies(self): + for dataset in self.datasets: + dataset.set_current_strategies() + def set_current_epoch(self, epoch): for dataset in self.datasets: dataset.set_current_epoch(epoch) @@ -2411,34 +2421,34 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) # TODO update to use CachingStrategy -def load_latents_from_disk( - npz_path, -) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: - npz = np.load(npz_path) - if "latents" not in npz: - raise ValueError(f"error: npz is old format. please re-generate {npz_path}") - - latents = npz["latents"] - original_size = npz["original_size"].tolist() - crop_ltrb = npz["crop_ltrb"].tolist() - flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None - return latents, original_size, crop_ltrb, flipped_latents, alpha_mask - - -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): - kwargs = {} - if flipped_latents_tensor is not None: - kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() - if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - np.savez( - npz_path, - latents=latents_tensor.float().cpu().numpy(), - original_size=np.array(original_size), - crop_ltrb=np.array(crop_ltrb), - **kwargs, - ) +# def load_latents_from_disk( +# npz_path, +# ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: +# npz = np.load(npz_path) +# if "latents" not in npz: +# raise ValueError(f"error: npz is old format. please re-generate {npz_path}") + +# latents = npz["latents"] +# original_size = npz["original_size"].tolist() +# crop_ltrb = npz["crop_ltrb"].tolist() +# flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None +# alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None +# return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + + +# def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): +# kwargs = {} +# if flipped_latents_tensor is not None: +# kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() +# if alpha_mask is not None: +# kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() +# np.savez( +# npz_path, +# latents=latents_tensor.float().cpu().numpy(), +# original_size=np.array(original_size), +# crop_ltrb=np.array(crop_ltrb), +# **kwargs, +# ) def debug_dataset(train_dataset, show_input_ids=False): @@ -2465,12 +2475,12 @@ def debug_dataset(train_dataset, show_input_ids=False): example = train_dataset[idx] if example["latents"] is not None: logger.info(f"sample has latents from npz file: {example['latents'].size()}") - for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( + for j, (ik, cap, lw, orgsz, crptl, trgsz, flpdz) in enumerate( zip( example["image_keys"], example["captions"], example["loss_weights"], - example["input_ids"], + # example["input_ids"], example["original_sizes_hw"], example["crop_top_lefts"], example["target_sizes_hw"], @@ -2483,10 +2493,10 @@ def debug_dataset(train_dataset, show_input_ids=False): if "network_multipliers" in example: print(f"network multiplier: {example['network_multipliers'][j]}") - if show_input_ids: - logger.info(f"input ids: {iid}") - if "input_ids2" in example: - logger.info(f"input ids2: {example['input_ids2'][j]}") + # if show_input_ids: + # logger.info(f"input ids: {iid}") + # if "input_ids2" in example: + # logger.info(f"input ids2: {example['input_ids2'][j]}") if example["images"] is not None: im = example["images"][j] logger.info(f"image size: {im.size()}") @@ -2555,8 +2565,8 @@ def glob_images_pathlib(dir_path, recursive): class MinimalDataset(BaseDataset): - def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False): - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + def __init__(self, resolution, network_multiplier, debug_dataset=False): + super().__init__(resolution, network_multiplier, debug_dataset) self.num_train_images = 0 # update in subclass self.num_reg_images = 0 # update in subclass @@ -2773,14 +2783,15 @@ def cache_batch_latents( raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk( - info.latents_npz, - latent, - info.latents_original_size, - info.latents_crop_ltrb, - flipped_latent, - alpha_mask, - ) + # save_latents_to_disk( + # info.latents_npz, + # latent, + # info.latents_original_size, + # info.latents_crop_ltrb, + # flipped_latent, + # alpha_mask, + # ) + pass else: info.latents = latent if flip_aug: @@ -4662,33 +4673,6 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): ) -def load_tokenizer(args: argparse.Namespace): - logger.info("prepare tokenizer") - original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH - - tokenizer: CLIPTokenizer = None - if args.tokenizer_cache_dir: - local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) - if os.path.exists(local_tokenizer_path): - logger.info(f"load tokenizer from cache: {local_tokenizer_path}") - tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 - - if tokenizer is None: - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(original_path) - - if hasattr(args, "max_token_length") and args.max_token_length is not None: - logger.info(f"update token length: {args.max_token_length}") - - if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") - tokenizer.save_pretrained(local_tokenizer_path) - - return tokenizer - - def prepare_accelerator(args: argparse.Namespace): """ this function also prepares deepspeed plugin @@ -5550,6 +5534,7 @@ def sample_images_common( ): """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した + TODO Use strategies here """ if steps == 0: diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index ffa0d46de..e9e61af1b 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) -from library import sd3_models, sd3_utils +from library import sd3_models, sd3_utils, strategy_sd3 def get_noise(seed, latent): @@ -145,6 +145,7 @@ def do_sample( parser.add_argument("--clip_g", type=str, required=False) parser.add_argument("--clip_l", type=str, required=False) parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77") parser.add_argument("--prompt", type=str, default="A photo of a cat") # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders parser.add_argument("--negative_prompt", type=str, default="") @@ -247,7 +248,7 @@ def do_sample( # load tokenizers logger.info("Loading tokenizers...") - tokenizer = sd3_models.SD3Tokenizer(use_t5xxl) # combined tokenizer + tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) # load models # logger.info("Create MMDiT from SD3 checkpoint...") @@ -320,12 +321,19 @@ def do_sample( # prepare embeddings logger.info("Encoding prompts...") - # embeds, pooled_embed - lg_out, t5_out, pooled = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl) - cond = torch.cat([lg_out, t5_out], dim=-2), pooled + encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() - lg_out, t5_out, pooled = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl) - neg_cond = torch.cat([lg_out, t5_out], dim=-2), pooled + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt) + lg_out, t5_out, pooled = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + ) + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt) + lg_out, t5_out, pooled = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + ) + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # generate image logger.info("Generating image...") diff --git a/sd3_train.py b/sd3_train.py index f34e47124..617e30271 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -17,7 +17,7 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils +from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3 from library.sdxl_train_util import match_mixed_precision # , sdxl_model_util @@ -69,10 +69,22 @@ def train(args): # not args.train_text_encoder # ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" - # training without text encoder cache is not supported - assert ( - args.cache_text_encoder_outputs - ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" + # # training without text encoder cache is not supported: because T5XXL must be cached + # assert ( + # args.cache_text_encoder_outputs + # ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" + + assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), ( + "when training text encoder, text encoder outputs must not be cached (except for T5XXL)" + + " / text encoderの学習時はtext encoderの出力はキャッシュできません(t5xxlのみキャッシュすることは可能です)" + ) + + if args.use_t5xxl_cache_only and not args.cache_text_encoder_outputs: + logger.warning( + "use_t5xxl_cache_only is enabled, so cache_text_encoder_outputs is automatically enabled." + + " / use_t5xxl_cache_onlyが有効なため、cache_text_encoder_outputsも自動的に有効になります" + ) + args.cache_text_encoder_outputs = True # if args.block_lr: # block_lrs = [float(lr) for lr in args.block_lr.split(",")] @@ -88,17 +100,17 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # load tokenizer - sd3_tokenizer = sd3_models.SD3Tokenizer() - - # prepare caching strategy - if args.new_caching: - latents_caching_strategy = sd3_train_utils.Sd3LatentsCachingStrategy( + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check ) - else: - latents_caching_strategy = None - train_util.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # load tokenizer and prepare tokenize strategy + sd3_tokenizer = sd3_models.SD3Tokenizer(t5xxl_max_length=args.t5xxl_max_token_length) + sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) # データセットを準備する if args.dataset_class is None: @@ -153,6 +165,16 @@ def train(args): train_dataset_group.verify_bucket_reso_steps(8) # TODO これでいいか確認 if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + False, + False, + ) + ) + train_dataset_group.set_current_strategies() train_util.debug_dataset(train_dataset_group, True) return if len(train_dataset_group) == 0: @@ -215,19 +237,8 @@ def train(args): vae.requires_grad_(False) vae.eval() - if not args.new_caching: - vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible - with torch.no_grad(): - train_dataset_group.cache_latents( - vae_wrapper, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - file_suffix="_sd3.npz", - ) - else: - latents_caching_strategy.set_vae(vae) - train_dataset_group.new_cache_latents(accelerator.is_main_process, latents_caching_strategy) + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) @@ -246,60 +257,70 @@ def train(args): t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load) # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + # should be deleted after caching text encoder outputs when not training text encoder + # this strategy should not be used other than this process + text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # 学習を準備する:モデルを適切な状態にする train_clip_l = False train_clip_g = False train_t5xxl = False - # if args.train_text_encoder: - # # TODO each option for two text encoders? - # accelerator.print("enable text encoder training") - # if args.gradient_checkpointing: - # text_encoder1.gradient_checkpointing_enable() - # text_encoder2.gradient_checkpointing_enable() - # lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train - # lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train - # train_clip_l = lr_te1 != 0 - # train_clip_g = lr_te2 != 0 - - # # caching one text encoder output is not supported - # if not train_clip_l: - # text_encoder1.to(weight_dtype) - # if not train_clip_g: - # text_encoder2.to(weight_dtype) - # text_encoder1.requires_grad_(train_clip_l) - # text_encoder2.requires_grad_(train_clip_g) - # text_encoder1.train(train_clip_l) - # text_encoder2.train(train_clip_g) - # else: - clip_l.to(weight_dtype) - clip_g.to(weight_dtype) - clip_l.requires_grad_(False) - clip_g.requires_grad_(False) - clip_l.eval() - clip_g.eval() + if args.train_text_encoder: + accelerator.print("enable text encoder training") + if args.gradient_checkpointing: + clip_l.gradient_checkpointing_enable() + clip_g.gradient_checkpointing_enable() + lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train + train_clip_l = lr_te1 != 0 + train_clip_g = lr_te2 != 0 + + if not train_clip_l: + clip_l.to(weight_dtype) + if not train_clip_g: + clip_g.to(weight_dtype) + clip_l.requires_grad_(train_clip_l) + clip_g.requires_grad_(train_clip_g) + clip_l.train(train_clip_l) + clip_g.train(train_clip_g) + else: + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + clip_l.requires_grad_(False) + clip_g.requires_grad_(False) + clip_l.eval() + clip_g.eval() + if t5xxl is not None: t5xxl.to(t5xxl_dtype) t5xxl.requires_grad_(False) t5xxl.eval() - # TextEncoderの出力をキャッシュする + # cache text encoder outputs if args.cache_text_encoder_outputs: - # Text Encodes are eval and no grad - - with torch.no_grad(), accelerator.autocast(): - train_dataset_group.cache_text_encoder_outputs_sd3( - sd3_tokenizer, - (clip_l, clip_g, t5xxl), - (accelerator.device, accelerator.device, t5xxl_device), - None, - (None, None, None), - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - args.text_encoder_batch_size, - ) + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(t5xxl_device) + + text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + False, + train_clip_g or train_clip_l or args.use_t5xxl_cache_only, + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + clip_l.to(accelerator.device, dtype=weight_dtype) + clip_g.to(accelerator.device, dtype=weight_dtype) + if t5xxl is not None: + t5xxl.to(t5xxl_device, dtype=t5xxl_dtype) - # TODO we can delete text encoders after caching + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process) accelerator.wait_for_everyone() # load MMDIT @@ -332,11 +353,11 @@ def train(args): # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) # if train_clip_l: - # training_models.append(text_encoder1) - # params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + # training_models.append(clip_l) + # params_to_optimize.append({"params": list(clip_l.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) # if train_clip_g: - # training_models.append(text_encoder2) - # params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + # training_models.append(clip_g) + # params_to_optimize.append({"params": list(clip_g.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) # calculate number of trainable parameters n_params = 0 @@ -344,7 +365,7 @@ def train(args): for p in group["params"]: n_params += p.numel() - accelerator.print(f"train mmdit: {train_mmdit}") # , text_encoder1: {train_clip_l}, text_encoder2: {train_clip_g}") + accelerator.print(f"train mmdit: {train_mmdit}") # , clip_l: {train_clip_l}, clip_g: {train_clip_g}") accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of trainable parameters: {n_params}") @@ -398,7 +419,11 @@ def train(args): else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -455,8 +480,8 @@ def train(args): # TODO check if this is necessary. SD3 uses pool for clip_l and clip_g # # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer # if train_clip_l: - # text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) - # text_encoder1.text_model.final_layer_norm.requires_grad_(False) + # clip_l.text_model.encoder.layers[-1].requires_grad_(False) + # clip_l.text_model.final_layer_norm.requires_grad_(False) # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する if args.cache_text_encoder_outputs: @@ -484,9 +509,8 @@ def train(args): ds_model = deepspeed_utils.prepare_deepspeed_model( args, mmdit=mmdit, - # mmdie=mmdit if train_mmdit else None, - # text_encoder1=text_encoder1 if train_clip_l else None, - # text_encoder2=text_encoder2 if train_clip_g else None, + clip_l=clip_l if train_clip_l else None, + clip_g=clip_g if train_clip_g else None, ) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -498,10 +522,10 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい if train_mmdit: mmdit = accelerator.prepare(mmdit) - # if train_clip_l: - # text_encoder1 = accelerator.prepare(text_encoder1) - # if train_clip_g: - # text_encoder2 = accelerator.prepare(text_encoder2) + if train_clip_l: + clip_l = accelerator.prepare(clip_l) + if train_clip_g: + clip_g = accelerator.prepare(clip_g) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -613,7 +637,7 @@ def optimizer_hook(parameter: torch.Tensor): # # For --sample_at_first # sd3_train_utils.sample_images( - # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], mmdit + # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [clip_l, clip_g], mmdit # ) # following function will be moved to sd3_train_utils @@ -666,6 +690,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -687,37 +712,45 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # encode images to latents. images are [-1, 1] latents = vae.encode(batch["images"].to(vae_dtype)).to(weight_dtype) - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + # latents = latents * sdxl_model_util.VAE_SCALE_FACTOR latents = sd3_models.SDVAE.process_in(latents) - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - # not cached, get text encoder outputs - # XXX This does not work yet - input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl = batch["input_ids"] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + lg_out, t5_out, lg_pooled = text_encoder_outputs_list + if args.use_t5xxl_cache_only: + lg_out = None + lg_pooled = None + else: + lg_out = None + t5_out = None + lg_pooled = None + + if lg_out is None or (train_clip_l or train_clip_g): + # not cached or training, so get from text encoders + input_ids_clip_l, input_ids_clip_g, _ = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # TODO support weighted captions - # TODO support length > 75 input_ids_clip_l = input_ids_clip_l.to(accelerator.device) input_ids_clip_g = input_ids_clip_g.to(accelerator.device) - input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) + lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, [clip_l, clip_g, None], [input_ids_clip_l, input_ids_clip_g, None] + ) - # get text encoder outputs: outputs are concatenated - context, pool = sd3_utils.get_cond_from_tokens( - input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl, clip_l, clip_g, t5xxl + if t5_out is None: + _, _, input_ids_t5xxl = batch["input_ids_list"] + with torch.no_grad(): + input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None + _, t5_out, _ = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl] ) - else: - # encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - # encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - # pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) - # TODO this reuses SDXL keys, it should be fixed - lg_out = batch["text_encoder_outputs1_list"] - t5_out = batch["text_encoder_outputs2_list"] - pool = batch["text_encoder_pool2_list"] - context = torch.cat([lg_out, t5_out], dim=-2) + + context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps @@ -748,13 +781,13 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): if torch.any(torch.isnan(context)): accelerator.print("NaN found in context, replacing with zeros") context = torch.nan_to_num(context, 0, out=context) - if torch.any(torch.isnan(pool)): + if torch.any(torch.isnan(lg_pooled)): accelerator.print("NaN found in pool, replacing with zeros") - pool = torch.nan_to_num(pool, 0, out=pool) + lg_pooled = torch.nan_to_num(lg_pooled, 0, out=lg_pooled) # call model with accelerator.autocast(): - model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool) + model_pred = mmdit(noisy_model_input, timesteps, context=context, y=lg_pooled) # Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Preconditioning of the model outputs. @@ -806,7 +839,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # accelerator.device, # vae, # [tokenizer1, tokenizer2], - # [text_encoder1, text_encoder2], + # [clip_l, clip_g], # mmdit, # ) @@ -875,7 +908,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # accelerator.device, # vae, # [tokenizer1, tokenizer2], - # [text_encoder1, text_encoder2], + # [clip_l, clip_g], # mmdit, # ) @@ -924,7 +957,19 @@ def setup_parser() -> argparse.ArgumentParser: custom_train_functions.add_custom_train_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) - # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument( + "--train_text_encoder", action="store_true", help="train text encoder (CLIP-L and G) / text encoderも学習する" + ) + # parser.add_argument("--train_t5xxl", action="store_true", help="train T5-XXL / T5-XXLも学習する") + parser.add_argument( + "--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする" + ) + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=None, + help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256", + ) # TE training is disabled temporarily # parser.add_argument( @@ -962,7 +1007,6 @@ def setup_parser() -> argparse.ArgumentParser: help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", ) - parser.add_argument("--new_caching", action="store_true", help="use new caching method / 新しいキャッシング方法を使う") parser.add_argument( "--skip_latents_validity_check", action="store_true", diff --git a/sdxl_train.py b/sdxl_train.py index ae92d6a3d..b6d4afd6a 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -17,7 +17,7 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, sdxl_model_util +from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl import library.train_util as train_util @@ -124,7 +124,16 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] # will be removed in the future + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -166,10 +175,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2]) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -262,8 +271,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -276,6 +286,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): train_text_encoder1 = False train_text_encoder2 = False + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if args.train_text_encoder: # TODO each option for two text encoders? accelerator.print("enable text encoder training") @@ -307,16 +320,17 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(), accelerator.autocast(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) - accelerator.wait_for_everyone() + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + + accelerator.wait_for_everyone() if not cache_latents: vae.requires_grad_(False) @@ -403,7 +417,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -597,7 +615,7 @@ def optimizer_hook(parameter: torch.Tensor): # For --sample_at_first sdxl_train_util.sample_images( - accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet + accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet ) loss_recorder = train_util.LossRecorder() @@ -628,9 +646,15 @@ def optimizer_hook(parameter: torch.Tensor): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning # TODO support weighted captions @@ -646,39 +670,13 @@ def optimizer_hook(parameter: torch.Tensor): # else: input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - # unwrap_model is fine for models not wrapped by accelerator - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, - accelerator=accelerator, + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) - - # # verify that the text encoder outputs are correct - # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( - # args.max_token_length, - # batch["input_ids"].to(text_encoder1.device), - # batch["input_ids2"].to(text_encoder1.device), - # tokenizer1, - # tokenizer2, - # text_encoder1, - # text_encoder2, - # None if not args.full_fp16 else weight_dtype, - # ) - # b_size = encoder_hidden_states1.shape[0] - # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # logger.info("text encoder outputs verified") + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) # get size embeddings orig_size = batch["original_sizes_hw"] @@ -765,7 +763,7 @@ def optimizer_hook(parameter: torch.Tensor): global_step, accelerator.device, vae, - [tokenizer1, tokenizer2], + tokenizers, [text_encoder1, text_encoder2], unet, ) @@ -847,7 +845,7 @@ def optimizer_hook(parameter: torch.Tensor): global_step, accelerator.device, vae, - [tokenizer1, tokenizer2], + tokenizers, [text_encoder1, text_encoder2], unet, ) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 5ff060a9f..0eaec29b8 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -23,7 +23,16 @@ import accelerate from diffusers import DDPMScheduler, ControlNetModel from safetensors.torch import load_file -from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util +from library import ( + deepspeed_utils, + sai_model_spec, + sdxl_model_util, + sdxl_original_unet, + sdxl_train_util, + strategy_base, + strategy_sd, + strategy_sdxl, +) import library.model_util as model_util import library.train_util as train_util @@ -79,7 +88,14 @@ def train(args): args.seed = random.randint(0, 2**32) set_seed(args.seed) - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) @@ -106,7 +122,7 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) @@ -164,30 +180,30 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + accelerator.wait_for_everyone() # prepare ControlNet-LLLite @@ -242,7 +258,11 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -290,7 +310,7 @@ def train(args): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) if isinstance(unet, DDP): - unet._set_static_graph() # avoid error for multiple use of the parameter + unet._set_static_graph() # avoid error for multiple use of the parameter if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる @@ -357,7 +377,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -409,27 +431,26 @@ def remove_model(old_ckpt_name): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] with torch.no_grad(): - # Get the text embedding for conditioning input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) # get size embeddings orig_size = batch["original_sizes_hw"] diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 83969bb1d..67ccae62c 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,16 +1,21 @@ import argparse import torch +from accelerate import Accelerator from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() -from library import sdxl_model_util, sdxl_train_util, train_util +from library import sdxl_model_util, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, train_util import train_network from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + class SdxlNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() @@ -49,15 +54,32 @@ def load_target_model(self, args, weight_dtype, accelerator): return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet - def load_tokenizer(self, args): - tokenizer = sdxl_train_util.load_tokenizers(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): + return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy - def is_text_encoder_outputs_cached(self, args): - return args.cache_text_encoder_outputs + def get_text_encoding_strategy(self, args): + return strategy_sdxl.SdxlTextEncodingStrategy() + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + return text_encoders + [accelerator.unwrap_model(text_encoders[-1])] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) + else: + return None def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype ): if args.cache_text_encoder_outputs: if not args.lowram: @@ -70,15 +92,13 @@ def cache_text_encoder_outputs_if_needed( clean_memory_on_device(accelerator.device) # When TE is not be trained, it will not be prepared so we need to use explicit autocast + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) with accelerator.autocast(): - dataset.cache_text_encoder_outputs( - tokenizers, - text_encoders, - accelerator.device, - weight_dtype, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, + dataset.new_cache_text_encoder_outputs( + text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process ) + accelerator.wait_for_everyone() text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 5df739e28..cbfcef554 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -5,10 +5,10 @@ import torch from library.device_utils import init_ipex -init_ipex() -from library import sdxl_model_util, sdxl_train_util, train_util +init_ipex() +from library import sdxl_model_util, sdxl_train_util, strategy_sd, strategy_sdxl, train_util import train_textual_inversion @@ -41,28 +41,20 @@ def load_target_model(self, args, weight_dtype, accelerator): return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet - def load_tokenizer(self, args): - tokenizer = sdxl_train_util.load_tokenizers(args) - return tokenizer - - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] - with torch.enable_grad(): - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizers[0], - tokenizers[1], - text_encoders[0], - text_encoders[1], - None if not args.full_fp16 else weight_dtype, - accelerator=accelerator, - ) - return encoder_hidden_states1, encoder_hidden_states2, pool2 + def get_tokenize_strategy(self, args): + return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): + return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sdxl.SdxlTextEncodingStrategy() def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -81,9 +73,11 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images( + self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement + ): sdxl_train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): @@ -122,8 +116,7 @@ def load_weights(self, file): def setup_parser() -> argparse.ArgumentParser: parser = train_textual_inversion.setup_parser() - # don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching - # sdxl_train_util.add_sdxl_training_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False) return parser diff --git a/train_db.py b/train_db.py index 39d8ea6ed..7caee6647 100644 --- a/train_db.py +++ b/train_db.py @@ -11,7 +11,7 @@ from tqdm import tqdm import torch -from library import deepspeed_utils +from library import deepspeed_utils, strategy_base from library.device_utils import init_ipex, clean_memory_on_device @@ -38,6 +38,7 @@ apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments +import library.strategy_sd as strategy_sd setup_logging() import logging @@ -58,7 +59,14 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -80,10 +88,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -145,13 +153,17 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # 学習を準備する:モデルを適切な状態にする train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 @@ -184,8 +196,11 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -290,10 +305,16 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) + accelerator.init_trackers( + "dreambooth" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first - train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -331,7 +352,7 @@ def train(args): with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): if args.weighted_captions: encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, + tokenize_strategy.tokenizer, text_encoder, batch["captions"], accelerator.device, @@ -339,14 +360,18 @@ def train(args): clip_skip=args.clip_skip, ) else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) # Predict the noise residual with accelerator.autocast(): @@ -358,7 +383,9 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -393,7 +420,7 @@ def train(args): global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) # 指定ステップごとにモデルを保存 @@ -457,7 +484,9 @@ def train(args): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) is_main_process = accelerator.is_main_process if is_main_process: diff --git a/train_network.py b/train_network.py index 7ba073855..3828fed19 100644 --- a/train_network.py +++ b/train_network.py @@ -7,6 +7,7 @@ import time import json from multiprocessing import Value +from typing import Any, List import toml from tqdm import tqdm @@ -18,7 +19,7 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, model_util +from library import deepspeed_utils, model_util, strategy_base, strategy_sd import library.train_util as train_util from library.train_util import DreamBoothDataset @@ -101,19 +102,31 @@ def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet - def load_tokenizer(self, args): - tokenizer = train_util.load_tokenizer(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) - def is_text_encoder_outputs_cached(self, args): - return False + def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]: + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sd.SdTextEncodingStrategy(args.clip_skip) + + def get_text_encoder_outputs_caching_strategy(self, args): + return None + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + return text_encoders def is_train_text_encoder(self, args): - return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args) + return not args.network_train_unet_only - def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype - ): + def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype): for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) @@ -123,7 +136,7 @@ def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, wei return encoder_hidden_states def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noise_pred = unet(noisy_latents, timesteps, text_conds).sample + noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred def all_reduce_network(self, accelerator, network): @@ -131,8 +144,8 @@ def all_reduce_network(self, accelerator, network): if param.grad is not None: param.grad = accelerator.reduce(param.grad, reduction="mean") - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet): + train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet) def train(self, args): session_id = random.randint(0, 2**32) @@ -150,9 +163,13 @@ def train(self, args): args.seed = random.randint(0, 2**32) set_seed(args.seed) - # tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため - tokenizer = self.load_tokenizer(args) - tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] + tokenize_strategy = self.get_tokenize_strategy(args) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = self.get_latents_caching_strategy(args) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -194,11 +211,11 @@ def train(self, args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -268,8 +285,9 @@ def train(self, args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -277,9 +295,13 @@ def train(self, args): # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu - self.cache_text_encoder_outputs_if_needed( - args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype - ) + text_encoding_strategy = self.get_text_encoding_strategy(args) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args) + if text_encoder_outputs_caching_strategy is not None: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype) # prepare network net_kwargs = {} @@ -366,7 +388,11 @@ def train(self, args): optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -878,7 +904,7 @@ def remove_model(old_ckpt_name): os.remove(old_ckpt_file) # For --sample_at_first - self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) # training loop if initial_step > 0: # only if skip_until_initial_step is specified @@ -933,21 +959,31 @@ def remove_model(old_ckpt_name): # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + else: + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + # SD only + text_encoder_conds = get_weighted_text_embeddings( + tokenizers[0], + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids, + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -1026,7 +1062,9 @@ def remove_model(old_ckpt_name): progress_bar.update(1) global_step += 1 - self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet + ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -1082,7 +1120,7 @@ def remove_model(old_ckpt_name): if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) # end of epoch diff --git a/train_textual_inversion.py b/train_textual_inversion.py index ade077c36..9044f50df 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -2,6 +2,7 @@ import math import os from multiprocessing import Value +from typing import Any, List import toml from tqdm import tqdm @@ -15,7 +16,7 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer -from library import deepspeed_utils, model_util +from library import deepspeed_utils, model_util, strategy_base, strategy_sd import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -103,28 +104,38 @@ def assert_extra_args(self, args, train_dataset_group): def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) - return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet - def load_tokenizer(self, args): - tokenizer = train_util.load_tokenizer(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]: + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy def assert_token_string(self, token_string, tokenizers: CLIPTokenizer): pass - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - with torch.enable_grad(): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None) - return encoder_hidden_states + def get_text_encoding_strategy(self, args): + return strategy_sd.SdTextEncodingStrategy(args.clip_skip) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders) -> List[Any]: + return text_encoders def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noise_pred = unet(noisy_latents, timesteps, text_conds).sample + noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images( + self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement + ): train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoders[0], unet, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): @@ -182,8 +193,13 @@ def train(self, args): if args.seed is not None: set_seed(args.seed) - tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer - tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] + tokenize_strategy = self.get_tokenize_strategy(args) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = self.get_latents_caching_strategy(args) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # acceleratorを準備する logger.info("prepare accelerator") @@ -194,14 +210,7 @@ def train(self, args): vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む - model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator) - text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list - - if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1: - accelerator.print( - "accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / " - + "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです" - ) + model_version, text_encoders, vae, unet = self.load_target_model(args, weight_dtype, accelerator) # Convert the init_word to token_id init_token_ids_list = [] @@ -310,10 +319,10 @@ def train(self, args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list) + train_dataset_group = train_util.load_arbitrary_dataset(args) self.assert_extra_args(args, train_dataset_group) @@ -368,11 +377,10 @@ def train(self, args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - clean_memory_on_device(accelerator.device) + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() if args.gradient_checkpointing: @@ -387,7 +395,11 @@ def train(self, args): trainable_params += text_encoder.get_input_embeddings().parameters() _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -415,20 +427,8 @@ def train(self, args): lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # acceleratorがなんかよろしくやってくれるらしい - if len(text_encoders) == 1: - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler - ) - - elif len(text_encoders) == 2: - text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler - ) - - text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] - - else: - raise NotImplementedError() + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders] index_no_updates_list = [] orig_embeds_params_list = [] @@ -456,6 +456,9 @@ def train(self, args): else: unet.eval() + text_encoding_strategy = self.get_text_encoding_strategy(args) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) vae.eval() @@ -510,7 +513,9 @@ def train(self, args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) # function for saving/removing @@ -540,8 +545,8 @@ def remove_model(old_ckpt_name): global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) @@ -568,7 +573,12 @@ def remove_model(old_ckpt_name): latents = latents * self.vae_scale_factor # Get the text embedding for conditioning - text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype) + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -588,7 +598,9 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -639,8 +651,8 @@ def remove_model(old_ckpt_name): global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) @@ -722,8 +734,8 @@ def remove_model(old_ckpt_name): global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) From 1a977e847a10975c042c0fdacd871a33c9e93900 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 27 Jul 2024 13:51:50 +0900 Subject: [PATCH 024/163] fix typos --- library/strategy_base.py | 2 +- library/strategy_sd.py | 2 +- library/strategy_sd3.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index 594cca5eb..a99a08290 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -218,7 +218,7 @@ def is_disk_cached_latents_expected( def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): raise NotImplementedError - def _defualt_is_disk_cached_latents_expected( + def _default_is_disk_cached_latents_expected( self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool ): if not self.cache_to_disk: diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 105816145..83ffaa31b 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -125,7 +125,7 @@ def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index 42630ab22..7491e814f 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -177,7 +177,7 @@ def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): From 002d75179ae5a3b165a65c5cf49c00bf8f98e2df Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 29 Jul 2024 23:18:34 +0900 Subject: [PATCH 025/163] sample images for training --- library/sd3_train_utils.py | 348 ++++++++++++++++++++++++++++++++++++- sd3_train.py | 51 +++--- 2 files changed, 367 insertions(+), 32 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 8f99d9474..da0729506 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -1,14 +1,18 @@ import argparse -import glob import math import os -from typing import List, Optional, Tuple, Union +import toml +import json +import time +from typing import Dict, List, Optional, Tuple, Union import torch from safetensors.torch import save_file -from accelerate import Accelerator +from accelerate import Accelerator, PartialState +from tqdm import tqdm +from PIL import Image -from library import sd3_models, sd3_utils, train_util +from library import sd3_models, sd3_utils, strategy_base, train_util from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -276,10 +280,342 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin ) -def sample_images(*args, **kwargs): - return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) +# temporary copied from sd3_minimal_inferece.py +def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): + start = sampling.timestep(sampling.sigma_max) + end = sampling.timestep(sampling.sigma_min) + timesteps = torch.linspace(start, end, steps) + sigs = [] + for x in range(len(timesteps)): + ts = timesteps[x] + sigs.append(sampling.sigma(ts)) + sigs += [0.0] + return torch.FloatTensor(sigs) + + +def max_denoise(model_sampling, sigmas): + max_sigma = float(model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma + + +def do_sample( + height: int, + width: int, + seed: int, + cond: Tuple[torch.Tensor, torch.Tensor], + neg_cond: Tuple[torch.Tensor, torch.Tensor], + mmdit: sd3_models.MMDiT, + steps: int, + guidance_scale: float, + dtype: torch.dtype, + device: str, +): + latent = torch.zeros(1, 16, height // 8, width // 8, device=device) + latent = latent.to(dtype).to(device) + + # noise = get_noise(seed, latent).to(device) + if seed is not None: + generator = torch.manual_seed(seed) + noise = ( + torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu") + .to(latent.dtype) + .to(device) + ) + + model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 + + sigmas = get_sigmas(model_sampling, steps).to(device) + + noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) + + c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype) + y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype) + + x = noise_scaled.to(device).to(dtype) + # print(x.shape) + + with torch.no_grad(): + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] + + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) + + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * guidance_scale + # print(denoised.shape) + + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims + + dt = sigmas[i + 1] - sigma_hat + + # Euler method + x = x + d * dt + x = x.to(dtype) + + return x + + +def load_prompts(prompt_file: str) -> List[Dict]: + # read prompts + if prompt_file.endswith(".txt"): + with open(prompt_file, "r", encoding="utf-8") as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif prompt_file.endswith(".toml"): + with open(prompt_file, "r", encoding="utf-8") as f: + data = toml.load(f) + prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] + elif prompt_file.endswith(".json"): + with open(prompt_file, "r", encoding="utf-8") as f: + prompts = json.load(f) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + return prompts + + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + mmdit, + vae, + text_encoders, + sample_prompts_te_outputs, + prompt_replacement=None, +): + if steps == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts): + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap unet and text_encoder(s) + mmdit = accelerator.unwrap_model(mmdit) + text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + org_vae_device = vae.device # will be on cpu + vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + mmdit, + text_encoders, + vae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + mmdit, + text_encoders, + vae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + vae.to(org_vae_device) + + clean_memory_on_device(accelerator.device) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + mmdit: sd3_models.MMDiT, + text_encoders: List[Union[sd3_models.SDClipModel, sd3_models.SDXLClipG, sd3_models.T5XXLModel]], + vae: sd3_models.SDVAE, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, +): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + # controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + if negative_prompt is None: + negative_prompt = "" + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: + te_outputs = sample_prompts_te_outputs[prompt] + else: + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(prompt) + te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) + + lg_out, t5_out, pooled = te_outputs + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # encode negative prompts + if sample_prompts_te_outputs and negative_prompt in sample_prompts_te_outputs: + neg_te_outputs = sample_prompts_te_outputs[negative_prompt] + else: + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(negative_prompt) + neg_te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) + + lg_out, t5_out, pooled = neg_te_outputs + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # sample image + latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device) + latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype)) + + # latent to image + with torch.no_grad(): + image = vae.decode(latents) + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) + decoded_np = decoded_np.astype(np.uint8) + + image = Image.fromarray(decoded_np) + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass + # region Diffusers diff --git a/sd3_train.py b/sd3_train.py index 617e30271..2f4ea8cb2 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -299,6 +299,7 @@ def train(args): t5xxl.eval() # cache text encoder outputs + sample_prompts_te_outputs = None if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad here clip_l.to(accelerator.device) @@ -321,6 +322,22 @@ def train(args): with accelerator.autocast(): train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + prompts = sd3_train_utils.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_list = sd3_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_list + ) + accelerator.wait_for_everyone() # load MMDIT @@ -635,10 +652,8 @@ def optimizer_hook(parameter: torch.Tensor): init_kwargs=init_kwargs, ) - # # For --sample_at_first - # sd3_train_utils.sample_images( - # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [clip_l, clip_g], mmdit - # ) + # For --sample_at_first + sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) # following function will be moved to sd3_train_utils @@ -831,17 +846,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): progress_bar.update(1) global_step += 1 - # sdxl_train_util.sample_images( - # accelerator, - # args, - # None, - # global_step, - # accelerator.device, - # vae, - # [tokenizer1, tokenizer2], - # [clip_l, clip_g], - # mmdit, - # ) + sd3_train_utils.sample_images( + accelerator, args, None, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs + ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -900,17 +907,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): vae, ) - # sdxl_train_util.sample_images( - # accelerator, - # args, - # epoch + 1, - # global_step, - # accelerator.device, - # vae, - # [tokenizer1, tokenizer2], - # [clip_l, clip_g], - # mmdit, - # ) + sd3_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs + ) is_main_process = accelerator.is_main_process # if is_main_process: From 231df197ddf4372b3d90751146927f33e1965d1a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 5 Aug 2024 20:26:30 +0900 Subject: [PATCH 026/163] Fix npz path for verification --- library/strategy_sdxl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index a4513336d..3eb0ab6f6 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -184,20 +184,20 @@ def __init__( def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX - def is_disk_cached_outputs_expected(self, abs_path: str): + def is_disk_cached_outputs_expected(self, npz_path: str): if not self.cache_to_disk: return False - if not os.path.exists(self.get_outputs_npz_path(abs_path)): + if not os.path.exists(npz_path): return False if self.skip_disk_cache_validity_check: return True try: - npz = np.load(self.get_outputs_npz_path(abs_path)) + npz = np.load(npz_path) if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: return False except Exception as e: - logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + logger.error(f"Error loading file: {npz_path}") raise e return True From da4d0fe0165b3e0143c237de8cf307d53a9de45a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 5 Aug 2024 20:51:34 +0900 Subject: [PATCH 027/163] support attn mask for l+g/t5 --- library/strategy_sd3.py | 88 +++++++++++++++++++++++++++++++++------- library/train_util.py | 3 +- sd3_minimal_inference.py | 10 +++-- sd3_train.py | 30 +++++++++++--- 4 files changed, 107 insertions(+), 24 deletions(-) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index 7491e814f..a22818903 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -37,11 +37,14 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + l_attn_mask = l_tokens["attention_mask"] + g_attn_mask = g_tokens["attention_mask"] + t5_attn_mask = t5_tokens["attention_mask"] l_tokens = l_tokens["input_ids"] g_tokens = g_tokens["input_ids"] t5_tokens = t5_tokens["input_ids"] - return [l_tokens, g_tokens, t5_tokens] + return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask] class Sd3TextEncodingStrategy(TextEncodingStrategy): @@ -49,11 +52,20 @@ def __init__(self) -> None: pass def encode_tokens( - self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: List[torch.Tensor], + apply_lg_attn_mask: bool = False, + apply_t5_attn_mask: bool = False, ) -> List[torch.Tensor]: + """ + returned embeddings are not masked + """ clip_l, clip_g, t5xxl = models - l_tokens, g_tokens, t5_tokens = tokens + l_tokens, g_tokens, t5_tokens = tokens[:3] + l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] if len(tokens) > 3 else [None, None, None] if l_tokens is None: assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None @@ -61,10 +73,15 @@ def encode_tokens( assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" l_out, l_pooled = clip_l(l_tokens) g_out, g_pooled = clip_g(g_tokens) + if apply_lg_attn_mask: + l_out = l_out * l_attn_mask.to(l_out.device).unsqueeze(-1) + g_out = g_out * g_attn_mask.to(g_out.device).unsqueeze(-1) lg_out = torch.cat([l_out, g_out], dim=-1) if t5xxl is not None and t5_tokens is not None: t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096] + if apply_t5_attn_mask: + t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) else: t5_out = None @@ -84,50 +101,81 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz" def __init__( - self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + apply_lg_attn_mask: bool = False, + apply_t5_attn_mask: bool = False, ) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + self.apply_lg_attn_mask = apply_lg_attn_mask + self.apply_t5_attn_mask = apply_t5_attn_mask def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX - def is_disk_cached_outputs_expected(self, abs_path: str): + def is_disk_cached_outputs_expected(self, npz_path: str): if not self.cache_to_disk: return False - if not os.path.exists(self.get_outputs_npz_path(abs_path)): + if not os.path.exists(npz_path): return False if self.skip_disk_cache_validity_check: return True try: - npz = np.load(self.get_outputs_npz_path(abs_path)) - if "clip_l" not in npz or "clip_g" not in npz: + npz = np.load(npz_path) + if "lg_out" not in npz: return False - if "clip_l_pool" not in npz or "clip_g_pool" not in npz: + if "lg_pooled" not in npz: + return False + if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used return False # t5xxl is optional except Exception as e: - logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + logger.error(f"Error loading file: {npz_path}") raise e return True + def mask_lg_attn(self, lg_out: np.ndarray, l_attn_mask: np.ndarray, g_attn_mask: np.ndarray) -> np.ndarray: + l_out = lg_out[..., :768] + g_out = lg_out[..., 768:] # 1280 + l_out = l_out * np.expand_dims(l_attn_mask, -1) # l_out = l_out * l_attn_mask. + g_out = g_out * np.expand_dims(g_attn_mask, -1) # g_out = g_out * g_attn_mask. + return np.concatenate([l_out, g_out], axis=-1) + + def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray: + return t5_out * np.expand_dims(t5_attn_mask, -1) + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: data = np.load(npz_path) lg_out = data["lg_out"] lg_pooled = data["lg_pooled"] t5_out = data["t5_out"] if "t5_out" in data else None + + if self.apply_lg_attn_mask: + l_attn_mask = data["clip_l_attn_mask"] + g_attn_mask = data["clip_g_attn_mask"] + lg_out = self.mask_lg_attn(lg_out, l_attn_mask, g_attn_mask) + + if self.apply_t5_attn_mask and t5_out is not None: + t5_attn_mask = data["t5_attn_mask"] + t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) + return [lg_out, t5_out, lg_pooled] def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List ): + sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy captions = [info.caption for info in infos] - clip_l_tokens, clip_g_tokens, t5xxl_tokens = tokenize_strategy.tokenize(captions) + tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - lg_out, t5_out, lg_pooled = text_encoding_strategy.encode_tokens( - tokenize_strategy, models, [clip_l_tokens, clip_g_tokens, t5xxl_tokens] + lg_out, t5_out, lg_pooled = sd3_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask ) if lg_out.dtype == torch.bfloat16: @@ -148,10 +196,22 @@ def cache_batch_outputs( lg_pooled_i = lg_pooled[i] if self.cache_to_disk: + clip_l_attn_mask, clip_g_attn_mask, t5_attn_mask = tokens_and_masks[3:6] + clip_l_attn_mask_i = clip_l_attn_mask[i].cpu().numpy() + clip_g_attn_mask_i = clip_g_attn_mask[i].cpu().numpy() + t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() if t5_attn_mask is not None else None # shouldn't be None kwargs = {} if t5_out is not None: kwargs["t5_out"] = t5_out_i - np.savez(info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, **kwargs) + np.savez( + info.text_encoder_outputs_npz, + lg_out=lg_out_i, + lg_pooled=lg_pooled_i, + clip_l_attn_mask=clip_l_attn_mask_i, + clip_g_attn_mask=clip_g_attn_mask_i, + t5_attn_mask=t5_attn_mask_i, + **kwargs, + ) else: info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i) diff --git a/library/train_util.py b/library/train_util.py index a747e0478..fc458a884 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -646,7 +646,7 @@ def __init__( # caching self.caching_mode = None # None, 'latents', 'text' - + self.tokenize_strategy = None self.text_encoder_output_caching_strategy = None self.latents_caching_strategy = None @@ -1486,6 +1486,7 @@ def __getitem__(self, index): text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( image_info.text_encoder_outputs_npz ) + text_encoder_outputs = [torch.FloatTensor(x) for x in text_encoder_outputs] else: tokenization_required = True text_encoder_outputs_list.append(text_encoder_outputs) diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index e9e61af1b..630da7e08 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -146,6 +146,8 @@ def do_sample( parser.add_argument("--clip_l", type=str, required=False) parser.add_argument("--t5xxl", type=str, required=False) parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77") + parser.add_argument("--apply_lg_attn_mask", action="store_true") + parser.add_argument("--apply_t5_attn_mask", action="store_true") parser.add_argument("--prompt", type=str, default="A photo of a cat") # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders parser.add_argument("--negative_prompt", type=str, default="") @@ -323,15 +325,15 @@ def do_sample( logger.info("Encoding prompts...") encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() - l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt) + tokens_and_masks = tokenize_strategy.tokenize(args.prompt) lg_out, t5_out, pooled = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask ) cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) - l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt) + tokens_and_masks = tokenize_strategy.tokenize(args.negative_prompt) lg_out, t5_out, pooled = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask ) neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) diff --git a/sd3_train.py b/sd3_train.py index 2f4ea8cb2..9c37cbce6 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -172,6 +172,8 @@ def train(args): args.text_encoder_batch_size, False, False, + False, + False, ) ) train_dataset_group.set_current_strategies() @@ -312,6 +314,8 @@ def train(args): args.text_encoder_batch_size, False, train_clip_g or train_clip_l or args.use_t5xxl_cache_only, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) @@ -335,7 +339,11 @@ def train(args): logger.info(f"cache Text Encoder outputs for prompt: {p}") tokens_list = sd3_tokenize_strategy.tokenize(p) sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - sd3_tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_list + sd3_tokenize_strategy, + [clip_l, clip_g, t5xxl], + tokens_list, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, ) accelerator.wait_for_everyone() @@ -748,21 +756,23 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): if lg_out is None or (train_clip_l or train_clip_g): # not cached or training, so get from text encoders - input_ids_clip_l, input_ids_clip_g, _ = batch["input_ids_list"] + input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # TODO support weighted captions input_ids_clip_l = input_ids_clip_l.to(accelerator.device) input_ids_clip_g = input_ids_clip_g.to(accelerator.device) lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( - sd3_tokenize_strategy, [clip_l, clip_g, None], [input_ids_clip_l, input_ids_clip_g, None] + sd3_tokenize_strategy, + [clip_l, clip_g, None], + [input_ids_clip_l, input_ids_clip_g, None, l_attn_mask, g_attn_mask, None], ) if t5_out is None: - _, _, input_ids_t5xxl = batch["input_ids_list"] + _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] with torch.no_grad(): input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None _, t5_out, _ = text_encoding_strategy.encode_tokens( - sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl] + sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] ) context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) @@ -969,6 +979,16 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256", ) + parser.add_argument( + "--apply_lg_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + ) # TE training is disabled temporarily # parser.add_argument( From 36b2e6fc288c57f496a061e4d638f5641c32c9ea Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 9 Aug 2024 22:56:48 +0900 Subject: [PATCH 028/163] add FLUX.1 LoRA training --- README.md | 20 + flux_minimal_inference.py | 390 ++++++++++++++++ flux_train_network.py | 332 ++++++++++++++ library/flux_models.py | 920 ++++++++++++++++++++++++++++++++++++++ library/flux_utils.py | 215 +++++++++ library/sd3_models.py | 22 +- library/strategy_flux.py | 244 ++++++++++ networks/lora_flux.py | 730 ++++++++++++++++++++++++++++++ sdxl_train_network.py | 5 + train_network.py | 169 ++++--- 10 files changed, 2992 insertions(+), 55 deletions(-) create mode 100644 flux_minimal_inference.py create mode 100644 flux_train_network.py create mode 100644 library/flux_models.py create mode 100644 library/flux_utils.py create mode 100644 library/strategy_flux.py create mode 100644 networks/lora_flux.py diff --git a/README.md b/README.md index d406fecde..a0b02f108 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,25 @@ This repository contains training, generation and utility scripts for Stable Diffusion. +## FLUX.1 LoRA training (WIP) + +__Aug 9, 2024__: + +Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. + +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. + +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name +``` + +The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. + +``` +python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors +``` + +Unfortnately the training result is not good. Please let us know if you have any idea to improve the training. + ## SD3 training SD3 training is done with `sd3_train.py`. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py new file mode 100644 index 000000000..f3affca80 --- /dev/null +++ b/flux_minimal_inference.py @@ -0,0 +1,390 @@ +# Minimum Inference Code for FLUX + +import argparse +import datetime +import math +import os +import random +from typing import Callable, Optional, Tuple +import einops +import numpy as np + +import torch +from safetensors.torch import safe_open, load_file +from tqdm import tqdm +from PIL import Image +import accelerate + +from library import device_utils +from library.device_utils import init_ipex, get_preferred_device + +init_ipex() + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import networks.lora_flux as lora_flux +from library import flux_models, flux_utils, sd3_utils, strategy_flux + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + vec: torch.Tensor, + timesteps: list[float], + guidance: float = 4.0, +): + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec) + + img = img + (t_prev - t_curr) * pred + + return img + + +def do_sample( + accelerator: Optional[accelerate.Accelerator], + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + l_pooled: torch.Tensor, + t5_out: torch.Tensor, + txt_ids: torch.Tensor, + num_steps: int, + guidance: float, + is_schnell: bool, + device: torch.device, + flux_dtype: torch.dtype, +): + timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) + + # denoise initial noise + if accelerator: + with accelerator.autocast(), torch.no_grad(): + x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance) + else: + with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): + x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance) + + return x + + +def generate_image( + model, + clip_l, + t5xxl, + ae, + prompt: str, + seed: Optional[int], + image_width: int, + image_height: int, + steps: Optional[int], + guidance: float, +): + # make first noise with packed shape + # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 + packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=device, + dtype=dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + # prepare img and img ids + + # this is needed only for img2img + # img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + # if img.shape[0] == 1 and bs > 1: + # img = repeat(img, "1 ... -> bs ...", bs=bs) + + # txt2img only needs img_ids + img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) + + # prepare embeddings + logger.info("Encoding prompts...") + tokens_and_masks = tokenize_strategy.tokenize(prompt) + clip_l = clip_l.to(device) + t5xxl = t5xxl.to(device) + with torch.no_grad(): + if is_fp8(clip_l_dtype) or is_fp8(t5xxl_dtype): + clip_l.to(clip_l_dtype) + t5xxl.to(t5xxl_dtype) + with accelerator.autocast(): + _, t5_out, txt_ids = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + else: + with torch.autocast(device_type=device.type, dtype=clip_l_dtype): + l_pooled, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): + _, t5_out, txt_ids = encoding_strategy.encode_tokens( + tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + + # NaN check + if torch.isnan(l_pooled).any(): + raise ValueError("NaN in l_pooled") + if torch.isnan(t5_out).any(): + raise ValueError("NaN in t5_out") + + if args.offload: + clip_l = clip_l.cpu() + t5xxl = t5xxl.cpu() + # del clip_l, t5xxl + device_utils.clean_memory() + + # generate image + logger.info("Generating image...") + model = model.to(device) + if steps is None: + steps = 4 if is_schnell else 50 + + img_ids = img_ids.to(device) + x = do_sample( + accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance_scale, is_schnell, device, flux_dtype + ) + if args.offload: + model = model.cpu() + # del model + device_utils.clean_memory() + + # unpack + x = x.float() + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + + # decode + logger.info("Decoding image...") + ae = ae.to(device) + with torch.no_grad(): + if is_fp8(ae_dtype): + with accelerator.autocast(): + x = ae.decode(x) + else: + with torch.autocast(device_type=device.type, dtype=ae_dtype): + x = ae.decode(x) + if args.offload: + ae = ae.cpu() + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + + # save image + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") + img.save(output_path) + + logger.info(f"Saved image to {output_path}") + + +if __name__ == "__main__": + target_height = 768 # 1024 + target_width = 1360 # 1024 + + # steps = 50 # 28 # 50 + # guidance_scale = 5 + # seed = 1 # None # 1 + + device = get_preferred_device() + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--clip_l", type=str, required=False) + parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--ae", type=str, required=False) + parser.add_argument("--apply_t5_attn_mask", action="store_true") + parser.add_argument("--prompt", type=str, default="A photo of a cat") + parser.add_argument("--output_dir", type=str, default=".") + parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype") + parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l") + parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae") + parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl") + parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux") + parser.add_argument("--seed", type=int, default=None) + parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev") + parser.add_argument("--guidance", type=float, default=3.5) + parser.add_argument("--offload", action="store_true", help="Offload to CPU") + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)", + ) + parser.add_argument("--width", type=int, default=target_width) + parser.add_argument("--height", type=int, default=target_height) + parser.add_argument("--interactive", action="store_true") + args = parser.parse_args() + + seed = args.seed + steps = args.steps + guidance_scale = args.guidance + + name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way + is_schnell = name == "schnell" + + def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: + if s is None: + return default_dtype + if s in ["bf16", "bfloat16"]: + return torch.bfloat16 + elif s in ["fp16", "float16"]: + return torch.float16 + elif s in ["fp32", "float32"]: + return torch.float32 + elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: + return torch.float8_e4m3fn + elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: + return torch.float8_e4m3fnuz + elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: + return torch.float8_e5m2 + elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: + return torch.float8_e5m2fnuz + elif s in ["fp8", "float8"]: + return torch.float8_e4m3fn # default fp8 + else: + raise ValueError(f"Unsupported dtype: {s}") + + def is_fp8(dt): + return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] + + dtype = str_to_dtype(args.dtype) + clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype) + t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype) + ae_dtype = str_to_dtype(args.ae_dtype, dtype) + flux_dtype = str_to_dtype(args.flux_dtype, dtype) + + logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}") + + loading_device = "cpu" if args.offload else device + + use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]] + if any(use_fp8): + accelerator = accelerate.Accelerator(mixed_precision="bf16") + else: + accelerator = None + + # load clip_l + logger.info(f"Loading clip_l from {args.clip_l}...") + clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device) + clip_l.eval() + + logger.info(f"Loading t5xxl from {args.t5xxl}...") + t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) + t5xxl.eval() + + if is_fp8(clip_l_dtype): + clip_l = accelerator.prepare(clip_l) + if is_fp8(t5xxl_dtype): + t5xxl = accelerator.prepare(t5xxl) + + t5xxl_max_length = 256 if is_schnell else 512 + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) + encoding_strategy = strategy_flux.FluxTextEncodingStrategy() + + # DiT + model = flux_utils.load_flow_model(name, args.ckpt_path, flux_dtype, loading_device) + model.eval() + logger.info(f"Casting model to {flux_dtype}") + model.to(flux_dtype) # make sure model is dtype + if is_fp8(flux_dtype): + model = accelerator.prepare(model) + + # AE + ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device) + ae.eval() + if is_fp8(ae_dtype): + ae = accelerator.prepare(ae) + + # LoRA + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + lora_model, weights_sd = lora_flux.create_network_from_weights( + multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True + ) + lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + + if not args.interactive: + generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) + else: + # loop for interactive + width = target_width + height = target_height + steps = None + guidance = args.guidance + + while True: + print("Enter prompt (empty to exit). Options: --w --h --s --d --g ") + prompt = input() + if prompt == "": + break + + # parse options + options = prompt.split("--") + prompt = options[0].strip() + seed = None + for opt in options[1:]: + opt = opt.strip() + if opt.startswith("w"): + width = int(opt[1:].strip()) + elif opt.startswith("h"): + height = int(opt[1:].strip()) + elif opt.startswith("s"): + steps = int(opt[1:].strip()) + elif opt.startswith("d"): + seed = int(opt[1:].strip()) + elif opt.startswith("g"): + guidance = float(opt[1:].strip()) + + generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance) + + logger.info("Done!") diff --git a/flux_train_network.py b/flux_train_network.py new file mode 100644 index 000000000..7c762c86d --- /dev/null +++ b/flux_train_network.py @@ -0,0 +1,332 @@ +import argparse +import copy +import math +import random +from typing import Any + +import torch +from accelerate import Accelerator +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from library import flux_models, flux_utils, sd3_train_utils, sd3_utils, sdxl_model_util, sdxl_train_util, strategy_flux, train_util +import train_network +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class FluxNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + assert ( + args.network_train_unet_only or not args.cache_text_encoder_outputs + ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + + train_dataset_group.verify_bucket_reso_steps(32) + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") + clip_l.eval() + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") + t5xxl.eval() + + name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way + # if we load to cpu, flux.to(fp8) takes a long time + model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + + def get_tokenize_strategy(self, args): + return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_flux.FluxTextEncodingStrategy() + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + return strategy_flux.FluxTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) + accelerator.wait_for_everyone() + + logger.info("move text encoders back to cpu") + text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU + text_encoders[1].to("cpu") # , dtype=torch.float32) + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + + # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # # get size embeddings + # orig_size = batch["original_sizes_hw"] + # crop_size = batch["crop_top_lefts"] + # target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # # concat embeddings + # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + # return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): + # logger.warning("Sampling images is not supported for Flux model") + pass + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images).latent_dist.sample() + + def shift_scale_latents(self, args, latents): + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # copy from sd3_train.py and modified + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = self.noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = self.noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None + ): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) + + # Add noise according to flow matching. + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance + guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + img_ids.requires_grad_(True) + guidance_vec.requires_grad_(True) + + # Predict the noise residual + l_pooled, t5_out, txt_ids = text_encoder_conds + # print( + # f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}" + # ) + + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + return model_pred, target, timesteps, None, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + # sdxl_train_util.add_sdxl_training_arguments(parser) + parser.add_argument("--clip_l", type=str, help="path to clip_l") + parser.add_argument("--t5xxl", type=str, help="path to t5xxl") + parser.add_argument("--ae", type=str, help="path to ae") + parser.add_argument("--apply_t5_attn_mask", action="store_true") + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + + # copy from Diffusers + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = FluxNetworkTrainer() + trainer.train(args) diff --git a/library/flux_models.py b/library/flux_models.py new file mode 100644 index 000000000..d0955e375 --- /dev/null +++ b/library/flux_models.py @@ -0,0 +1,920 @@ +# copy from FLUX repo: https://github.com/black-forest-labs/flux +# license: Apache-2.0 License + + +from dataclasses import dataclass +import math + +import torch +from einops import rearrange +from torch import Tensor, nn +from torch.utils.checkpoint import checkpoint + +# USE_REENTRANT = True + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +# region autoencoder + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) + + +# endregion +# region config + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + # repo_id: str | None + # repo_flow: str | None + # repo_ae: str | None + + +configs = { + "dev": ModelSpec( + # repo_id="black-forest-labs/FLUX.1-dev", + # repo_flow="flux1-dev.sft", + # repo_ae="ae.sft", + ckpt_path=None, # os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "schnell": ModelSpec( + # repo_id="black-forest-labs/FLUX.1-schnell", + # repo_flow="flux1-schnell.sft", + # repo_ae="ae.sft", + ckpt_path=None, # os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +# endregion + +# region math + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +# endregion + + +# region layers +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + # def forward(self, x): + # if self.training and self.gradient_checkpointing: + # def create_custom_forward(func): + # def custom_forward(*inputs): + # return func(*inputs) + # return custom_forward + # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, use_reentrant=USE_REENTRANT) + # else: + # return self._forward(x) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + # return (x * rrms).to(dtype=x_dtype) * self.scale + return ((x * rrms) * self.scale.float()).to(dtype=x_dtype) + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + # self.gradient_checkpointing = False + + # def enable_gradient_checkpointing(self): + # self.gradient_checkpointing = True + + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + # def forward(self, *args, **kwargs): + # if self.training and self.gradient_checkpointing: + # return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + # else: + # return self._forward(*args, **kwargs) + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + # self.img_attn.enable_gradient_checkpointing() + # self.txt_attn.enable_gradient_checkpointing() + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + # self.img_attn.disable_gradient_checkpointing() + # self.txt_attn.disable_gradient_checkpointing() + + def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + return img, txt + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + # def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): + # if self.training and self.gradient_checkpointing: + # def create_custom_forward(func): + # def custom_forward(*inputs): + # return func(*inputs) + # return custom_forward + # return torch.utils.checkpoint.checkpoint( + # create_custom_forward(self._forward), img, txt, vec, pe, use_reentrant=USE_REENTRANT + # ) + # else: + # return self._forward(img, txt, vec, pe) + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + # def forward(self, x: Tensor, vec: Tensor, pe: Tensor): + # if self.training and self.gradient_checkpointing: + # def create_custom_forward(func): + # def custom_forward(*inputs): + # return func(*inputs) + # return custom_forward + # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=USE_REENTRANT) + # else: + # return self._forward(x, vec, pe) + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +# endregion + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/library/flux_utils.py b/library/flux_utils.py new file mode 100644 index 000000000..ba828d508 --- /dev/null +++ b/library/flux_utils.py @@ -0,0 +1,215 @@ +import json +from typing import Union +import einops +import torch + +from safetensors.torch import load_file +from accelerate import init_empty_weights +from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config + +from library import flux_models + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +MODEL_VERSION_FLUX_V1 = "flux1" + + +def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux: + logger.info(f"Bulding Flux model {name}") + with torch.device("meta"): + model = flux_models.Flux(flux_models.configs[name].params).to(dtype) + + # load_sft doesn't support torch.device + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_file(ckpt_path, device=str(device)) + info = model.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Flux: {info}") + return model + + +def load_ae(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.AutoEncoder: + logger.info("Building AutoEncoder") + with torch.device("meta"): + ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_file(ckpt_path, device=str(device)) + info = ae.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded AE: {info}") + return ae + + +def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> CLIPTextModel: + logger.info("Building CLIP") + CLIPL_CONFIG = { + "_name_or_path": "clip-vit-large-patch14/", + "architectures": ["CLIPModel"], + "initializer_factor": 1.0, + "logit_scale_init_value": 2.6592, + "model_type": "clip", + "projection_dim": 768, + # "text_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "attention_dropout": 0.0, + "bad_words_ids": None, + "bos_token_id": 0, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": False, + "dropout": 0.0, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": 2, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "layer_norm_eps": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 77, + "min_length": 0, + "model_type": "clip_text_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 12, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 12, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": 1, + "prefix": None, + "problem_type": None, + "projection_dim": 768, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "sep_token_id": None, + "task_specific_params": None, + "temperature": 1.0, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": None, + "torchscript": False, + "transformers_version": "4.16.0.dev0", + "use_bfloat16": False, + "vocab_size": 49408, + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, + # }, + # "text_config_dict": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "projection_dim": 768, + # }, + # "torch_dtype": "float32", + # "transformers_version": None, + } + config = CLIPConfig(**CLIPL_CONFIG) + with init_empty_weights(): + clip = CLIPTextModel._from_config(config) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_file(ckpt_path, device=str(device)) + info = clip.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded CLIP: {info}") + return clip + + +def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> T5EncoderModel: + T5_CONFIG_JSON = """ +{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.41.2", + "use_cache": true, + "vocab_size": 32128 +} +""" + config = json.loads(T5_CONFIG_JSON) + config = T5Config(**config) + with init_empty_weights(): + t5xxl = T5EncoderModel._from_config(config) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_file(ckpt_path, device=str(device)) + info = t5xxl.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded T5xxl: {info}") + return t5xxl + + +def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): + img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :] + img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + return img_ids + + +def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: + """ + x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + return x + + +def pack_latents(x: torch.Tensor) -> torch.Tensor: + """ + x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + return x diff --git a/library/sd3_models.py b/library/sd3_models.py index 28378c73b..ec704dcba 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -15,6 +15,12 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from transformers import CLIPTokenizer, T5TokenizerFast +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) memory_efficient_attention = None @@ -95,7 +101,9 @@ def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) # truncate to max_length - print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}") + print( + f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}" + ) if truncate_to_max_length and len(batch) > self.max_length: batch = batch[: self.max_length] if truncate_length is not None and len(batch) > truncate_length: @@ -1554,6 +1562,17 @@ def __init__( self.set_clip_options({"layer": layer_idx}) self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def gradient_checkpointing_enable(self): + logger.warning("Gradient checkpointing is not supported for this model") + def set_attn_mode(self, mode): raise NotImplementedError("This model does not support setting the attention mode") @@ -1925,6 +1944,7 @@ def create_clip_l(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[s return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG, ) + clip_l.gradient_checkpointing_enable() if state_dict is not None: # update state_dict if provided to include logit_scale and text_projection.weight avoid errors if "logit_scale" not in state_dict: diff --git a/library/strategy_flux.py b/library/strategy_flux.py new file mode 100644 index 000000000..f194ccf6e --- /dev/null +++ b/library/strategy_flux.py @@ -0,0 +1,244 @@ +import os +import glob +from typing import Any, List, Optional, Tuple, Union +import torch +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast + +from library import sd3_utils, train_util +from library import sd3_models +from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14" +T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" + + +class FluxTokenizeStrategy(TokenizeStrategy): + def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None: + self.t5xxl_max_length = t5xxl_max_length + self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + + l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + + t5_attn_mask = t5_tokens["attention_mask"] + l_tokens = l_tokens["input_ids"] + t5_tokens = t5_tokens["input_ids"] + + return [l_tokens, t5_tokens, t5_attn_mask] + + +class FluxTextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + pass + + def encode_tokens( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: List[torch.Tensor], + apply_t5_attn_mask: bool = False, + ) -> List[torch.Tensor]: + # supports single model inference only + + clip_l, t5xxl = models + l_tokens, t5_tokens = tokens[:2] + t5_attn_mask = tokens[2] if len(tokens) > 2 else None + + if clip_l is not None and l_tokens is not None: + l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"] + else: + l_pooled = None + + if t5xxl is not None and t5_tokens is not None: + # t5_out is [1, max length, 4096] + t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True) + if apply_t5_attn_mask: + t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) + txt_ids = torch.zeros(1, t5_out.shape[1], 3, device=t5_out.device) + else: + t5_out = None + txt_ids = None + + return [l_pooled, t5_out, txt_ids] + + +class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + apply_t5_attn_mask: bool = False, + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + self.apply_t5_attn_mask = apply_t5_attn_mask + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, npz_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "l_pooled" not in npz: + return False + if "t5_out" not in npz: + return False + if "txt_ids" not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray: + return t5_out * np.expand_dims(t5_attn_mask, -1) + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + l_pooled = data["l_pooled"] + t5_out = data["t5_out"] + txt_ids = data["txt_ids"] + + if self.apply_t5_attn_mask: + t5_attn_mask = data["t5_attn_mask"] + t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) + + return [l_pooled, t5_out, txt_ids] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy + captions = [info.caption for info in infos] + + tokens_and_masks = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens_and_masks, self.apply_t5_attn_mask + ) + + if l_pooled.dtype == torch.bfloat16: + l_pooled = l_pooled.float() + if t5_out.dtype == torch.bfloat16: + t5_out = t5_out.float() + if txt_ids.dtype == torch.bfloat16: + txt_ids = txt_ids.float() + + l_pooled = l_pooled.cpu().numpy() + t5_out = t5_out.cpu().numpy() + txt_ids = txt_ids.cpu().numpy() + + for i, info in enumerate(infos): + l_pooled_i = l_pooled[i] + t5_out_i = t5_out[i] + txt_ids_i = txt_ids[i] + + if self.cache_to_disk: + t5_attn_mask = tokens_and_masks[2] + t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() + np.savez( + info.text_encoder_outputs_npz, + l_pooled=l_pooled_i, + t5_out=t5_out_i, + txt_ids=txt_ids_i, + t5_attn_mask=t5_attn_mask_i, + ) + else: + info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i) + + +class FluxLatentsCachingStrategy(LatentsCachingStrategy): + FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz" + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) + + +if __name__ == "__main__": + # test code for FluxTokenizeStrategy + # tokenizer = sd3_models.SD3Tokenizer() + strategy = FluxTokenizeStrategy(256) + text = "hello world" + + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + # print(l_tokens.shape) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + texts = ["hello world", "the quick brown fox jumps over the lazy dog"] + l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens_2 = strategy.t5xxl( + texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + print(l_tokens_2) + print(g_tokens_2) + print(t5_tokens_2) + + # compare + print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0])) + print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0])) + print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0])) + + text = ",".join(["hello world! this is long text"] * 50) + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + print(f"model max length l: {strategy.clip_l.model_max_length}") + print(f"model max length g: {strategy.clip_g.model_max_length}") + print(f"model max length t5: {strategy.t5xxl.model_max_length}") diff --git a/networks/lora_flux.py b/networks/lora_flux.py new file mode 100644 index 000000000..141137b46 --- /dev/null +++ b/networks/lora_flux.py @@ -0,0 +1,730 @@ +# temporary minimum implementation of LoRA +# FLUX doesn't have Conv2d, so we ignore it +# TODO commonize with the original implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +import re +from library.utils import setup_logging +from library.sdxl_original_unet import SdxlUNet2DConditionModel + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) # calc in float + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + # get up/down weight + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def set_region(self, region): + self.region = region + self.region_mask = None + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + ae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + flux, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + flux, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + varbose=True, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork(text_encoders, flux, multiplier=multiplier, module_class=module_class) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_FLUX = "lora_flux" + LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" + + def __init__( + self, + text_encoders: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + if self.conv_lora_dim is not None: + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + + # create module instances + def create_modules( + is_flux: bool, text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str] + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_FLUX + if is_flux + else (self.LORA_PREFIX_TEXT_ENCODER_CLIP if text_encoder_idx == 0 else self.LORA_PREFIX_TEXT_ENCODER_T5) + ) + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + dim = None + alpha = None + + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) + loras.append(lora) + return loras, skipped + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + index = i + logger.info(f"create LoRA for Text Encoder {index+1}:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.FLUX_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + skipped = skipped_te + skipped_un + if varbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, flux, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP) or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_FLUX): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) + # if ( + # self.loraplus_lr_ratio is not None + # or self.loraplus_text_encoder_lr_ratio is not None + # or self.loraplus_unet_lr_ratio is not None + # ): + # assert ( + # optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower() + # ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません" + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + params, descriptions = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + # if self.block_lr: + # is_sdxl = False + # for lora in self.unet_loras: + # if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name: + # is_sdxl = True + # break + + # # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 + # block_idx_to_lora = {} + # for lora in self.unet_loras: + # idx = get_block_index(lora.lora_name, is_sdxl) + # if idx not in block_idx_to_lora: + # block_idx_to_lora[idx] = [] + # block_idx_to_lora[idx].append(lora) + + # # blockごとにパラメータを設定する + # for idx, block_loras in block_idx_to_lora.items(): + # params, descriptions = assemble_params( + # block_loras, + # (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx), + # self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + # ) + # all_params.extend(params) + # lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions]) + + # else: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 67ccae62c..4d6e3f184 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -52,6 +52,11 @@ def load_target_model(self, args, weight_dtype, accelerator): self.logit_scale = logit_scale self.ckpt_info = ckpt_info + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet def get_tokenize_strategy(self, args): diff --git a/train_network.py b/train_network.py index 3828fed19..48d988624 100644 --- a/train_network.py +++ b/train_network.py @@ -100,6 +100,12 @@ def assert_extra_args(self, args, train_dataset_group): def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet def get_tokenize_strategy(self, args): @@ -147,6 +153,81 @@ def all_reduce_network(self, accelerator, network): def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet) + # region SD/SDXL + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images).latent_dist.sample() + + def shift_scale_latents(self, args, latents): + return latents * self.vae_scale_factor + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + ): + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + timesteps, + text_encoder_conds, + batch, + weight_dtype, + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + return noise_pred, target, timesteps, huber_c, None + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + return loss + + # endregion + def train(self, args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -253,11 +334,6 @@ def train(self, args): # text_encoder is List[CLIPTextModel] or CLIPTextModel text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) - # 差分追加学習のためにモデルを読み込む sys.path.append(os.path.dirname(__file__)) accelerator.print("import network module:", args.network_module) @@ -445,16 +521,19 @@ def train(self, args): unet_weight_dtype = torch.float8_e4m3fn te_weight_dtype = torch.float8_e4m3fn + unet.to(accelerator.device) # this makes faster `to(dtype)` below + unet.requires_grad_(False) - unet.to(dtype=unet_weight_dtype) + unet.to(dtype=unet_weight_dtype) # this takes long time and large memory for t_enc in text_encoders: t_enc.requires_grad_(False) # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 if t_enc.device.type != "cpu": t_enc.to(dtype=te_weight_dtype) - # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + if hasattr(t_enc.text_model, "embeddings"): + # nn.Embedding not support FP8 + t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: @@ -851,12 +930,7 @@ def load_model_hook(models, input_dir): global_step = 0 - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + noise_scheduler = self.get_noise_scheduler(args, accelerator.device) if accelerator.is_main_process: init_kwargs = {} @@ -913,6 +987,13 @@ def remove_model(old_ckpt_name): initial_step -= len(train_dataloader) global_step = initial_step + # log device and dtype for each model + logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") + for t_enc in text_encoders: + logger.info(f"text_encoder dtype: {te_weight_dtype}, device: {t_enc.device}") + + clean_memory_on_device(accelerator.device) + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -940,13 +1021,15 @@ def remove_model(old_ckpt_name): else: with torch.no_grad(): # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) + latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype)) + latents = latents.to(dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) - latents = latents * self.vae_scale_factor + + latents = self.shift_scale_latents(args, latents) # get multiplier for each sample if network_has_multiplier: @@ -985,41 +1068,25 @@ def remove_model(old_ckpt_name): if args.full_fp16: text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents + # sample noise, call unet, get target + noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, ) - # ensure the hidden state will require grad - if args.gradient_checkpointing: - for x in noisy_latents: - x.requires_grad_(True) - for t in text_encoder_conds: - t.requires_grad_(True) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = self.call_unet( - args, - accelerator, - unet, - noisy_latents.requires_grad_(train_unet), - timesteps, - text_encoder_conds, - batch, - weight_dtype, - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) + if weighting is not None: + loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -1027,14 +1094,8 @@ def remove_model(old_ckpt_name): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし From 808d2d1f48e2f4e544d47464edb2727c03da2f53 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 9 Aug 2024 23:02:51 +0900 Subject: [PATCH 029/163] fix typos --- flux_train_network.py | 2 +- library/flux_models.py | 4 ++-- library/flux_utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 7c762c86d..e4be97ad8 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -250,7 +250,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # ) with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( img=packed_noisy_model_input, img_ids=img_ids, diff --git a/library/flux_models.py b/library/flux_models.py index d0955e375..92c79bcca 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -685,11 +685,11 @@ def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[T attn = attention(q, k, v, pe=pe) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - # calculate the img bloks + # calculate the img blocks img = img + img_mod1.gate * self.img_attn.proj(img_attn) img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) - # calculate the txt bloks + # calculate the txt blocks txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) return img, txt diff --git a/library/flux_utils.py b/library/flux_utils.py index ba828d508..166cd833b 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -20,7 +20,7 @@ def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux: - logger.info(f"Bulding Flux model {name}") + logger.info(f"Building Flux model {name}") with torch.device("meta"): model = flux_models.Flux(flux_models.configs[name].params).to(dtype) From 358f13f2c92a04fb524006f124fc029a9edb0eaf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 10 Aug 2024 14:03:59 +0900 Subject: [PATCH 030/163] fix alpha is ignored --- networks/lora_flux.py | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 141137b46..332a73d97 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -307,7 +307,9 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh module_class = LoRAInfModule if for_inference else LoRAModule - network = LoRANetwork(text_encoders, flux, multiplier=multiplier, module_class=module_class) + network = LoRANetwork( + text_encoders, flux, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + ) return network, weights_sd @@ -331,6 +333,8 @@ def __init__( conv_lora_dim: Optional[int] = None, conv_alpha: Optional[float] = None, module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -348,12 +352,15 @@ def __init__( self.loraplus_unet_lr_ratio = None self.loraplus_text_encoder_lr_ratio = None - logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" - ) - if self.conv_lora_dim is not None: - logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + if self.conv_lora_dim is not None: + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( @@ -381,13 +388,19 @@ def create_modules( dim = None alpha = None - # 通常、すべて対象とする - if is_linear or is_conv2d_1x1: - dim = self.lora_dim - alpha = self.alpha - elif self.conv_lora_dim is not None: - dim = self.conv_lora_dim - alpha = self.conv_alpha + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha if dim is None or dim == 0: # skipした情報を出力 From 8a0f12dde812994ec3facdcdb7c08b362dbceb0f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 10 Aug 2024 23:42:05 +0900 Subject: [PATCH 031/163] update FLUX LoRA training --- README.md | 29 ++++++++--- flux_train_network.py | 105 ++++++++++++++++++++++++++++++-------- library/sai_model_spec.py | 24 +++++++-- library/strategy_flux.py | 4 +- library/train_util.py | 9 ++-- networks/lora_flux.py | 2 +- train_network.py | 18 +++++-- 7 files changed, 150 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index a0b02f108..1089dd001 100644 --- a/README.md +++ b/README.md @@ -2,24 +2,41 @@ This repository contains training, generation and utility scripts for Stable Dif ## FLUX.1 LoRA training (WIP) -__Aug 9, 2024__: +This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. + +Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. -We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below. It will work with 24GB VRAM GPUs. ``` -accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0 --loss_type l2 ``` +LoRAs for Text Encoders are not tested yet. + +We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows: + +- `--timestep_sampling` is the method to sample timesteps (0-1): `sigma` (sigma-based, same as SD3), `uniform` (uniform random), or `sigmoid` (sigmoid of random normal, same as x-flux). +- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform. +- `--model_prediction_type` is how to interpret and process the model prediction: `raw` (use as is, same as x-flux), `additive` (add to noisy input), `sigma_scaled` (apply sigma scaling, same as SD3). +- `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3). + +`--loss_type` may be useful for FLUX.1 training. The default is `l2`. + +In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. Other settings may work better, so please try different settings. + +We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. + +The trained LoRA model can be used with ComfyUI. + The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. ``` -python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors +python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` -Unfortnately the training result is not good. Please let us know if you have any idea to improve the training. - ## SD3 training SD3 training is done with `sd3_train.py`. diff --git a/flux_train_network.py b/flux_train_network.py index e4be97ad8..69b6e8eaf 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -135,7 +135,7 @@ def sample_images(self, accelerator, args, epoch, global_step, device, vae, toke pass def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) return noise_scheduler @@ -211,21 +211,32 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image - # for weighting schemes where we sample timesteps non-uniformly - u = compute_density_for_timestep_sampling( - weighting_scheme=args.weighting_scheme, - batch_size=bsz, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, - ) - indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) - - # Add noise according to flow matching. - sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main + t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=accelerator.device)) + else: + t = torch.rand((bsz,), device=accelerator.device) + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) + + # Add noise according to flow matching. + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 @@ -264,11 +275,20 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - model_pred = model_pred * (-sigmas) + noisy_model_input - - # these weighting schemes use a uniform timestep sampling - # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + if args.model_prediction_type == "raw": + # use model_pred as is + weighting = None + elif args.model_prediction_type == "additive": + # add the model_pred to the noisy_model_input + model_pred = model_pred + noisy_model_input + weighting = None + elif args.model_prediction_type == "sigma_scaled": + # apply sigma scaling + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss: this is different from SD3 target = noise - latents @@ -278,6 +298,21 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + + def update_metadata(self, metadata, args): + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_guidance_scale"] = args.guidance_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() @@ -318,6 +353,34 @@ def setup_parser() -> argparse.ArgumentParser: default=3.5, help="the FLUX.1 dev variant is a guidance distilled model", ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="sigma_scaled", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=3.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + ) return parser diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index af073677e..ad72ec00d 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -59,6 +59,8 @@ ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" ARCH_SD3_M = "stable-diffusion-3-medium" ARCH_SD3_UNKNOWN = "stable-diffusion-3" +ARCH_FLUX_1_DEV = "flux-1-dev" +ARCH_FLUX_1_UNKNOWN = "flux-1" ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" @@ -66,6 +68,7 @@ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI" IMPL_DIFFUSERS = "diffusers" +IMPL_FLUX = "https://github.com/black-forest-labs/flux" PRED_TYPE_EPSILON = "epsilon" PRED_TYPE_V = "v" @@ -118,10 +121,11 @@ def build_metadata( merged_from: Optional[str] = None, timesteps: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, - sd3: str = None, + sd3: Optional[str] = None, + flux: Optional[str] = None, ): """ - sd3: only supports "m" + sd3: only supports "m", flux: only supports "dev" """ # if state_dict is None, hash is not calculated @@ -140,6 +144,11 @@ def build_metadata( arch = ARCH_SD3_M else: arch = ARCH_SD3_UNKNOWN + elif flux is not None: + if flux == "dev": + arch = ARCH_FLUX_1_DEV + else: + arch = ARCH_FLUX_1_UNKNOWN elif v2: if v_parameterization: arch = ARCH_SD_V2_768_V @@ -158,7 +167,10 @@ def build_metadata( if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion - if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: + if flux is not None: + # Flux + impl = IMPL_FLUX + elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA impl = IMPL_STABILITY_AI else: @@ -216,7 +228,7 @@ def build_metadata( reso = (reso[0], reso[0]) else: # resolution is defined in dataset, so use default - if sdxl or sd3 is not None: + if sdxl or sd3 is not None or flux is not None: reso = 1024 elif v2 and v_parameterization: reso = 768 @@ -227,7 +239,9 @@ def build_metadata( metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}" - if v_parameterization: + if flux is not None: + del metadata["modelspec.prediction_type"] + elif v_parameterization: metadata["modelspec.prediction_type"] = PRED_TYPE_V else: metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON diff --git a/library/strategy_flux.py b/library/strategy_flux.py index f194ccf6e..13459d32f 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -63,11 +63,11 @@ def encode_tokens( l_pooled = None if t5xxl is not None and t5_tokens is not None: - # t5_out is [1, max length, 4096] + # t5_out is [b, max length, 4096] t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True) if apply_t5_attn_mask: t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) - txt_ids = torch.zeros(1, t5_out.shape[1], 3, device=t5_out.device) + txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device) else: t5_out = None txt_ids = None diff --git a/library/train_util.py b/library/train_util.py index fc458a884..6b74bb3fa 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3186,6 +3186,7 @@ def get_sai_model_spec( textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA sd3: str = None, + flux: str = None, ): timestamp = time.time() @@ -3220,6 +3221,7 @@ def get_sai_model_spec( timesteps=timesteps, clip_skip=args.clip_skip, # None or int sd3=sd3, + flux=flux, ) return metadata @@ -3642,8 +3644,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--loss_type", type=str, default="l2", - choices=["l2", "huber", "smooth_l1"], - help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2", + choices=["l1", "l2", "huber", "smooth_l1"], + help="The type of loss function to use (L1, L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1)、デフォルトはL2", ) parser.add_argument( "--huber_schedule", @@ -5359,9 +5361,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): def conditional_loss( model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1 ): - if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) + elif loss_type == "l1": + loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) elif loss_type == "huber": loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 332a73d97..a4dab287a 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -316,7 +316,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh class LoRANetwork(torch.nn.Module): FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] - LORA_PREFIX_FLUX = "lora_flux" + LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" diff --git a/train_network.py b/train_network.py index 48d988624..367203f54 100644 --- a/train_network.py +++ b/train_network.py @@ -226,6 +226,12 @@ def post_process_loss(self, loss, args, timesteps, noise_scheduler): loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) return loss + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) + + def update_metadata(self, metadata, args): + pass + # endregion def train(self, args): @@ -521,10 +527,13 @@ def train(self, args): unet_weight_dtype = torch.float8_e4m3fn te_weight_dtype = torch.float8_e4m3fn - unet.to(accelerator.device) # this makes faster `to(dtype)` below + # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM + # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory + + unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above unet.requires_grad_(False) - unet.to(dtype=unet_weight_dtype) # this takes long time and large memory + unet.to(dtype=unet_weight_dtype) for t_enc in text_encoders: t_enc.requires_grad_(False) @@ -718,8 +727,11 @@ def load_model_hook(models, input_dir): "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, "ss_huber_c": args.huber_c, + "ss_fp8_base": args.fp8_base, } + self.update_metadata(metadata, args) # architecture specific metadata + if use_user_config: # save metadata of multiple datasets # NOTE: pack "ss_datasets" value as json one time @@ -964,7 +976,7 @@ def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False metadata["ss_epoch"] = str(epoch_no) metadata_to_save = minimum_metadata if args.no_metadata else metadata - sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) + sai_metadata = self.get_sai_model_spec(args) metadata_to_save.update(sai_metadata) unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save) From 82314ac2e7926ed15eac6306bebe4ffb78280346 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 11 Aug 2024 11:14:08 +0900 Subject: [PATCH 032/163] update readme for ai toolkit settings --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1089dd001..d016bcec4 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,11 @@ We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_sca `--loss_type` may be useful for FLUX.1 training. The default is `l2`. -In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. Other settings may work better, so please try different settings. +In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. + +additional note (Aug 11): A quick check shows that the settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). + +Other settings may work better, so please try different settings. We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. From d25ae361d06bb6f49c104ca2e6b4a9188a88c95f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 11 Aug 2024 19:07:07 +0900 Subject: [PATCH 033/163] fix apply_t5_attn_mask to work --- README.md | 2 ++ flux_train_network.py | 6 ++++-- library/strategy_flux.py | 18 +++++++++++++----- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index d016bcec4..d47776ca6 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ This repository contains training, generation and utility scripts for Stable Dif This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. +Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before. + Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. diff --git a/flux_train_network.py b/flux_train_network.py index 69b6e8eaf..59a666aae 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -67,14 +67,16 @@ def get_latents_caching_strategy(self, args): return latents_caching_strategy def get_text_encoding_strategy(self, args): - return strategy_flux.FluxTextEncodingStrategy() + return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) def get_models_for_text_encoding(self, args, accelerator, text_encoders): return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])] def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: - return strategy_flux.FluxTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) + return strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False, apply_t5_attn_mask=args.apply_t5_attn_mask + ) else: return None diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 13459d32f..3880a1e1b 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -41,17 +41,24 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: class FluxTextEncodingStrategy(TextEncodingStrategy): - def __init__(self) -> None: - pass + def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None: + """ + Args: + apply_t5_attn_mask: Default value for apply_t5_attn_mask. + """ + self.apply_t5_attn_mask = apply_t5_attn_mask def encode_tokens( self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], - apply_t5_attn_mask: bool = False, + apply_t5_attn_mask: Optional[bool] = None, ) -> List[torch.Tensor]: - # supports single model inference only + # supports single model inference + + if apply_t5_attn_mask is None: + apply_t5_attn_mask = self.apply_t5_attn_mask clip_l, t5xxl = models l_tokens, t5_tokens = tokens[:2] @@ -137,8 +144,9 @@ def cache_batch_outputs( tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): + # attn_mask is not applied when caching to disk: it is applied when loading from disk l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens_and_masks, self.apply_t5_attn_mask + tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk ) if l_pooled.dtype == torch.bfloat16: From 74f91c2ff71035db105b218128567e6b8fa6c80d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 11 Aug 2024 21:54:10 +0900 Subject: [PATCH 034/163] correct option name closes #1446 --- docs/train_README-ja.md | 2 +- docs/train_README-zh.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index d186bf243..cfa5a7d1c 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -648,7 +648,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b 詳細については各自お調べください。 - 任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--scheduler_args`でオプション引数を指定してください。 + 任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--lr_scheduler_args`でオプション引数を指定してください。 ### オプティマイザの指定について diff --git a/docs/train_README-zh.md b/docs/train_README-zh.md index 7e00278c5..1bc47e0f5 100644 --- a/docs/train_README-zh.md +++ b/docs/train_README-zh.md @@ -582,7 +582,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b 有关详细信息,请自行研究。 - 要使用任何调度程序,请像使用任何优化器一样使用“--scheduler_args”指定可选参数。 + 要使用任何调度程序,请像使用任何优化器一样使用“--lr_scheduler_args”指定可选参数。 ### 关于指定优化器 使用 --optimizer_args 选项指定优化器选项参数。可以以key=value的格式指定多个值。此外,您可以指定多个值,以逗号分隔。例如,要指定 AdamW 优化器的参数,``--optimizer_args weight_decay=0.01 betas=.9,.999``。 From 9e09a69df1ea8aa76ec98df3b2eed961c66432e4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Aug 2024 08:19:45 +0900 Subject: [PATCH 035/163] update README --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d47776ca6..ccc83e6e8 100644 --- a/README.md +++ b/README.md @@ -10,10 +10,10 @@ Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to mak Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. -We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below. It will work with 24GB VRAM GPUs. +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. ``` -accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0 --loss_type l2 +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2 ``` LoRAs for Text Encoders are not tested yet. @@ -29,7 +29,7 @@ We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_sca In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. -additional note (Aug 11): A quick check shows that the settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). +additional note (Aug 11): A quick check shows that the settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). This seems to be a good starting point. Thanks to Ostris for the great work! Other settings may work better, so please try different settings. From 4af36f96320d553025cfdf067cae1e346af44a67 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Mon, 12 Aug 2024 13:24:10 +0900 Subject: [PATCH 036/163] update to work interactive mode --- README.md | 2 ++ flux_minimal_inference.py | 33 +++++++++++++++++++++++++++------ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index ccc83e6e8..c0d50a5a2 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ The trained LoRA model can be used with ComfyUI. The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. +Aug 12: `--interactive` option is now working. + ``` python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index f3affca80..b09f63808 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -5,7 +5,7 @@ import math import os import random -from typing import Callable, Optional, Tuple +from typing import Callable, List, Optional, Tuple import einops import numpy as np @@ -121,6 +121,9 @@ def generate_image( steps: Optional[int], guidance: float, ): + seed = seed if seed is not None else random.randint(0, 2**32 - 1) + logger.info(f"Seed: {seed}") + # make first noise with packed shape # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) @@ -183,9 +186,7 @@ def generate_image( steps = 4 if is_schnell else 50 img_ids = img_ids.to(device) - x = do_sample( - accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance_scale, is_schnell, device, flux_dtype - ) + x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, is_schnell, device, flux_dtype) if args.offload: model = model.cpu() # del model @@ -255,6 +256,7 @@ def generate_image( default=[], help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)", ) + parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") parser.add_argument("--width", type=int, default=target_width) parser.add_argument("--height", type=int, default=target_height) parser.add_argument("--interactive", action="store_true") @@ -341,6 +343,7 @@ def is_fp8(dt): ae = accelerator.prepare(ae) # LoRA + lora_models: List[lora_flux.LoRANetwork] = [] for weights_file in args.lora_weights: if ";" in weights_file: weights_file, multiplier = weights_file.split(";") @@ -351,7 +354,16 @@ def is_fp8(dt): lora_model, weights_sd = lora_flux.create_network_from_weights( multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True ) - lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + if args.merge_lora_weights: + lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + else: + lora_model.apply_to([clip_l, t5xxl], model) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.eval() + lora_model.to(device) + + lora_models.append(lora_model) if not args.interactive: generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) @@ -363,7 +375,9 @@ def is_fp8(dt): guidance = args.guidance while True: - print("Enter prompt (empty to exit). Options: --w --h --s --d --g ") + print( + "Enter prompt (empty to exit). Options: --w --h --s --d --g --m " + ) prompt = input() if prompt == "": break @@ -384,6 +398,13 @@ def is_fp8(dt): seed = int(opt[1:].strip()) elif opt.startswith("g"): guidance = float(opt[1:].strip()) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance) From a7d5dabde3facb57d069eba0aa91e961e04303ad Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Aug 2024 17:09:19 +0900 Subject: [PATCH 037/163] Update readme --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index c0d50a5a2..19aed2212 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,12 @@ We have added a new training script for LoRA training. The script is `flux_train accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2 ``` +The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: + +``` +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"` +``` + LoRAs for Text Encoders are not tested yet. We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows: From 0415d200f5f3db89e33b33c9b36cb3c3e15d0266 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 13 Aug 2024 21:00:16 +0900 Subject: [PATCH 038/163] update dependencies closes #1450 --- requirements.txt | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index e99775b8a..4ee19b3ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ -accelerate==0.25.0 -transformers==4.36.2 +accelerate==0.33.0 +transformers==4.44.0 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 opencv-python==4.7.0.68 einops==0.7.0 pytorch-lightning==1.9.0 -bitsandbytes==0.43.0 +bitsandbytes==0.43.3 prodigyopt==1.0 lion-pytorch==0.0.6 tensorboard @@ -16,7 +16,7 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.20.1 +huggingface-hub==0.24.5 # for Image utils imagesize==1.4.1 # for BLIP captioning @@ -38,5 +38,7 @@ imagesize==1.4.1 # open-clip-torch==2.20.0 # For logging rich==13.7.0 +# for T5XXL tokenizer (SD3/FLUX) +sentencepiece==0.2.0 # for kohya_ss library -e . From 9711c96f96038df5fa1a15d073244198b93ef0a2 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 13 Aug 2024 21:03:17 +0900 Subject: [PATCH 039/163] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 19aed2212..3eb034ed4 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-ge Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. -Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. +__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. From 56d7651f0895c805c403a8db01083a522503eb7d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 13 Aug 2024 22:28:39 +0900 Subject: [PATCH 040/163] add experimental split mode for FLUX --- README.md | 22 +++++- flux_train_network.py | 110 +++++++++++++++++++++++---- library/flux_models.py | 165 +++++++++++++++++++++++++++++++++++++++++ networks/lora_flux.py | 30 ++++++-- 4 files changed, 304 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 3eb034ed4..64b018804 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,22 @@ This repository contains training, generation and utility scripts for Stable Dif This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. +__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ + +Aug 13, 2024: + +__Experimental__ A network argument `train_blocks` is added to `lora_flux`. This is to select the target blocks of LoRA from FLUX double blocks and single blocks. Specify like `--network_args "train_blocks=single"`. `all` trains both double blocks and single blocks, `double` trains only double blocks, and `single` trains only single blocks. The default (omission) is `all`. + +This argument is available even if `--split_mode` is not specified. + +__Experimental__ `--split_mode` option is added to `flux_train_network.py`. This splits FLUX into double blocks and single blocks for training. By enabling gradients only for the single blocks part, memory usage is reduced. When this option is specified, you need to specify `"train_blocks=single"` in the network arguments. + +This option enables training with 12GB VRAM GPUs, but the training speed is 2-3 times slower than the default. + Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before. Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. -__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ - We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. ``` @@ -19,7 +29,13 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: ``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"` +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" +``` + +The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_mode` and `train_blocks=single` options. Please use settings like below: + +``` +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" ``` LoRAs for Text Encoders are not tested yet. diff --git a/flux_train_network.py b/flux_train_network.py index 59a666aae..1d1f00d84 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -37,10 +37,16 @@ def assert_extra_args(self, args, train_dataset_group): args.network_train_unet_only or not args.cache_text_encoder_outputs ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" - train_dataset_group.verify_bucket_reso_steps(32) + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models + name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way + # if we load to cpu, flux.to(fp8) takes a long time + model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + + if args.split_mode: + model = self.prepare_split_model(model, weight_dtype, accelerator) clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") clip_l.eval() @@ -49,13 +55,47 @@ def load_target_model(self, args, weight_dtype, accelerator): t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") t5xxl.eval() - name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way - # if we load to cpu, flux.to(fp8) takes a long time - model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + def prepare_split_model(self, model, weight_dtype, accelerator): + from accelerate import init_empty_weights + + logger.info("prepare split model") + with init_empty_weights(): + flux_upper = flux_models.FluxUpper(model.params) + flux_lower = flux_models.FluxLower(model.params) + sd = model.state_dict() + + # lower (trainable) + logger.info("load state dict for lower") + flux_lower.load_state_dict(sd, strict=False, assign=True) + flux_lower.to(dtype=weight_dtype) + + # upper (frozen) + logger.info("load state dict for upper") + flux_upper.load_state_dict(sd, strict=False, assign=True) + + logger.info("prepare upper model") + target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype + flux_upper.to(accelerator.device, dtype=target_dtype) + flux_upper.eval() + + if args.fp8_base: + # this is required to run on fp8 + flux_upper = accelerator.prepare(flux_upper) + + flux_upper.to("cpu") + + self.flux_upper = flux_upper + del model # we don't need model anymore + clean_memory_on_device(accelerator.device) + + logger.info("split model prepared") + + return flux_lower + def get_tokenize_strategy(self, args): return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) @@ -262,17 +302,51 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}" # ) - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = unet( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - ) + if not args.split_mode: + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + ) + else: + # split forward to reduce memory usage + assert network.train_blocks == "single", "train_blocks must be single for split mode" + with accelerator.autocast(): + # move flux lower to cpu, and then move flux upper to gpu + unet.to("cpu") + clean_memory_on_device(accelerator.device) + self.flux_upper.to(accelerator.device) + + # upper model does not require grad + with torch.no_grad(): + intermediate_img, intermediate_txt, vec, pe = self.flux_upper( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + ) + + # move flux upper back to cpu, and then move flux lower to gpu + self.flux_upper.to("cpu") + clean_memory_on_device(accelerator.device) + unet.to(accelerator.device) + + # lower model requires grad + intermediate_img.requires_grad_(True) + intermediate_txt.requires_grad_(True) + vec.requires_grad_(True) + pe.requires_grad_(True) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe) # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) @@ -331,6 +405,12 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", ) + parser.add_argument( + "--split_mode", + action="store_true", + help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", + ) # copy from Diffusers parser.add_argument( diff --git a/library/flux_models.py b/library/flux_models.py index 92c79bcca..3c7766b85 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -918,3 +918,168 @@ def forward( img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img + + +class FluxUpper(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks: + block.enable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + return img, txt, vec, pe + + +class FluxLower(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.out_channels = params.in_channels + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + for block in self.single_blocks: + block.enable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + for block in self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def forward( + self, + img: Tensor, + txt: Tensor, + vec: Tensor | None = None, + pe: Tensor | None = None, + ) -> Tensor: + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/networks/lora_flux.py b/networks/lora_flux.py index a4dab287a..4da33542f 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -252,6 +252,11 @@ def create_network( if module_dropout is not None: module_dropout = float(module_dropout) + # single or double blocks + train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" + if train_blocks is not None: + assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -264,6 +269,7 @@ def create_network( module_dropout=module_dropout, conv_lora_dim=conv_dim, conv_alpha=conv_alpha, + train_blocks=train_blocks, varbose=True, ) @@ -314,9 +320,11 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh class LoRANetwork(torch.nn.Module): - FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] + # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] - LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible + LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" @@ -335,6 +343,7 @@ def __init__( module_class: Type[object] = LoRAModule, modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, + train_blocks: Optional[str] = None, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -347,6 +356,7 @@ def __init__( self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.train_blocks = train_blocks if train_blocks is not None else "all" self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -360,7 +370,9 @@ def __init__( f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" ) if self.conv_lora_dim is not None: - logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + ) # create module instances def create_modules( @@ -434,9 +446,17 @@ def create_modules( skipped_te += skipped logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + # create LoRA for U-Net + if self.train_blocks == "all": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "single": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "double": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] - self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.FLUX_TARGET_REPLACE_MODULE) - logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules) + logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: From 9760d097b0bd7efbeb065d4320b2216a94e76efd Mon Sep 17 00:00:00 2001 From: DukeG Date: Wed, 14 Aug 2024 19:58:54 +0800 Subject: [PATCH 041/163] Fix AttributeError: 'T5EncoderModel' object has no attribute 'text_model' While loading T5 model in GPU. --- train_network.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 367203f54..405aa747c 100644 --- a/train_network.py +++ b/train_network.py @@ -540,9 +540,13 @@ def train(self, args): # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 if t_enc.device.type != "cpu": t_enc.to(dtype=te_weight_dtype) - if hasattr(t_enc.text_model, "embeddings"): + if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"): # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + t_enc.text_model.embeddings.to( + dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"): + t_enc.encoder.embeddings.to( + dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: From 7db422211907df3c50703b419655202276a53301 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 14 Aug 2024 22:15:26 +0900 Subject: [PATCH 042/163] add sample image generation during training --- README.md | 2 + flux_train_network.py | 67 +++++++- library/flux_train_utils.py | 297 ++++++++++++++++++++++++++++++++++++ train_network.py | 13 +- 4 files changed, 374 insertions(+), 5 deletions(-) create mode 100644 library/flux_train_utils.py diff --git a/README.md b/README.md index 64b018804..7dc954fbc 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,8 @@ This feature is experimental. The options and the training script may change in __Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ +Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified. + Aug 13, 2024: __Experimental__ A network argument `train_blocks` is added to `lora_flux`. This is to select the target blocks of LoRA from FLUX double blocks and single blocks. Specify like `--network_args "train_blocks=single"`. `all` trains both double blocks and single blocks, `double` trains only double blocks, and `single` trains only single blocks. The default (omission) is `all`. diff --git a/flux_train_network.py b/flux_train_network.py index 1d1f00d84..b8ea56223 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -10,7 +10,7 @@ init_ipex() -from library import flux_models, flux_utils, sd3_train_utils, sd3_utils, sdxl_model_util, sdxl_train_util, strategy_flux, train_util +from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util import train_network from library.utils import setup_logging @@ -28,6 +28,12 @@ def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + if args.cache_text_encoder_outputs: assert ( train_dataset_group.is_text_encoder_output_cacheable() @@ -139,8 +145,31 @@ def cache_text_encoder_outputs_if_needed( text_encoders[1].to(accelerator.device, dtype=weight_dtype) with accelerator.autocast(): dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) + + # cache sample prompts + self.sample_prompts_te_outputs = None + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = sd3_train_utils.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + accelerator.wait_for_everyone() + # move back to cpu logger.info("move text encoders back to cpu") text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu") # , dtype=torch.float32) @@ -172,9 +201,36 @@ def cache_text_encoder_outputs_if_needed( # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) # return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - # logger.warning("Sampling images is not supported for Flux model") - pass + def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): + if not args.split_mode: + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs + ) + return + + class FluxUpperLowerWrapper(torch.nn.Module): + def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): + super().__init__() + self.flux_upper = flux_upper + self.flux_lower = flux_lower + self.target_device = device + + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None): + self.flux_lower.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_upper.to(self.target_device) + img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance) + self.flux_upper.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_lower.to(self.target_device) + return self.flux_lower(img, txt, vec, pe) + + wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) + clean_memory_on_device(accelerator.device) + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs + ) + clean_memory_on_device(accelerator.device) def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) @@ -389,6 +445,9 @@ def update_metadata(self, metadata, args): metadata["ss_model_prediction_type"] = args.model_prediction_type metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py new file mode 100644 index 000000000..91f522389 --- /dev/null +++ b/library/flux_train_utils.py @@ -0,0 +1,297 @@ +import argparse +import math +import os +import numpy as np +import toml +import json +import time +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from accelerate import Accelerator, PartialState +from transformers import CLIPTextModel +from tqdm import tqdm +from PIL import Image + +from library import flux_models, flux_utils, strategy_base +from library.sd3_train_utils import load_prompts +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + flux, + ae, + text_encoders, + sample_prompts_te_outputs, + prompt_replacement=None, +): + if steps == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts): + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap unet and text_encoder(s) + flux = accelerator.unwrap_model(flux) + text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + flux, + text_encoders, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + flux, + text_encoders, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + clean_memory_on_device(accelerator.device) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + flux: flux_models.Flux, + text_encoders: List[CLIPTextModel], + ae: flux_models.AutoEncoder, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, +): + assert isinstance(prompt_dict, dict) + # negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 20) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 3.5) + seed = prompt_dict.get("seed") + # controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + # if negative_prompt is not None: + # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + # if negative_prompt is None: + # negative_prompt = "" + + height = max(64, height - height % 16) # round to divisible by 16 + width = max(64, width - width % 16) # round to divisible by 16 + logger.info(f"prompt: {prompt}") + # logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: + te_outputs = sample_prompts_te_outputs[prompt] + else: + tokens_and_masks = tokenize_strategy.tokenize(prompt) + te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + + l_pooled, t5_out, txt_ids = te_outputs + + # sample image + weight_dtype = ae.dtype # TOFO give dtype as argument + packed_latent_height = height // 16 + packed_latent_width = width // 16 + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=accelerator.device, + dtype=weight_dtype, + generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, + ) + timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True + img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + + with accelerator.autocast(), torch.no_grad(): + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale) + + x = x.float() + x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) + + # latent to image + clean_memory_on_device(accelerator.device) + org_vae_device = ae.device # will be on cpu + ae.to(accelerator.device) # distributed_state.device is same as accelerator.device + with accelerator.autocast(), torch.no_grad(): + x = ae.decode(x) + ae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + vec: torch.Tensor, + timesteps: list[float], + guidance: float = 4.0, +): + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec) + + img = img + (t_prev - t_curr) * pred + + return img diff --git a/train_network.py b/train_network.py index 367203f54..53d71b57d 100644 --- a/train_network.py +++ b/train_network.py @@ -232,6 +232,9 @@ def get_sai_model_spec(self, args): def update_metadata(self, metadata, args): pass + def is_text_encoder_not_needed_for_training(self, args): + return False # use for sample images + # endregion def train(self, args): @@ -529,7 +532,7 @@ def train(self, args): # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory - + unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above unet.requires_grad_(False) @@ -989,6 +992,14 @@ def remove_model(old_ckpt_name): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + # if text_encoder is not needed for training, delete it to save memory. + # TODO this can be automated after SDXL sample prompt cache is implemented + if self.is_text_encoder_not_needed_for_training(args): + logger.info("text_encoder is not needed for training. deleting to save memory.") + for t_enc in text_encoders: + del t_enc + text_encoders = [] + # For --sample_at_first self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) From 8aaa1967bd3d3a9b4b44e97e5432d23f2101cf51 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Aug 2024 22:07:23 +0900 Subject: [PATCH 043/163] fix encoding latents closes #1456 --- flux_train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index b8ea56223..daa65c857 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -238,8 +238,8 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> return noise_scheduler def encode_images_to_latents(self, args, accelerator, vae, images): - return vae.encode(images).latent_dist.sample() - + return vae.encode(images) + def shift_scale_latents(self, args, latents): return latents From 35b6cb0cd1b319d5f34b44a8c24c81c42895fa2e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Aug 2024 22:07:35 +0900 Subject: [PATCH 044/163] update for torchvision --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 7dc954fbc..bdb6bf2ed 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,10 @@ This repository contains training, generation and utility scripts for Stable Dif This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. -__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ +__Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchvision==0.19.0` with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ + +The command to install PyTorch is as follows: +`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified. From 08ef886bfeb058aa6d6f7e0a19589c0fd80b3757 Mon Sep 17 00:00:00 2001 From: DukeG Date: Fri, 16 Aug 2024 11:00:08 +0800 Subject: [PATCH 045/163] Fix AttributeError: 'FluxNetworkTrainer' object has no attribute 'sample_prompts_te_outputs' Move "self.sample_prompts_te_outputs = None" from Line 150 to Line 26. --- flux_train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index daa65c857..59b9d84b5 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -23,6 +23,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() + self.sample_prompts_te_outputs = None def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -147,7 +148,6 @@ def cache_text_encoder_outputs_if_needed( dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) # cache sample prompts - self.sample_prompts_te_outputs = None if args.sample_prompts is not None: logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") From 3921a4efda1cd1d7d873177ea7f51b77c3f15d3d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 16 Aug 2024 17:06:05 +0900 Subject: [PATCH 046/163] add t5xxl max token length, support schnell --- README.md | 8 ++++++++ flux_train_network.py | 32 ++++++++++++++++++++++++++++---- library/flux_models.py | 12 ++++++++---- 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index bdb6bf2ed..6fb050dff 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,14 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 16, 2024: + +FLUX.1 schnell model based training is now supported (but not tested). If the name of the model file contains `schnell`, the model is treated as a schnell model. + +Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. The default is 512 in dev and 256 in schnell. + +Previously, when `--max_token_length` was specified, that value was used, and 512 was used when omitted (default). Therefore, there is no impact if `--max_token_length` was not specified. If `--max_token_length` was specified, please specify `--t5xxl_max_token_length` instead. `--max_token_length` is ignored during FLUX.1 training. + Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified. Aug 13, 2024: diff --git a/flux_train_network.py b/flux_train_network.py index 59b9d84b5..b9a29c160 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -44,11 +44,18 @@ def assert_extra_args(self, args, train_dataset_group): args.network_train_unet_only or not args.cache_text_encoder_outputs ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + def get_flux_model_name(self, args): + return "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" + def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models - name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way + name = self.get_flux_model_name(args) + # if we load to cpu, flux.to(fp8) takes a long time model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") @@ -104,7 +111,18 @@ def prepare_split_model(self, model, weight_dtype, accelerator): return flux_lower def get_tokenize_strategy(self, args): - return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + name = self.get_flux_model_name(args) + + if args.t5xxl_max_token_length is None: + if name == "schnell": + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] @@ -239,7 +257,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> def encode_images_to_latents(self, args, accelerator, vae, images): return vae.encode(images) - + def shift_scale_latents(self, args, latents): return latents @@ -470,7 +488,13 @@ def setup_parser() -> argparse.ArgumentParser: help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", ) - + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=None, + help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev" + " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512", + ) # copy from Diffusers parser.add_argument( "--weighting_scheme", diff --git a/library/flux_models.py b/library/flux_models.py index 3c7766b85..ed0bc8c7d 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -863,7 +863,8 @@ def enable_gradient_checkpointing(self): self.time_in.enable_gradient_checkpointing() self.vector_in.enable_gradient_checkpointing() - self.guidance_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() for block in self.double_blocks + self.single_blocks: block.enable_gradient_checkpointing() @@ -875,7 +876,8 @@ def disable_gradient_checkpointing(self): self.time_in.disable_gradient_checkpointing() self.vector_in.disable_gradient_checkpointing() - self.guidance_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() for block in self.double_blocks + self.single_blocks: block.disable_gradient_checkpointing() @@ -972,7 +974,8 @@ def enable_gradient_checkpointing(self): self.time_in.enable_gradient_checkpointing() self.vector_in.enable_gradient_checkpointing() - self.guidance_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() for block in self.double_blocks: block.enable_gradient_checkpointing() @@ -984,7 +987,8 @@ def disable_gradient_checkpointing(self): self.time_in.disable_gradient_checkpointing() self.vector_in.disable_gradient_checkpointing() - self.guidance_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() for block in self.double_blocks: block.disable_gradient_checkpointing() From e45d3f8634c6dd4e358a8c7972f7c851f18f94d3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 16 Aug 2024 22:19:21 +0900 Subject: [PATCH 047/163] add merge LoRA script --- README.md | 24 +++ library/train_util.py | 2 +- networks/flux_merge_lora.py | 361 ++++++++++++++++++++++++++++++++++++ 3 files changed, 386 insertions(+), 1 deletion(-) create mode 100644 networks/flux_merge_lora.py diff --git a/README.md b/README.md index 6fb050dff..e231cc24e 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,8 @@ The command to install PyTorch is as follows: Aug 16, 2024: +Added a script `networks/flux_merge_lora.py` to merge LoRA into FLUX.1 checkpoint. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. + FLUX.1 schnell model based training is now supported (but not tested). If the name of the model file contains `schnell`, the model is treated as a schnell model. Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. The default is 512 in dev and 256 in schnell. @@ -80,6 +82,28 @@ Aug 12: `--interactive` option is now working. python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` +### Merge LoRA to FLUX.1 checkpoint + +`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ + +``` +python networks/flux_merge_lora.py --flux_model flux1-dev.sft --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu +``` + +You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`. + +`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`): + +- 'cpu' / 'cpu': Uses >50GB of RAM, but works on any machine. +- 'cuda' / 'cpu': Uses 24GB of VRAM, but requires 30GB of RAM. +- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cuda' / 'cpu'. + +In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`. + +The script can merge multiple LoRA models. If you want to merge multiple LoRA models, specify `--concat` option to work the merged LoRA model properly. + +``` + ## SD3 training SD3 training is done with `sd3_train.py`. diff --git a/library/train_util.py b/library/train_util.py index 59ec3e56d..fa0eb9e51 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3160,7 +3160,7 @@ def load_metadata_from_safetensors(safetensors_file: str) -> dict: def build_minimum_network_metadata( - v2: Optional[bool], + v2: Optional[str], base_model: Optional[str], network_module: str, network_dim: str, diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py new file mode 100644 index 000000000..c3986ef1f --- /dev/null +++ b/networks/flux_merge_lora.py @@ -0,0 +1,361 @@ +import math +import argparse +import os +import time +import torch +from safetensors import safe_open +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from library import sai_model_spec, train_util +import networks.lora_flux as lora_flux +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == ".safetensors": + sd = load_file(file_name) + metadata = train_util.load_metadata_from_safetensors(file_name) + else: + sd = torch.load(file_name, map_location="cpu") + metadata = {} + + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + + return sd, metadata + + +def save_to_file(file_name, state_dict, dtype, metadata): + if dtype is not None: + logger.info(f"converting to {dtype}...") + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + logger.info(f"saving to: {file_name}") + save_file(state_dict, file_name, metadata=metadata) + + +def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): + # create module map without loading state_dict + logger.info(f"loading keys from FLUX.1 model: {flux_model}") + lora_name_to_module_key = {} + with safe_open(flux_model, framework="pt", device=loading_device) as flux_file: + keys = list(flux_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") + lora_name_to_module_key[lora_name] = key + + flux_state_dict = load_file(flux_model, device=loading_device) + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU + + logger.info(f"merging...") + for key in tqdm(lora_sd.keys()): + if "lora_down" in key: + lora_name = key[: key.rfind(".lora_down")] + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + if lora_name not in lora_name_to_module_key: + logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.") + continue + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + module_weight_key = lora_name_to_module_key[lora_name] + if module_weight_key not in flux_state_dict: + weight = flux_file.get_tensor(module_weight_key) + else: + weight = flux_state_dict[module_weight_key] + + weight = weight.to(working_device, merge_dtype) + up_weight = up_weight.to(working_device, merge_dtype) + down_weight = down_weight.to(working_device, merge_dtype) + + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + del up_weight + del down_weight + del weight + + return flux_state_dict + + +def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): + base_alphas = {} # alpha for merged model + base_dims = {} + + merged_sd = {} + base_model = None + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, lora_metadata = load_state_dict(model, merge_dtype) + + if lora_metadata is not None: + if base_model is None: + base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + + # get alpha and dim + alphas = {} # alpha for current model + dims = {} # dims for current model + for key in lora_sd.keys(): + if "alpha" in key: + lora_module_name = key[: key.rfind(".alpha")] + alpha = float(lora_sd[key].detach().numpy()) + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + elif "lora_down" in key: + lora_module_name = key[: key.rfind(".lora_down")] + dim = lora_sd[key].size()[0] + dims[lora_module_name] = dim + if lora_module_name not in base_dims: + base_dims[lora_module_name] = dim + + for lora_module_name in dims.keys(): + if lora_module_name not in alphas: + alpha = dims[lora_module_name] + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + + # merge + logger.info(f"merging...") + for key in tqdm(lora_sd.keys()): + if "alpha" in key: + continue + + if "lora_up" in key and concat: + concat_dim = 1 + elif "lora_down" in key and concat: + concat_dim = 0 + else: + concat_dim = None + + lora_module_name = key[: key.rfind(".lora_")] + + base_alpha = base_alphas[lora_module_name] + alpha = alphas[lora_module_name] + + scale = math.sqrt(alpha / base_alpha) * ratio + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + + if key in merged_sd: + assert ( + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None + ), f"weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" + if concat_dim is not None: + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + else: + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + else: + merged_sd[key] = lora_sd[key] * scale + + # set alpha to sd + for lora_module_name, alpha in base_alphas.items(): + key = lora_module_name + ".alpha" + merged_sd[key] = torch.tensor(alpha) + if shuffle: + key_down = lora_module_name + ".lora_down.weight" + key_up = lora_module_name + ".lora_up.weight" + dim = merged_sd[key_down].shape[0] + perm = torch.randperm(dim) + merged_sd[key_down] = merged_sd[key_down][perm] + merged_sd[key_up] = merged_sd[key_up][:, perm] + + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + + # check all dims are same + dims_list = list(set(base_dims.values())) + alphas_list = list(set(base_alphas.values())) + all_same_dims = True + all_same_alphas = True + for dims in dims_list: + if dims != dims_list[0]: + all_same_dims = False + break + for alphas in alphas_list: + if alphas != alphas_list[0]: + all_same_alphas = False + break + + # build minimum metadata + dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" + alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" + metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None) + + return merged_sd, metadata + + +def merge(args): + assert len(args.models) == len( + args.ratios + ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + dest_dir = os.path.dirname(args.save_to) + if not os.path.exists(dest_dir): + logger.info(f"creating directory: {dest_dir}") + os.makedirs(dest_dir) + + if args.flux_model is not None: + state_dict = merge_to_flux_model( + args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + ) + + if args.no_metadata: + sai_metadata = None + else: + merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" + ) + + logger.info(f"saving FLUX model to: {args.save_to}") + save_to_file(args.save_to, state_dict, save_dtype, sai_metadata) + + else: + state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + + logger.info(f"calculating hashes and creating metadata...") + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + if not args.no_metadata: + merged_from = sai_model_spec.build_merged_from(args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" + ) + metadata.update(sai_metadata) + + logger.info(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, save_dtype, metadata) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + ) + parser.add_argument( + "--precision", + type=str, + default="float", + choices=["float", "fp16", "bf16"], + help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", + ) + parser.add_argument( + "--flux_model", + type=str, + default=None, + help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする", + ) + parser.add_argument( + "--loading_device", + type=str, + default="cpu", + help="device to load FLUX.1 model. LoRA models are loaded on CPU / FLUX.1モデルを読み込むデバイス。LoRAモデルはCPUで読み込まれます", + ) + parser.add_argument( + "--working_device", + type=str, + default="cpu", + help="device to work (merge). Merging LoRA models are done on CPU." + + " / 作業(マージ)するデバイス。LoRAモデルのマージはCPUで行われます。", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + help="destination file name: safetensors file / 保存先のファイル名、safetensorsファイル", + ) + parser.add_argument( + "--models", + type=str, + nargs="*", + help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル", + ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + parser.add_argument( + "--concat", + action="store_true", + help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " + + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle lora weight./ " + "LoRAの重みをシャッフルする", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + merge(args) From 7367584e6749448cb9b012df0d3bcbe4f0531ea5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 17 Aug 2024 14:38:34 +0900 Subject: [PATCH 048/163] fix sd3 training to work without cachine TE outputs #1465 --- sd3_train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index 9c37cbce6..3b6c8a118 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -759,8 +759,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # TODO support weighted captions - input_ids_clip_l = input_ids_clip_l.to(accelerator.device) - input_ids_clip_g = input_ids_clip_g.to(accelerator.device) + # text models in sd3_models require "cpu" for input_ids + input_ids_clip_l = input_ids_clip_l.to("cpu") + input_ids_clip_g = input_ids_clip_g.to("cpu") lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [clip_l, clip_g, None], @@ -770,7 +771,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): if t5_out is None: _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] with torch.no_grad(): - input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None + input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None _, t5_out, _ = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] ) From 400955d3ea4088e8da7a3917dec9b0664424e24a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 17 Aug 2024 15:36:18 +0900 Subject: [PATCH 049/163] add fine tuning FLUX.1 (WIP) --- flux_train.py | 729 ++++++++++++++++++++++++++++++++++++ flux_train_network.py | 168 +-------- library/flux_train_utils.py | 270 ++++++++++++- library/train_util.py | 2 +- 4 files changed, 1007 insertions(+), 162 deletions(-) create mode 100644 flux_train.py diff --git a/flux_train.py b/flux_train.py new file mode 100644 index 000000000..2ca20ded2 --- /dev/null +++ b/flux_train.py @@ -0,0 +1,729 @@ +# training with captions + +import argparse +import copy +import math +import os +from multiprocessing import Value +from typing import List +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux +from library.sd3_train_utils import load_prompts, FlowMatchEulerDiscreteScheduler + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False + ) + ) + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" + + # load VAE for caching latents + ae = None + if cache_latents: + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() + + train_dataset_group.new_cache_latents(ae, accelerator.is_main_process) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.t5xxl_max_token_length is None: + if name == "schnell": + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) + + # load clip_l, t5xxl for caching text encoder outputs + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") + clip_l.eval() + t5xxl.eval() + clip_l.requires_grad_(False) + t5xxl.requires_grad_(False) + + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator.is_main_process) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + clip_l = None + t5xxl = None + + # load FLUX + # if we load to cpu, flux.to(fp8) takes a long time + flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + + if args.gradient_checkpointing: + flux.enable_gradient_checkpointing() + + flux.requires_grad_(True) + + if not cache_latents: + # load VAE here if not cached + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(flux) + params_to_optimize.append({"params": list(flux.parameters()), "lr": args.learning_rate}) + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.fused_optimizer_groups: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # This balances memory usage and management complexity. + + # calculate total number of parameters + n_total_params = sum(len(params["params"]) for params in params_to_optimize) + params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) + + # split params into groups, keeping the learning rate the same for all params in a group + # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + grouped_params = [] + param_group = [] + param_group_lr = -1 + for group in params_to_optimize: + lr = group["lr"] + for p in group["params"]: + # if the learning rate is different for different params, start a new group + if lr != param_group_lr: + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = lr + + param_group.append(p) + + # if the group has enough parameters, start a new group + if len(param_group) == params_per_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = -1 + + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + flux.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + flux.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=flux) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # acceleratorがなんかよろしくやってくれるらしい + flux = accelerator.prepare(flux) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + elif args.fused_optimizer_groups: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def optimizer_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(optimizer_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + # For --sample_at_first + flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"]) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance + guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) + + # call model + l_pooled, t5_out, txt_ids = text_encoder_conds + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + loss = train_util.conditional_loss( + model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + ) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.fused_optimizer_groups): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), + ) + + flux_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + ) + + is_main_process = accelerator.is_main_process + # if is_main_process: + flux = accelerator.unwrap_model(flux) + clip_l = accelerator.unwrap_model(clip_l) + clip_g = accelerator.unwrap_model(clip_g) + if t5xxl is not None: + t5xxl = accelerator.unwrap_model(t5xxl) + + accelerator.end_training() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux, ae) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="skip latents validity check / latentsの正当性チェックをスキップする", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/flux_train_network.py b/flux_train_network.py index b9a29c160..002252c87 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -274,85 +274,14 @@ def get_noise_pred_and_target( weight_dtype, train_unet, ): - # copy from sd3_train.py and modified - - def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): - sigmas = self.noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) - schedule_timesteps = self.noise_scheduler_copy.timesteps.to(accelerator.device) - timesteps = timesteps.to(accelerator.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - return sigma - - def compute_density_for_timestep_sampling( - weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None - ): - """Compute the density for sampling the timesteps when doing SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") - u = torch.nn.functional.sigmoid(u) - elif weighting_scheme == "mode": - u = torch.rand(size=(batch_size,), device="cpu") - u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) - else: - u = torch.rand(size=(batch_size,), device="cpu") - return u - - def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): - """Computes loss weighting scheme for SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "sigma_sqrt": - weighting = (sigmas**-2.0).float() - elif weighting_scheme == "cosmap": - bot = 1 - 2 * sigmas + 2 * sigmas**2 - weighting = 2 / (math.pi * bot) - else: - weighting = torch.ones_like(sigmas) - return weighting - # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] - if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": - # Simple random t-based noise sampling - if args.timestep_sampling == "sigmoid": - # https://github.com/XLabs-AI/x-flux/tree/main - t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=accelerator.device)) - else: - t = torch.rand((bsz,), device=accelerator.device) - timesteps = t * 1000.0 - t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * noise - else: - # Sample a random timestep for each image - # for weighting schemes where we sample timesteps non-uniformly - u = compute_density_for_timestep_sampling( - weighting_scheme=args.weighting_scheme, - batch_size=bsz, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, - ) - indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) - - # Add noise according to flow matching. - sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 @@ -425,20 +354,8 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - if args.model_prediction_type == "raw": - # use model_pred as is - weighting = None - elif args.model_prediction_type == "additive": - # add the model_pred to the noisy_model_input - model_pred = model_pred + noisy_model_input - weighting = None - elif args.model_prediction_type == "sigma_scaled": - # apply sigma scaling - model_pred = model_pred * (-sigmas) + noisy_model_input - - # these weighting schemes use a uniform timestep sampling - # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) # flow matching loss: this is different from SD3 target = noise - latents @@ -469,83 +386,14 @@ def is_text_encoder_not_needed_for_training(self, args): def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() - # sdxl_train_util.add_sdxl_training_arguments(parser) - parser.add_argument("--clip_l", type=str, help="path to clip_l") - parser.add_argument("--t5xxl", type=str, help="path to t5xxl") - parser.add_argument("--ae", type=str, help="path to ae") - parser.add_argument("--apply_t5_attn_mask", action="store_true") - parser.add_argument( - "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" - ) - parser.add_argument( - "--cache_text_encoder_outputs_to_disk", - action="store_true", - help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", - ) + flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument( "--split_mode", action="store_true", help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", ) - parser.add_argument( - "--t5xxl_max_token_length", - type=int, - default=None, - help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev" - " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512", - ) - # copy from Diffusers - parser.add_argument( - "--weighting_scheme", - type=str, - default="none", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - ) - parser.add_argument( - "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." - ) - parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") - parser.add_argument( - "--mode_scale", - type=float, - default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", - ) - parser.add_argument( - "--guidance_scale", - type=float, - default=3.5, - help="the FLUX.1 dev variant is a guidance distilled model", - ) - - parser.add_argument( - "--timestep_sampling", - choices=["sigma", "uniform", "sigmoid"], - default="sigma", - help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。", - ) - parser.add_argument( - "--sigmoid_scale", - type=float, - default=1.0, - help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', - ) - parser.add_argument( - "--model_prediction_type", - choices=["raw", "additive", "sigma_scaled"], - default="sigma_scaled", - help="How to interpret and process the model prediction: " - "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." - " / モデル予測の解釈と処理方法:" - "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", - ) - parser.add_argument( - "--discrete_flow_shift", - type=float, - default=3.0, - help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", - ) return parser diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 91f522389..167d61c7e 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -12,8 +12,9 @@ from transformers import CLIPTextModel from tqdm import tqdm from PIL import Image +from safetensors.torch import save_file -from library import flux_models, flux_utils, strategy_base +from library import flux_models, flux_utils, strategy_base, train_util from library.sd3_train_utils import load_prompts from library.device_utils import init_ipex, clean_memory_on_device @@ -27,6 +28,9 @@ logger = logging.getLogger(__name__) +# region sample images + + def sample_images( accelerator: Accelerator, args: argparse.Namespace, @@ -295,3 +299,267 @@ def denoise( img = img + (t_prev - t_curr) * pred return img + + +# endregion + + +# region train +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz = latents.shape[0] + sigmas = None + + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main + t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + else: + t = torch.rand((bsz,), device=device) + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input, timesteps, sigmas + + +def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): + weighting = None + if args.model_prediction_type == "raw": + pass + elif args.model_prediction_type == "additive": + # add the model_pred to the noisy_model_input + model_pred = model_pred + noisy_model_input + elif args.model_prediction_type == "sigma_scaled": + # apply sigma scaling + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + return model_pred, weighting + + +def save_models(ckpt_path: str, flux: flux_models.Flux, sai_metadata: Optional[dict], save_dtype: Optional[torch.dtype] = None): + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("", flux.state_dict()) + + save_file(state_dict, ckpt_path, metadata=sai_metadata) + + +def save_flux_model_on_train_end( + args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") + save_models(ckpt_file, flux, sai_metadata, save_dtype) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_flux_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + flux: flux_models.Flux, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") + save_models(ckpt_file, flux, sai_metadata, save_dtype) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +# endregion + + +def add_flux_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--clip_l", + type=str, + help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument( + "--t5xxl", + type=str, + help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=None, + help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev" + " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + parser.add_argument( + "--text_encoder_batch_size", + type=int, + default=None, + help="text encoder batch size (default: None, use dataset's batch size)" + + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", + ) + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) + + # copy from Diffusers + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="sigma_scaled", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=3.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + ) diff --git a/library/train_util.py b/library/train_util.py index fa0eb9e51..f4ac8740a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2629,7 +2629,7 @@ def __getitem__(self, idx): raise NotImplementedError -def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: +def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset: module = ".".join(args.dataset_class.split(".")[:-1]) dataset_class = args.dataset_class.split(".")[-1] module = importlib.import_module(module) From 25f77f6ef04ee760506338e7e7f9835c28657c59 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 17 Aug 2024 15:54:32 +0900 Subject: [PATCH 050/163] fix flux fine tuning to work --- README.md | 4 ++++ flux_train.py | 6 ++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e231cc24e..2b7b110f3 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,10 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` + +Aug 17. 2024: +Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training. + Aug 16, 2024: Added a script `networks/flux_merge_lora.py` to merge LoRA into FLUX.1 checkpoint. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. diff --git a/flux_train.py b/flux_train.py index 2ca20ded2..d2a9b3f32 100644 --- a/flux_train.py +++ b/flux_train.py @@ -674,9 +674,7 @@ def optimizer_hook(parameter: torch.Tensor): # if is_main_process: flux = accelerator.unwrap_model(flux) clip_l = accelerator.unwrap_model(clip_l) - clip_g = accelerator.unwrap_model(clip_g) - if t5xxl is not None: - t5xxl = accelerator.unwrap_model(t5xxl) + t5xxl = accelerator.unwrap_model(t5xxl) accelerator.end_training() @@ -686,7 +684,7 @@ def optimizer_hook(parameter: torch.Tensor): del accelerator # この後メモリを使うのでこれは消す if is_main_process: - flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux, ae) + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux) logger.info("model saved.") From 7e688913aef4c852f54a703c9f91d135b17dff87 Mon Sep 17 00:00:00 2001 From: exveria1015 Date: Sun, 18 Aug 2024 12:38:05 +0900 Subject: [PATCH 051/163] =?UTF-8?q?fix:=20Flux=20=E3=81=AE=20LoRA=20?= =?UTF-8?q?=E3=83=9E=E3=83=BC=E3=82=B8=E6=A9=9F=E8=83=BD=E3=82=92=E4=BF=AE?= =?UTF-8?q?=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/flux_merge_lora.py | 364 +++++++++++++++++++++++++++++------- 1 file changed, 297 insertions(+), 67 deletions(-) diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index c3986ef1f..df0ba606a 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -1,13 +1,14 @@ -import math import argparse +import math import os import time + import torch -from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm + +import lora_flux as lora_flux from library import sai_model_spec, train_util -import networks.lora_flux as lora_flux from library.utils import setup_logging setup_logging() @@ -42,34 +43,181 @@ def save_to_file(file_name, state_dict, dtype, metadata): save_file(state_dict, file_name, metadata=metadata) -def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): - # create module map without loading state_dict +def merge_to_flux_model( + loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype +): logger.info(f"loading keys from FLUX.1 model: {flux_model}") - lora_name_to_module_key = {} - with safe_open(flux_model, framework="pt", device=loading_device) as flux_file: - keys = list(flux_file.keys()) - for key in keys: - if key.endswith(".weight"): - module_name = ".".join(key.split(".")[:-1]) - lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") - lora_name_to_module_key[lora_name] = key - flux_state_dict = load_file(flux_model, device=loading_device) + + def create_key_map(n_double_layers, n_single_layers, hidden_size): + key_map = {} + for index in range(n_double_layers): + prefix_from = f"transformer_blocks.{index}" + prefix_to = f"double_blocks.{index}" + + for end in ("weight", "bias"): + k = f"{prefix_from}.attn." + qkv_img = f"{prefix_to}.img_attn.qkv.{end}" + qkv_txt = f"{prefix_to}.txt_attn.qkv.{end}" + + key_map[f"{k}to_q.{end}"] = (qkv_img, (0, 0, hidden_size)) + key_map[f"{k}to_k.{end}"] = (qkv_img, (0, hidden_size, hidden_size)) + key_map[f"{k}to_v.{end}"] = (qkv_img, (0, hidden_size * 2, hidden_size)) + key_map[f"{k}add_q_proj.{end}"] = (qkv_txt, (0, 0, hidden_size)) + key_map[f"{k}add_k_proj.{end}"] = ( + qkv_txt, + (0, hidden_size, hidden_size), + ) + key_map[f"{k}add_v_proj.{end}"] = ( + qkv_txt, + (0, hidden_size * 2, hidden_size), + ) + + block_map = { + "attn.to_out.0.weight": "img_attn.proj.weight", + "attn.to_out.0.bias": "img_attn.proj.bias", + "norm1.linear.weight": "img_mod.lin.weight", + "norm1.linear.bias": "img_mod.lin.bias", + "norm1_context.linear.weight": "txt_mod.lin.weight", + "norm1_context.linear.bias": "txt_mod.lin.bias", + "attn.to_add_out.weight": "txt_attn.proj.weight", + "attn.to_add_out.bias": "txt_attn.proj.bias", + "ff.net.0.proj.weight": "img_mlp.0.weight", + "ff.net.0.proj.bias": "img_mlp.0.bias", + "ff.net.2.weight": "img_mlp.2.weight", + "ff.net.2.bias": "img_mlp.2.bias", + "ff_context.net.0.proj.weight": "txt_mlp.0.weight", + "ff_context.net.0.proj.bias": "txt_mlp.0.bias", + "ff_context.net.2.weight": "txt_mlp.2.weight", + "ff_context.net.2.bias": "txt_mlp.2.bias", + "attn.norm_q.weight": "img_attn.norm.query_norm.scale", + "attn.norm_k.weight": "img_attn.norm.key_norm.scale", + "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", + "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", + } + + for k, v in block_map.items(): + key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}" + + for index in range(n_single_layers): + prefix_from = f"single_transformer_blocks.{index}" + prefix_to = f"single_blocks.{index}" + + for end in ("weight", "bias"): + k = f"{prefix_from}.attn." + qkv = f"{prefix_to}.linear1.{end}" + key_map[f"{k}to_q.{end}"] = (qkv, (0, 0, hidden_size)) + key_map[f"{k}to_k.{end}"] = (qkv, (0, hidden_size, hidden_size)) + key_map[f"{k}to_v.{end}"] = (qkv, (0, hidden_size * 2, hidden_size)) + key_map[f"{prefix_from}.proj_mlp.{end}"] = ( + qkv, + (0, hidden_size * 3, hidden_size * 4), + ) + + block_map = { + "norm.linear.weight": "modulation.lin.weight", + "norm.linear.bias": "modulation.lin.bias", + "proj_out.weight": "linear2.weight", + "proj_out.bias": "linear2.bias", + "attn.norm_q.weight": "norm.query_norm.scale", + "attn.norm_k.weight": "norm.key_norm.scale", + } + + for k, v in block_map.items(): + key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}" + + return key_map + + key_map = create_key_map( + 18, 1, 2048 + ) # Assuming 18 double layers, 1 single layer, and hidden size of 2048 + + def find_matching_key(flux_dict, lora_key): + lora_key = lora_key.replace("diffusion_model.", "") + lora_key = lora_key.replace("transformer.", "") + lora_key = lora_key.replace("lora_A", "lora_down").replace("lora_B", "lora_up") + lora_key = lora_key.replace("single_transformer_blocks", "single_blocks") + lora_key = lora_key.replace("transformer_blocks", "double_blocks") + + double_block_map = { + "attn.to_out.0": "img_attn.proj", + "norm1.linear": "img_mod.lin", + "norm1_context.linear": "txt_mod.lin", + "attn.to_add_out": "txt_attn.proj", + "ff.net.0.proj": "img_mlp.0", + "ff.net.2": "img_mlp.2", + "ff_context.net.0.proj": "txt_mlp.0", + "ff_context.net.2": "txt_mlp.2", + "attn.norm_q": "img_attn.norm.query_norm", + "attn.norm_k": "img_attn.norm.key_norm", + "attn.norm_added_q": "txt_attn.norm.query_norm", + "attn.norm_added_k": "txt_attn.norm.key_norm", + "attn.to_q": "img_attn.qkv", + "attn.to_k": "img_attn.qkv", + "attn.to_v": "img_attn.qkv", + "attn.add_q_proj": "txt_attn.qkv", + "attn.add_k_proj": "txt_attn.qkv", + "attn.add_v_proj": "txt_attn.qkv", + } + + single_block_map = { + "norm.linear": "modulation.lin", + "proj_out": "linear2", + "attn.norm_q": "norm.query_norm", + "attn.norm_k": "norm.key_norm", + "attn.to_q": "linear1", + "attn.to_k": "linear1", + "attn.to_v": "linear1", + } + + for old, new in double_block_map.items(): + lora_key = lora_key.replace(old, new) + + for old, new in single_block_map.items(): + lora_key = lora_key.replace(old, new) + + if lora_key in key_map: + flux_key = key_map[lora_key] + if isinstance(flux_key, tuple): + flux_key = flux_key[0] + logger.info(f"Found matching key: {flux_key}") + return flux_key + + # If not found in key_map, try partial matching + potential_key = lora_key + ".weight" + logger.info(f"Searching for key: {potential_key}") + matches = [k for k in flux_dict.keys() if potential_key in k] + if matches: + logger.info(f"Found matching key: {matches[0]}") + return matches[0] + return None + + merged_keys = set() for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") - lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU + lora_sd, _ = load_state_dict(model, merge_dtype) - logger.info(f"merging...") + logger.info("merging...") for key in tqdm(lora_sd.keys()): - if "lora_down" in key: - lora_name = key[: key.rfind(".lora_down")] - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" - - if lora_name not in lora_name_to_module_key: - logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.") + if "lora_down" in key or "lora_A" in key: + lora_name = key[ + : key.rfind(".lora_down" if "lora_down" in key else ".lora_A") + ] + up_key = key.replace("lora_down", "lora_up").replace("lora_A", "lora_B") + alpha_key = ( + key[: key.index("lora_down" if "lora_down" in key else "lora_A")] + + "alpha" + ) + + logger.info(f"Processing LoRA key: {lora_name}") + flux_key = find_matching_key(flux_state_dict, lora_name) + + if flux_key is None: + logger.warning(f"no module found for LoRA weight: {key}") continue + logger.info(f"Merging LoRA key {lora_name} into Flux key {flux_key}") + down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -77,40 +225,74 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati alpha = lora_sd.get(alpha_key, dim) scale = alpha / dim - # W <- W + U * D - module_weight_key = lora_name_to_module_key[lora_name] - if module_weight_key not in flux_state_dict: - weight = flux_file.get_tensor(module_weight_key) - else: - weight = flux_state_dict[module_weight_key] + weight = flux_state_dict[flux_key] weight = weight.to(working_device, merge_dtype) up_weight = up_weight.to(working_device, merge_dtype) down_weight = down_weight.to(working_device, merge_dtype) - # logger.info(module_name, down_weight.size(), up_weight.size()) - if len(weight.size()) == 2: - # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + ratio - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) + if lora_name.startswith("transformer."): + if "qkv" in flux_key: + hidden_size = weight.size(-1) // 3 + update = ratio * (up_weight @ down_weight) * scale + + if "img_attn" in flux_key or "txt_attn" in flux_key: + q, k, v = torch.chunk(weight, 3, dim=-1) + if "to_q" in lora_name or "add_q_proj" in lora_name: + q += update.reshape(q.shape) + elif "to_k" in lora_name or "add_k_proj" in lora_name: + k += update.reshape(k.shape) + elif "to_v" in lora_name or "add_v_proj" in lora_name: + v += update.reshape(v.shape) + weight = torch.cat([q, k, v], dim=-1) + else: + if len(weight.size()) == 2: + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + weight = ( + weight + + ratio + * ( + up_weight.squeeze(3).squeeze(2) + @ down_weight.squeeze(3).squeeze(2) + ) + .unsqueeze(2) + .unsqueeze(3) + * scale + ) + else: + conved = torch.nn.functional.conv2d( + down_weight.permute(1, 0, 2, 3), up_weight + ).permute(1, 0, 2, 3) + weight = weight + ratio * conved * scale else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale - - flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + if len(weight.size()) == 2: + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + weight = ( + weight + + ratio + * ( + up_weight.squeeze(3).squeeze(2) + @ down_weight.squeeze(3).squeeze(2) + ) + .unsqueeze(2) + .unsqueeze(3) + * scale + ) + else: + conved = torch.nn.functional.conv2d( + down_weight.permute(1, 0, 2, 3), up_weight + ).permute(1, 0, 2, 3) + weight = weight + ratio * conved * scale + + flux_state_dict[flux_key] = weight.to(loading_device, save_dtype) + merged_keys.add(flux_key) del up_weight del down_weight del weight + logger.info(f"Merged keys: {sorted(list(merged_keys))}") return flux_state_dict @@ -126,7 +308,9 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_metadata is not None: if base_model is None: - base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + base_model = lora_metadata.get( + train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None + ) # get alpha and dim alphas = {} # alpha for current model @@ -152,10 +336,12 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + logger.info( + f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}" + ) # merge - logger.info(f"merging...") + logger.info("merging...") for key in tqdm(lora_sd.keys()): if "alpha" in key: continue @@ -173,14 +359,19 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio - scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + scale = ( + abs(scale) if "lora_up" in key else scale + ) # マイナスの重みに対応する。 if key in merged_sd: assert ( - merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None - ), f"weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" + merged_sd[key].size() == lora_sd[key].size() + or concat_dim is not None + ), "weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" if concat_dim is not None: - merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + merged_sd[key] = torch.cat( + [merged_sd[key], lora_sd[key] * scale], dim=concat_dim + ) else: merged_sd[key] = merged_sd[key] + lora_sd[key] * scale else: @@ -199,7 +390,9 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_up] = merged_sd[key_up][:, perm] logger.info("merged model") - logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + logger.info( + f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}" + ) # check all dims are same dims_list = list(set(base_dims.values())) @@ -218,15 +411,17 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): # build minimum metadata dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" - metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None) + metadata = train_util.build_minimum_network_metadata( + str(False), base_model, "networks.lora", dims, alphas, None + ) return merged_sd, metadata def merge(args): - assert len(args.models) == len( - args.ratios - ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert ( + len(args.models) == len(args.ratios) + ), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" def str_to_dtype(p): if p == "float": @@ -249,27 +444,48 @@ def str_to_dtype(p): if args.flux_model is not None: state_dict = merge_to_flux_model( - args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + args.loading_device, + args.working_device, + args.flux_model, + args.models, + args.ratios, + merge_dtype, + save_dtype, ) if args.no_metadata: sai_metadata = None else: - merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) + merged_from = sai_model_spec.build_merged_from( + [args.flux_model] + args.models + ) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" + None, + False, + False, + False, + False, + False, + time.time(), + title=title, + merged_from=merged_from, + flux="dev", ) logger.info(f"saving FLUX model to: {args.save_to}") save_to_file(args.save_to, state_dict, save_dtype, sai_metadata) else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + state_dict, metadata = merge_lora_models( + args.models, args.ratios, merge_dtype, args.concat, args.shuffle + ) - logger.info(f"calculating hashes and creating metadata...") + logger.info("calculating hashes and creating metadata...") - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes( + state_dict, metadata + ) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash @@ -277,7 +493,16 @@ def str_to_dtype(p): merged_from = sai_model_spec.build_merged_from(args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" + state_dict, + False, + False, + False, + True, + False, + time.time(), + title=title, + merged_from=merged_from, + flux="dev", ) metadata.update(sai_metadata) @@ -332,7 +557,12 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル", ) - parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument( + "--ratios", + type=float, + nargs="*", + help="ratios for each model / それぞれのLoRAモデルの比率", + ) parser.add_argument( "--no_metadata", action="store_true", From ef535ec6bb99918027afc1e31efa72cd3761d453 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 18 Aug 2024 16:54:18 +0900 Subject: [PATCH 052/163] add memory efficient training for FLUX.1 --- README.md | 64 ++++++++++++-- flux_train.py | 187 +++++++++++++++++++++++++++++------------ library/flux_models.py | 182 ++++++++++++++++++++++++++++++++++----- 3 files changed, 354 insertions(+), 79 deletions(-) diff --git a/README.md b/README.md index 2b7b110f3..521e82e86 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,11 @@ The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` -Aug 17. 2024: +Aug 18, 2024: +Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. + + +Aug 17, 2024: Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training. Aug 16, 2024: @@ -39,11 +43,23 @@ Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-ge Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. -We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. + +### FLUX.1 LoRA training + +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. ``` -accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2 +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py +--pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors +--ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers +--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 +--network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 +--network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base +--highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml +--output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid +--model_prediction_type raw --guidance_scale 1.0 --loss_type l2 ``` +(The command is multi-line for readability. Please combine it into one line.) The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: @@ -80,12 +96,44 @@ The trained LoRA model can be used with ComfyUI. The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. -Aug 12: `--interactive` option is now working. - ``` python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` +### FLUX.1 fine-tuning + +Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended. + +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py +--pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft +--mixed_precision bf16 --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 +--seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 +--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name test-bf16 +--learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1 +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" +--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 +--blockwise_fused_optimizer --double_blocks_to_swap 6 --cpu_offload_checkpointing +``` + +(Combine the command into one line.) + +Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizer`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. + +`--blockwise_fused_optimizer` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. + +`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--blockwise_fused_optimizer`. + +`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. + +All these options are experimental and may change in the future. + +The increasing the number of blocks to swap may reduce the memory usage, but the training speed will be slower. `--cpu_offload_checkpointing` also slows down the training. + +Swap 6 double blocks and use cpu offload checkpointing may be a good starting point. Please try different settings according to VRAM usage and training speed. + +The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results. + ### Merge LoRA to FLUX.1 checkpoint `networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ @@ -298,7 +346,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. - - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. + - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only Adafactor is supported. Gradient accumulation is not available. - Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`. - Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size. - PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`. @@ -308,7 +356,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer. - Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10. - Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available. - - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. + - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using Adafactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. - Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side. - LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO! @@ -361,7 +409,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 - - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。 + - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は Adafactor のみ対応しています。また gradient accumulation は使えません。 - mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。 - バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。 - PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。 diff --git a/flux_train.py b/flux_train.py index d2a9b3f32..ecb3c7dda 100644 --- a/flux_train.py +++ b/flux_train.py @@ -1,5 +1,15 @@ # training with captions +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + import argparse import copy import math @@ -54,6 +64,12 @@ def train(args): ) args.cache_text_encoder_outputs = True + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True + cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -232,16 +248,25 @@ def train(args): # now we can delete Text Encoders to free memory clip_l = None t5xxl = None + clean_memory_on_device(accelerator.device) # load FLUX # if we load to cpu, flux.to(fp8) takes a long time flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") if args.gradient_checkpointing: - flux.enable_gradient_checkpointing() + flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing) flux.requires_grad_(True) + if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info( + f"enable block swap: double_blocks_to_swap={args.double_blocks_to_swap}, single_blocks_to_swap={args.single_blocks_to_swap}" + ) + flux.enable_block_swap(args.double_blocks_to_swap, args.single_blocks_to_swap) + if not cache_latents: # load VAE here if not cached ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") @@ -265,40 +290,43 @@ def train(args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html - # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. # This balances memory usage and management complexity. - # calculate total number of parameters - n_total_params = sum(len(params["params"]) for params in params_to_optimize) - params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) - - # split params into groups, keeping the learning rate the same for all params in a group - # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + # split params into groups. currently different learning rates are not supported grouped_params = [] - param_group = [] - param_group_lr = -1 + param_group = {} for group in params_to_optimize: - lr = group["lr"] - for p in group["params"]: - # if the learning rate is different for different params, start a new group - if lr != param_group_lr: - if param_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) - param_group = [] - param_group_lr = lr - - param_group.append(p) - - # if the group has enough parameters, start a new group - if len(param_group) == params_per_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) - param_group = [] - param_group_lr = -1 - - if param_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) + named_parameters = list(flux.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_idx = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_idx = int(np[0].split(".")[1]) + block_type = "single" + else: + block_idx = -1 + + param_group_key = (block_type, block_idx) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") # prepare optimizers for each group optimizers = [] @@ -307,7 +335,7 @@ def train(args): optimizers.append(optimizer) optimizer = optimizers[0] # avoid error in the following code - logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) @@ -341,7 +369,7 @@ def train(args): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: # prepare lr schedulers for each optimizer lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] lr_scheduler = lr_schedulers[0] # avoid error in the following code @@ -414,7 +442,7 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): parameter.register_post_accumulate_grad_hook(__grad_hook) - elif args.fused_optimizer_groups: + elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers for i in range(1, len(optimizers)): optimizers[i] = accelerator.prepare(optimizers[i]) @@ -429,22 +457,46 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} + double_blocks_to_swap = args.double_blocks_to_swap + single_blocks_to_swap = args.single_blocks_to_swap + num_double_blocks = len(flux.double_blocks) + num_single_blocks = len(flux.single_blocks) + for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: - - def optimizer_hook(parameter: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(parameter, args.max_grad_norm) - - i = parameter_optimizer_map[parameter] - optimizer_hooked_count[i] += 1 - if optimizer_hooked_count[i] == num_parameters_per_group[i]: - optimizers[i].step() - optimizers[i].zero_grad(set_to_none=True) - - parameter.register_post_accumulate_grad_hook(optimizer_hook) + block_type, block_idx = block_types_and_indices[opt_idx] + + def create_optimizer_hook(btype, bidx): + def optimizer_hook(parameter: torch.Tensor): + # print(f"optimizer_hook: {btype}, {bidx}") + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + # swap blocks if necessary + if btype == "double" and double_blocks_to_swap: + if bidx >= num_double_blocks - double_blocks_to_swap: + bidx_cuda = double_blocks_to_swap - (num_double_blocks - bidx) + flux.double_blocks[bidx].to("cpu") + flux.double_blocks[bidx_cuda].to(accelerator.device) + # print(f"Move double block {bidx} to cpu and {bidx_cuda} to device") + elif btype == "single" and single_blocks_to_swap: + if bidx >= num_single_blocks - single_blocks_to_swap: + bidx_cuda = single_blocks_to_swap - (num_single_blocks - bidx) + flux.single_blocks[bidx].to("cpu") + flux.single_blocks[bidx_cuda].to(accelerator.device) + # print(f"Move single block {bidx} to cpu and {bidx_cuda} to device") + + return optimizer_hook + + parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx)) parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 @@ -487,6 +539,9 @@ def optimizer_hook(parameter: torch.Tensor): init_kwargs=init_kwargs, ) + if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + flux.prepare_block_swap_before_forward() + # For --sample_at_first flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) @@ -502,7 +557,7 @@ def optimizer_hook(parameter: torch.Tensor): for step, batch in enumerate(train_dataloader): current_step.value = global_step - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step with accelerator.accumulate(*training_models): @@ -591,7 +646,7 @@ def optimizer_hook(parameter: torch.Tensor): # backward accelerator.backward(loss) - if not (args.fused_backward_pass or args.fused_optimizer_groups): + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] for m in training_models: @@ -604,7 +659,7 @@ def optimizer_hook(parameter: torch.Tensor): else: # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook lr_scheduler.step() - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: for i in range(1, len(optimizers)): lr_schedulers[i].step() @@ -614,7 +669,7 @@ def optimizer_hook(parameter: torch.Tensor): global_step += 1 flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) # 指定ステップごとにモデルを保存 @@ -673,8 +728,6 @@ def optimizer_hook(parameter: torch.Tensor): is_main_process = accelerator.is_main_process # if is_main_process: flux = accelerator.unwrap_model(flux) - clip_l = accelerator.unwrap_model(clip_l) - t5xxl = accelerator.unwrap_model(t5xxl) accelerator.end_training() @@ -707,13 +760,43 @@ def setup_parser() -> argparse.ArgumentParser: "--fused_optimizer_groups", type=int, default=None, - help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", ) parser.add_argument( "--skip_latents_validity_check", action="store_true", help="skip latents validity check / latentsの正当性チェックをスキップする", ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップする'変換ブロック'(約640MB)の数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) + parser.add_argument( + "--single_blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップする'変換ブロック'(約320MB)の数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", + ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index ed0bc8c7d..3f44068f9 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -4,6 +4,11 @@ from dataclasses import dataclass import math +from typing import Optional + +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() import torch from einops import rearrange @@ -466,6 +471,33 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso # region layers + + +# for cpu_offload_checkpointing + + +def to_cuda(x): + if isinstance(x, torch.Tensor): + return x.cuda() + elif isinstance(x, (list, tuple)): + return [to_cuda(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cuda(v) for k, v in x.items()} + else: + return x + + +def to_cpu(x): + if isinstance(x, torch.Tensor): + return x.cpu() + elif isinstance(x, (list, tuple)): + return [to_cpu(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cpu(v) for k, v in x.items()} + else: + return x + + class EmbedND(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: list[int]): super().__init__() @@ -648,16 +680,15 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: ) self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False - def enable_gradient_checkpointing(self): + def enable_gradient_checkpointing(self, cpu_offload: bool = False): self.gradient_checkpointing = True - # self.img_attn.enable_gradient_checkpointing() - # self.txt_attn.enable_gradient_checkpointing() + self.cpu_offload_checkpointing = cpu_offload def disable_gradient_checkpointing(self): self.gradient_checkpointing = False - # self.img_attn.disable_gradient_checkpointing() - # self.txt_attn.disable_gradient_checkpointing() + self.cpu_offload_checkpointing = False def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: img_mod1, img_mod2 = self.img_mod(vec) @@ -694,11 +725,24 @@ def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[T txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) return img, txt - def forward(self, *args, **kwargs): + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: if self.training and self.gradient_checkpointing: - return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, img, txt, vec, pe, use_reentrant=False) + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe) + else: - return self._forward(*args, **kwargs) + return self._forward(img, txt, vec, pe) # def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): # if self.training and self.gradient_checkpointing: @@ -747,12 +791,15 @@ def __init__( self.modulation = Modulation(hidden_size, double=False) self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False - def enable_gradient_checkpointing(self): + def enable_gradient_checkpointing(self, cpu_offload: bool = False): self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload def disable_gradient_checkpointing(self): self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: mod, _ = self.modulation(vec) @@ -768,11 +815,24 @@ def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) return x + mod.gate * output - def forward(self, *args, **kwargs): + def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: if self.training and self.gradient_checkpointing: - return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, x, vec, pe, use_reentrant=False) + + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe) else: - return self._forward(*args, **kwargs) + return self._forward(x, vec, pe) # def forward(self, x: Tensor, vec: Tensor, pe: Tensor): # if self.training and self.gradient_checkpointing: @@ -849,6 +909,9 @@ def __init__(self, params: FluxParams): self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.double_blocks_to_swap = None + self.single_blocks_to_swap = None @property def device(self): @@ -858,8 +921,9 @@ def device(self): def dtype(self): return next(self.parameters()).dtype - def enable_gradient_checkpointing(self): + def enable_gradient_checkpointing(self, cpu_offload: bool = False): self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload self.time_in.enable_gradient_checkpointing() self.vector_in.enable_gradient_checkpointing() @@ -867,12 +931,13 @@ def enable_gradient_checkpointing(self): self.guidance_in.enable_gradient_checkpointing() for block in self.double_blocks + self.single_blocks: - block.enable_gradient_checkpointing() + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) - print("FLUX: Gradient checkpointing enabled.") + print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}") def disable_gradient_checkpointing(self): self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False self.time_in.disable_gradient_checkpointing() self.vector_in.disable_gradient_checkpointing() @@ -884,6 +949,24 @@ def disable_gradient_checkpointing(self): print("FLUX: Gradient checkpointing disabled.") + def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optional[int]): + self.double_blocks_to_swap = double_blocks + self.single_blocks_to_swap = single_blocks + + def prepare_block_swap_before_forward(self): + # move last n blocks to cpu: they are on cuda + if self.double_blocks_to_swap: + for i in range(len(self.double_blocks) - self.double_blocks_to_swap): + self.double_blocks[i].to(self.device) + for i in range(len(self.double_blocks) - self.double_blocks_to_swap, len(self.double_blocks)): + self.double_blocks[i].to("cpu") # , non_blocking=True) + if self.single_blocks_to_swap: + for i in range(len(self.single_blocks) - self.single_blocks_to_swap): + self.single_blocks[i].to(self.device) + for i in range(len(self.single_blocks) - self.single_blocks_to_swap, len(self.single_blocks)): + self.single_blocks[i].to("cpu") # , non_blocking=True) + clean_memory_on_device(self.device) + def forward( self, img: Tensor, @@ -910,14 +993,75 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) - for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + if not self.double_blocks_to_swap: + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + else: + # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning + for block_idx in range(self.double_blocks_to_swap): + block = self.double_blocks[len(self.double_blocks) - self.double_blocks_to_swap + block_idx] + if block.parameters().__next__().device.type != "cpu": + block.to("cpu") # , non_blocking=True) + # print(f"Moved double block {len(self.double_blocks) - self.double_blocks_to_swap + block_idx} to cpu.") + + block = self.double_blocks[block_idx] + if block.parameters().__next__().device.type == "cpu": + block.to(self.device) + # print(f"Moved double block {block_idx} to cuda.") + + to_cpu_block_index = 0 + for block_idx, block in enumerate(self.double_blocks): + # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda + moving = block_idx >= len(self.double_blocks) - self.double_blocks_to_swap + if moving: + block.to(self.device) # move to cuda + # print(f"Moved double block {block_idx} to cuda.") + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + if moving: + self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) + # print(f"Moved double block {to_cpu_block_index} to cpu.") + to_cpu_block_index += 1 img = torch.cat((txt, img), 1) - for block in self.single_blocks: - img = block(img, vec=vec, pe=pe) + + if not self.single_blocks_to_swap: + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + else: + # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning + for block_idx in range(self.single_blocks_to_swap): + block = self.single_blocks[len(self.single_blocks) - self.single_blocks_to_swap + block_idx] + if block.parameters().__next__().device.type != "cpu": + block.to("cpu") # , non_blocking=True) + # print(f"Moved single block {len(self.single_blocks) - self.single_blocks_to_swap + block_idx} to cpu.") + + block = self.single_blocks[block_idx] + if block.parameters().__next__().device.type == "cpu": + block.to(self.device) + # print(f"Moved single block {block_idx} to cuda.") + + to_cpu_block_index = 0 + for block_idx, block in enumerate(self.single_blocks): + # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda + moving = block_idx >= len(self.single_blocks) - self.single_blocks_to_swap + if moving: + block.to(self.device) # move to cuda + # print(f"Moved single block {block_idx} to cuda.") + + img = block(img, vec=vec, pe=pe) + + if moving: + self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) + # print(f"Moved single block {to_cpu_block_index} to cpu.") + img = img[:, txt.shape[1] :, ...] + if self.training and self.cpu_offload_checkpointing: + img = img.to(self.device) + vec = vec.to(self.device) + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img From a45048892802dce43e86a7e377ba84e89b51fdf5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 18 Aug 2024 16:56:50 +0900 Subject: [PATCH 053/163] update readme --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 521e82e86..df2a612d7 100644 --- a/README.md +++ b/README.md @@ -9,10 +9,8 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` - Aug 18, 2024: -Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. - +Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. Aug 17, 2024: Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training. @@ -118,6 +116,8 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t (Combine the command into one line.) +Sample image generation during training is not tested yet. + Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizer`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. `--blockwise_fused_optimizer` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. From d034032a5dff4a5ee1a108e4f1cec41d8efadab0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 19 Aug 2024 13:08:49 +0900 Subject: [PATCH 054/163] update README fix option name --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index df2a612d7..9a603b281 100644 --- a/README.md +++ b/README.md @@ -105,24 +105,24 @@ Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GP ``` accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py --pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft ---mixed_precision bf16 --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 +--save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 ---dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name test-bf16 +--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name --learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1 --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 ---blockwise_fused_optimizer --double_blocks_to_swap 6 --cpu_offload_checkpointing +--blockwise_fused_optimizers --double_blocks_to_swap 6 --cpu_offload_checkpointing ``` (Combine the command into one line.) Sample image generation during training is not tested yet. -Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizer`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. +Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. -`--blockwise_fused_optimizer` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. +`--blockwise_fused_optimizers` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. -`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--blockwise_fused_optimizer`. +`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--blockwise_fused_optimizers`. `--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. From 6e72a799c8f55f148a248693d2c0c3fb1912b04e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 19 Aug 2024 21:55:28 +0900 Subject: [PATCH 055/163] reduce peak VRAM usage by excluding some blocks to cuda --- flux_train.py | 15 +++++++++------ library/flux_models.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/flux_train.py b/flux_train.py index ecb3c7dda..b294ce42a 100644 --- a/flux_train.py +++ b/flux_train.py @@ -251,7 +251,6 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - # if we load to cpu, flux.to(fp8) takes a long time flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") if args.gradient_checkpointing: @@ -259,7 +258,8 @@ def train(args): flux.requires_grad_(True) - if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + is_swapping_blocks = args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None + if is_swapping_blocks: # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! logger.info( @@ -412,8 +412,11 @@ def train(args): training_models = [ds_model] else: - # acceleratorがなんかよろしくやってくれるらしい - flux = accelerator.prepare(flux) + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -539,7 +542,7 @@ def optimizer_hook(parameter: torch.Tensor): init_kwargs=init_kwargs, ) - if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + if is_swapping_blocks: flux.prepare_block_swap_before_forward() # For --sample_at_first @@ -595,7 +598,7 @@ def optimizer_hook(parameter: torch.Tensor): # get noisy model input and timesteps noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype ) # pack latents and get img_ids diff --git a/library/flux_models.py b/library/flux_models.py index 3f44068f9..11ef647ad 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -953,6 +953,22 @@ def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optiona self.double_blocks_to_swap = double_blocks self.single_blocks_to_swap = single_blocks + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu + if self.double_blocks_to_swap: + save_double_blocks = self.double_blocks + self.double_blocks = None + if self.single_blocks_to_swap: + save_single_blocks = self.single_blocks + self.single_blocks = None + + self.to(device) + + if self.double_blocks_to_swap: + self.double_blocks = save_double_blocks + if self.single_blocks_to_swap: + self.single_blocks = save_single_blocks + def prepare_block_swap_before_forward(self): # move last n blocks to cpu: they are on cuda if self.double_blocks_to_swap: From 486fe8f70a53166f21f08b1c896bd9ba1e31d7e7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 19 Aug 2024 22:30:24 +0900 Subject: [PATCH 056/163] feat: reduce memory usage and add memory efficient option for model saving --- README.md | 5 +++ flux_train.py | 6 +++ library/flux_train_utils.py | 21 ++++++++--- library/utils.py | 75 ++++++++++++++++++++++++++++++++++++- 4 files changed, 100 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 9a603b281..51e4635bb 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,11 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 19, 2024: +In `flux_train.py`, the memory consumption during model saving is reduced when `--save_precision` is set to the same value as `--mixed_precision` (about 22GB). Please set the same value unless there is a reason. + +An experimental option `--mem_eff_save` is also added. When specified, it can further reduce memory consumption (about 22GB), but since it is a custom implementation, unexpected problems may occur. We do not recommend using it unless you are familiar with the code. + Aug 18, 2024: Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. diff --git a/flux_train.py b/flux_train.py index b294ce42a..669963856 100644 --- a/flux_train.py +++ b/flux_train.py @@ -759,6 +759,12 @@ def setup_parser() -> argparse.ArgumentParser: add_custom_train_arguments(parser) # TODO remove this from here flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument( + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + parser.add_argument( "--fused_optimizer_groups", type=int, diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 167d61c7e..3f9e8660f 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -20,7 +20,7 @@ init_ipex() -from .utils import setup_logging +from .utils import setup_logging, mem_eff_save_file setup_logging() import logging @@ -409,19 +409,28 @@ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): return model_pred, weighting -def save_models(ckpt_path: str, flux: flux_models.Flux, sai_metadata: Optional[dict], save_dtype: Optional[torch.dtype] = None): +def save_models( + ckpt_path: str, + flux: flux_models.Flux, + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, + use_mem_eff_save: bool = False, +): state_dict = {} def update_sd(prefix, sd): for k, v in sd.items(): key = prefix + k - if save_dtype is not None: + if save_dtype is not None and v.dtype != save_dtype: v = v.detach().clone().to("cpu").to(save_dtype) state_dict[key] = v update_sd("", flux.state_dict()) - save_file(state_dict, ckpt_path, metadata=sai_metadata) + if not use_mem_eff_save: + save_file(state_dict, ckpt_path, metadata=sai_metadata) + else: + mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata) def save_flux_model_on_train_end( @@ -429,7 +438,7 @@ def save_flux_model_on_train_end( ): def sd_saver(ckpt_file, epoch_no, global_step): sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") - save_models(ckpt_file, flux, sai_metadata, save_dtype) + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) @@ -448,7 +457,7 @@ def save_flux_model_on_epoch_end_or_stepwise( ): def sd_saver(ckpt_file, epoch_no, global_step): sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") - save_models(ckpt_file, flux, sai_metadata, save_dtype) + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) train_util.save_sd_model_on_epoch_end_or_stepwise_common( args, diff --git a/library/utils.py b/library/utils.py index 3037c055d..7de22d5a9 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,9 +1,12 @@ import logging import sys import threading +from typing import * +import json +import struct + import torch from torchvision import transforms -from typing import * from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput @@ -79,6 +82,76 @@ def setup_logging(args=None, log_level=None, reset=False): logger.info(msg_init) +def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): + """ + memory efficient save file + """ + + _TYPES = { + torch.float64: "F64", + torch.float32: "F32", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int64: "I64", + torch.int32: "I32", + torch.int16: "I16", + torch.int8: "I8", + torch.uint8: "U8", + torch.bool: "BOOL", + getattr(torch, "float8_e5m2", None): "F8_E5M2", + getattr(torch, "float8_e4m3fn", None): "F8_E4M3", + } + _ALIGN = 256 + + def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: + validated = {} + for key, value in metadata.items(): + if not isinstance(key, str): + raise ValueError(f"Metadata key must be a string, got {type(key)}") + if not isinstance(value, str): + print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") + validated[key] = str(value) + else: + validated[key] = value + return validated + + print(f"Using memory efficient save file: {filename}") + + header = {} + offset = 0 + if metadata: + header["__metadata__"] = validate_metadata(metadata) + for k, v in tensors.items(): + if v.numel() == 0: # empty tensor + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} + else: + size = v.numel() * v.element_size() + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} + offset += size + + hjson = json.dumps(header).encode("utf-8") + hjson += b" " * (-(len(hjson) + 8) % _ALIGN) + + with open(filename, "wb") as f: + f.write(struct.pack(" Date: Tue, 20 Aug 2024 08:19:00 +0900 Subject: [PATCH 057/163] Fix debug_dataset to work --- train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_network.py b/train_network.py index 086b314a5..cab0ec52e 100644 --- a/train_network.py +++ b/train_network.py @@ -313,6 +313,7 @@ def train(self, args): collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: From c62c95e8626bdb727cedc8f037c82ab3a8e66059 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 08:21:01 +0900 Subject: [PATCH 058/163] update about multi-resolution training in FLUX.1 --- README.md | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/README.md b/README.md index 51e4635bb..165eed341 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,13 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 20, 2024: +FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution). + +The script seems to support multi-resolution even in the current version, __if `--cache_latents_to_disk` is not specified__. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. + +We will support multi-resolution caching to disk in the near future. + Aug 19, 2024: In `flux_train.py`, the memory consumption during model saving is reduced when `--save_precision` is set to the same value as `--mixed_precision` (about 22GB). Please set the same value unless there is a reason. @@ -159,6 +166,51 @@ In the case of LoRA models are trained with `bf16`, we are not sure which is bet The script can merge multiple LoRA models. If you want to merge multiple LoRA models, specify `--concat` option to work the merged LoRA model properly. +### FLUX.1 Multi-resolution training + +You can define multiple resolutions in the dataset configuration file. __Caching latents to disk is not supported yet.__ + +The dataset configuration file is like below. You can define multiple resolutions with different batch sizes. The resolutions are defined in the `[[datasets]]` section. The `[[datasets.subsets]]` section is for the dataset directory. Please specify the same directory for each resolution. + +``` +[general] +# define common settings here +flip_aug = true +color_aug = false +keep_tokens_separator= "|||" +shuffle_caption = false +caption_tag_dropout_rate = 0 +caption_extension = ".txt" + +[[datasets]] +# define the first resolution here +batch_size = 2 +enable_bucket = true +resolution = [1024, 1024] + + [[datasets.subsets]] + image_dir = "path/to/image/dir" + num_repeats = 1 + +[[datasets]] +# define the second resolution here +batch_size = 3 +enable_bucket = true +resolution = [768, 768] + + [[datasets.subsets]] + image_dir = "path/to/image/dir" + num_repeats = 1 + +[[datasets]] +# define the third resolution here +batch_size = 4 +enable_bucket = true +resolution = [512, 512] + + [[datasets.subsets]] + image_dir = "path/to/image/dir" + num_repeats = 1 ``` ## SD3 training From 6f6faf9b5a99b7f741f657a06a42f63754e450c0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 19:16:25 +0900 Subject: [PATCH 059/163] fix to work with ai-toolkit LoRA --- networks/flux_merge_lora.py | 163 +++++++++++++++--------------------- 1 file changed, 68 insertions(+), 95 deletions(-) diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index df0ba606a..1ba1f314d 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -7,8 +7,6 @@ from safetensors.torch import load_file, save_file from tqdm import tqdm -import lora_flux as lora_flux -from library import sai_model_spec, train_util from library.utils import setup_logging setup_logging() @@ -16,6 +14,9 @@ logger = logging.getLogger(__name__) +import lora_flux as lora_flux +from library import sai_model_spec, train_util + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -43,13 +44,11 @@ def save_to_file(file_name, state_dict, dtype, metadata): save_file(state_dict, file_name, metadata=metadata) -def merge_to_flux_model( - loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype -): +def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): logger.info(f"loading keys from FLUX.1 model: {flux_model}") flux_state_dict = load_file(flux_model, device=loading_device) - def create_key_map(n_double_layers, n_single_layers, hidden_size): + def create_key_map(n_double_layers, n_single_layers): key_map = {} for index in range(n_double_layers): prefix_from = f"transformer_blocks.{index}" @@ -60,18 +59,12 @@ def create_key_map(n_double_layers, n_single_layers, hidden_size): qkv_img = f"{prefix_to}.img_attn.qkv.{end}" qkv_txt = f"{prefix_to}.txt_attn.qkv.{end}" - key_map[f"{k}to_q.{end}"] = (qkv_img, (0, 0, hidden_size)) - key_map[f"{k}to_k.{end}"] = (qkv_img, (0, hidden_size, hidden_size)) - key_map[f"{k}to_v.{end}"] = (qkv_img, (0, hidden_size * 2, hidden_size)) - key_map[f"{k}add_q_proj.{end}"] = (qkv_txt, (0, 0, hidden_size)) - key_map[f"{k}add_k_proj.{end}"] = ( - qkv_txt, - (0, hidden_size, hidden_size), - ) - key_map[f"{k}add_v_proj.{end}"] = ( - qkv_txt, - (0, hidden_size * 2, hidden_size), - ) + key_map[f"{k}to_q.{end}"] = qkv_img + key_map[f"{k}to_k.{end}"] = qkv_img + key_map[f"{k}to_v.{end}"] = qkv_img + key_map[f"{k}add_q_proj.{end}"] = qkv_txt + key_map[f"{k}add_k_proj.{end}"] = qkv_txt + key_map[f"{k}add_v_proj.{end}"] = qkv_txt block_map = { "attn.to_out.0.weight": "img_attn.proj.weight", @@ -106,13 +99,10 @@ def create_key_map(n_double_layers, n_single_layers, hidden_size): for end in ("weight", "bias"): k = f"{prefix_from}.attn." qkv = f"{prefix_to}.linear1.{end}" - key_map[f"{k}to_q.{end}"] = (qkv, (0, 0, hidden_size)) - key_map[f"{k}to_k.{end}"] = (qkv, (0, hidden_size, hidden_size)) - key_map[f"{k}to_v.{end}"] = (qkv, (0, hidden_size * 2, hidden_size)) - key_map[f"{prefix_from}.proj_mlp.{end}"] = ( - qkv, - (0, hidden_size * 3, hidden_size * 4), - ) + key_map[f"{k}to_q.{end}"] = qkv + key_map[f"{k}to_k.{end}"] = qkv + key_map[f"{k}to_v.{end}"] = qkv + key_map[f"{prefix_from}.proj_mlp.{end}"] = qkv block_map = { "norm.linear.weight": "modulation.lin.weight", @@ -126,11 +116,14 @@ def create_key_map(n_double_layers, n_single_layers, hidden_size): for k, v in block_map.items(): key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}" + # add as-is keys + values = list([(v if isinstance(v, str) else v[0]) for v in set(key_map.values())]) + values.sort() + key_map.update({v: v for v in values}) + return key_map - key_map = create_key_map( - 18, 1, 2048 - ) # Assuming 18 double layers, 1 single layer, and hidden size of 2048 + key_map = create_key_map(18, 38) # 18 double layers, 38 single layers def find_matching_key(flux_dict, lora_key): lora_key = lora_key.replace("diffusion_model.", "") @@ -159,7 +152,6 @@ def find_matching_key(flux_dict, lora_key): "attn.add_k_proj": "txt_attn.qkv", "attn.add_v_proj": "txt_attn.qkv", } - single_block_map = { "norm.linear": "modulation.lin", "proj_out": "linear2", @@ -168,18 +160,22 @@ def find_matching_key(flux_dict, lora_key): "attn.to_q": "linear1", "attn.to_k": "linear1", "attn.to_v": "linear1", + "proj_mlp": "linear1", } + # same key exists in both single_block_map and double_block_map, so we must care about single/double + # print("lora_key before double_block_map", lora_key) for old, new in double_block_map.items(): - lora_key = lora_key.replace(old, new) - + if "double" in lora_key: + lora_key = lora_key.replace(old, new) + # print("lora_key before single_block_map", lora_key) for old, new in single_block_map.items(): - lora_key = lora_key.replace(old, new) + if "single" in lora_key: + lora_key = lora_key.replace(old, new) + # print("lora_key after mapping", lora_key) if lora_key in key_map: flux_key = key_map[lora_key] - if isinstance(flux_key, tuple): - flux_key = flux_key[0] logger.info(f"Found matching key: {flux_key}") return flux_key @@ -198,16 +194,11 @@ def find_matching_key(flux_dict, lora_key): lora_sd, _ = load_state_dict(model, merge_dtype) logger.info("merging...") - for key in tqdm(lora_sd.keys()): + for key in lora_sd.keys(): if "lora_down" in key or "lora_A" in key: - lora_name = key[ - : key.rfind(".lora_down" if "lora_down" in key else ".lora_A") - ] + lora_name = key[: key.rfind(".lora_down" if "lora_down" in key else ".lora_A")] up_key = key.replace("lora_down", "lora_up").replace("lora_A", "lora_B") - alpha_key = ( - key[: key.index("lora_down" if "lora_down" in key else "lora_A")] - + "alpha" - ) + alpha_key = key[: key.index("lora_down" if "lora_down" in key else "lora_A")] + "alpha" logger.info(f"Processing LoRA key: {lora_name}") flux_key = find_matching_key(flux_state_dict, lora_name) @@ -231,20 +222,35 @@ def find_matching_key(flux_dict, lora_key): up_weight = up_weight.to(working_device, merge_dtype) down_weight = down_weight.to(working_device, merge_dtype) + # print(up_weight.size(), down_weight.size(), weight.size()) + if lora_name.startswith("transformer."): - if "qkv" in flux_key: - hidden_size = weight.size(-1) // 3 + if "qkv" in flux_key or "linear1" in flux_key: # combined qkv or qkv+mlp update = ratio * (up_weight @ down_weight) * scale + # print(update.shape) if "img_attn" in flux_key or "txt_attn" in flux_key: - q, k, v = torch.chunk(weight, 3, dim=-1) + q, k, v = torch.chunk(weight, 3, dim=0) if "to_q" in lora_name or "add_q_proj" in lora_name: q += update.reshape(q.shape) elif "to_k" in lora_name or "add_k_proj" in lora_name: k += update.reshape(k.shape) elif "to_v" in lora_name or "add_v_proj" in lora_name: v += update.reshape(v.shape) - weight = torch.cat([q, k, v], dim=-1) + weight = torch.cat([q, k, v], dim=0) + elif "linear1" in flux_key: + q, k, v = torch.chunk(weight[: int(update.shape[-1] * 3)], 3, dim=0) + mlp = weight[int(update.shape[-1] * 3) :] + # print(q.shape, k.shape, v.shape, mlp.shape) + if "to_q" in lora_name: + q += update.reshape(q.shape) + elif "to_k" in lora_name: + k += update.reshape(k.shape) + elif "to_v" in lora_name: + v += update.reshape(v.shape) + elif "proj_mlp" in lora_name: + mlp += update.reshape(mlp.shape) + weight = torch.cat([q, k, v, mlp], dim=0) else: if len(weight.size()) == 2: weight = weight + ratio * (up_weight @ down_weight) * scale @@ -252,18 +258,11 @@ def find_matching_key(flux_dict, lora_key): weight = ( weight + ratio - * ( - up_weight.squeeze(3).squeeze(2) - @ down_weight.squeeze(3).squeeze(2) - ) - .unsqueeze(2) - .unsqueeze(3) + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale ) else: - conved = torch.nn.functional.conv2d( - down_weight.permute(1, 0, 2, 3), up_weight - ).permute(1, 0, 2, 3) + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) weight = weight + ratio * conved * scale else: if len(weight.size()) == 2: @@ -272,18 +271,11 @@ def find_matching_key(flux_dict, lora_key): weight = ( weight + ratio - * ( - up_weight.squeeze(3).squeeze(2) - @ down_weight.squeeze(3).squeeze(2) - ) - .unsqueeze(2) - .unsqueeze(3) + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale ) else: - conved = torch.nn.functional.conv2d( - down_weight.permute(1, 0, 2, 3), up_weight - ).permute(1, 0, 2, 3) + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) weight = weight + ratio * conved * scale flux_state_dict[flux_key] = weight.to(loading_device, save_dtype) @@ -308,9 +300,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_metadata is not None: if base_model is None: - base_model = lora_metadata.get( - train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None - ) + base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) # get alpha and dim alphas = {} # alpha for current model @@ -336,9 +326,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - logger.info( - f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}" - ) + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") # merge logger.info("merging...") @@ -359,19 +347,14 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio - scale = ( - abs(scale) if "lora_up" in key else scale - ) # マイナスの重みに対応する。 + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 if key in merged_sd: assert ( - merged_sd[key].size() == lora_sd[key].size() - or concat_dim is not None + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None ), "weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" if concat_dim is not None: - merged_sd[key] = torch.cat( - [merged_sd[key], lora_sd[key] * scale], dim=concat_dim - ) + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) else: merged_sd[key] = merged_sd[key] + lora_sd[key] * scale else: @@ -390,9 +373,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_up] = merged_sd[key_up][:, perm] logger.info("merged model") - logger.info( - f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}" - ) + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") # check all dims are same dims_list = list(set(base_dims.values())) @@ -411,16 +392,14 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): # build minimum metadata dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" - metadata = train_util.build_minimum_network_metadata( - str(False), base_model, "networks.lora", dims, alphas, None - ) + metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None) return merged_sd, metadata def merge(args): - assert ( - len(args.models) == len(args.ratios) + assert len(args.models) == len( + args.ratios ), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" def str_to_dtype(p): @@ -456,9 +435,7 @@ def str_to_dtype(p): if args.no_metadata: sai_metadata = None else: - merged_from = sai_model_spec.build_merged_from( - [args.flux_model] + args.models - ) + merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( None, @@ -477,15 +454,11 @@ def str_to_dtype(p): save_to_file(args.save_to, state_dict, save_dtype, sai_metadata) else: - state_dict, metadata = merge_lora_models( - args.models, args.ratios, merge_dtype, args.concat, args.shuffle - ) + state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) logger.info("calculating hashes and creating metadata...") - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes( - state_dict, metadata - ) + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash From 9381332020b7089a41eb8d041938f8ba417528d1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 19:32:26 +0900 Subject: [PATCH 060/163] revert merge function add add option to use new func --- README.md | 3 + networks/flux_merge_lora.py | 120 +++++++++++++++++++++++++++--------- 2 files changed, 94 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 165eed341..3f5c4daa5 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 20, 2024 (update 2): +`flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015! + Aug 20, 2024: FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution). diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index 1ba1f314d..fd9cc4e3a 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -4,6 +4,7 @@ import time import torch +from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm @@ -45,6 +46,81 @@ def save_to_file(file_name, state_dict, dtype, metadata): def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): + # create module map without loading state_dict + logger.info(f"loading keys from FLUX.1 model: {flux_model}") + lora_name_to_module_key = {} + with safe_open(flux_model, framework="pt", device=loading_device) as flux_file: + keys = list(flux_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") + lora_name_to_module_key[lora_name] = key + + flux_state_dict = load_file(flux_model, device=loading_device) + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU + + logger.info(f"merging...") + for key in tqdm(list(lora_sd.keys())): + if "lora_down" in key: + lora_name = key[: key.rfind(".lora_down")] + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + if lora_name not in lora_name_to_module_key: + logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.") + continue + + down_weight = lora_sd.pop(key) + up_weight = lora_sd.pop(up_key) + + dim = down_weight.size()[0] + alpha = lora_sd.pop(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + module_weight_key = lora_name_to_module_key[lora_name] + if module_weight_key not in flux_state_dict: + weight = flux_file.get_tensor(module_weight_key) + else: + weight = flux_state_dict[module_weight_key] + + weight = weight.to(working_device, merge_dtype) + up_weight = up_weight.to(working_device, merge_dtype) + down_weight = down_weight.to(working_device, merge_dtype) + + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + del up_weight + del down_weight + del weight + + if len(lora_sd) > 0: + logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}") + + return flux_state_dict + + +def merge_to_flux_model_diffusers(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): logger.info(f"loading keys from FLUX.1 model: {flux_model}") flux_state_dict = load_file(flux_model, device=loading_device) @@ -422,15 +498,14 @@ def str_to_dtype(p): os.makedirs(dest_dir) if args.flux_model is not None: - state_dict = merge_to_flux_model( - args.loading_device, - args.working_device, - args.flux_model, - args.models, - args.ratios, - merge_dtype, - save_dtype, - ) + if not args.diffusers: + state_dict = merge_to_flux_model( + args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + ) + else: + state_dict = merge_to_flux_model_diffusers( + args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + ) if args.no_metadata: sai_metadata = None @@ -438,16 +513,7 @@ def str_to_dtype(p): merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - None, - False, - False, - False, - False, - False, - time.time(), - title=title, - merged_from=merged_from, - flux="dev", + None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) logger.info(f"saving FLUX model to: {args.save_to}") @@ -466,16 +532,7 @@ def str_to_dtype(p): merged_from = sai_model_spec.build_merged_from(args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - state_dict, - False, - False, - False, - True, - False, - time.time(), - title=title, - merged_from=merged_from, - flux="dev", + state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) metadata.update(sai_metadata) @@ -553,6 +610,11 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="shuffle lora weight./ " + "LoRAの重みをシャッフルする", ) + parser.add_argument( + "--diffusers", + action="store_true", + help="merge Diffusers (?) LoRA models / Diffusers (?) LoRAモデルをマージする", + ) return parser From dbed5126bd1133da832dae31ce73ba6c41afc9d3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 19:33:47 +0900 Subject: [PATCH 061/163] chore: formatting --- networks/flux_merge_lora.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index fd9cc4e3a..d5e82920d 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -113,7 +113,7 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati del up_weight del down_weight del weight - + if len(lora_sd) > 0: logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}") @@ -587,12 +587,7 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル", ) - parser.add_argument( - "--ratios", - type=float, - nargs="*", - help="ratios for each model / それぞれのLoRAモデルの比率", - ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") parser.add_argument( "--no_metadata", action="store_true", From 6ab48b09d8e46973d5e5fa47baeae3a464d06d04 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 21:39:43 +0900 Subject: [PATCH 062/163] feat: Support multi-resolution training with caching latents to disk --- README.md | 11 +++- library/strategy_base.py | 112 ++++++++++++++++++++++++++------------- library/strategy_flux.py | 11 +++- library/train_util.py | 2 +- 4 files changed, 93 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 3f5c4daa5..1d44c9e58 100644 --- a/README.md +++ b/README.md @@ -9,13 +9,20 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 20, 2024 (update 3): +__Experimental__ The multi-resolution training is now supported with caching latents to disk. + +The cache files now hold latents for multiple resolutions. Since the latents are appended to the current cache file, it is recommended to delete the cache file in advance (if not, the old latents is kept in .npz file). + +See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. + Aug 20, 2024 (update 2): `flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015! Aug 20, 2024: FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution). -The script seems to support multi-resolution even in the current version, __if `--cache_latents_to_disk` is not specified__. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. +The script seems to support multi-resolution even in the current version, ~~if `--cache_latents_to_disk` is not specified~~ -> `--cache_latents_to_disk` is now supported for multi-resolution training. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. We will support multi-resolution caching to disk in the near future. @@ -171,7 +178,7 @@ The script can merge multiple LoRA models. If you want to merge multiple LoRA mo ### FLUX.1 Multi-resolution training -You can define multiple resolutions in the dataset configuration file. __Caching latents to disk is not supported yet.__ +You can define multiple resolutions in the dataset configuration file. The dataset configuration file is like below. You can define multiple resolutions with different batch sizes. The resolutions are defined in the `[[datasets]]` section. The `[[datasets.subsets]]` section is for the dataset directory. Please specify the same directory for each resolution. diff --git a/library/strategy_base.py b/library/strategy_base.py index a99a08290..e7d3a97ef 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -219,7 +219,13 @@ def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mas raise NotImplementedError def _default_is_disk_cached_latents_expected( - self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + self, + latents_stride: int, + bucket_reso: Tuple[int, int], + npz_path: str, + flip_aug: bool, + alpha_mask: bool, + multi_resolution: bool = False, ): if not self.cache_to_disk: return False @@ -230,25 +236,17 @@ def _default_is_disk_cached_latents_expected( expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + # e.g. "_32x64", HxW + key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "" + try: npz = np.load(npz_path) - if npz["latents"].shape[1:3] != expected_latents_size: + if "latents" + key_reso_suffix not in npz: + return False + if flip_aug and "latents_flipped" + key_reso_suffix not in npz: + return False + if alpha_mask and "alpha_mask" + key_reso_suffix not in npz: return False - - if flip_aug: - if "latents_flipped" not in npz: - return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: - return False - - if alpha_mask: - if "alpha_mask" not in npz: - return False - if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): - return False - else: - if "alpha_mask" in npz: - return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -257,7 +255,15 @@ def _default_is_disk_cached_latents_expected( # TODO remove circular dependency for ImageInfo def _default_cache_batch_latents( - self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool + self, + encode_by_vae, + vae_device, + vae_dtype, + image_infos: List, + flip_aug: bool, + alpha_mask: bool, + random_crop: bool, + multi_resolution: bool = False, ): """ Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. @@ -287,8 +293,13 @@ def _default_cache_batch_latents( original_size = original_sizes[i] crop_ltrb = crop_ltrbs[i] + latents_size = latents.shape[1:3] # H, W + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW + if self.cache_to_disk: - self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask) + self.save_latents_to_disk( + info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix + ) else: info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb @@ -298,31 +309,56 @@ def _default_cache_batch_latents( info.alpha_mask = alpha_mask def load_latents_from_disk( - self, npz_path: str + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + """ + for SD/SDXL/SD3.0 + """ + return self._default_load_latents_from_disk(None, npz_path, bucket_reso) + + def _default_load_latents_from_disk( + self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + if latents_stride is None: + key_reso_suffix = "" + else: + latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW + npz = np.load(npz_path) - if "latents" not in npz: - raise ValueError(f"error: npz is old format. please re-generate {npz_path}") - - latents = npz["latents"] - original_size = npz["original_size"].tolist() - crop_ltrb = npz["crop_ltrb"].tolist() - flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None + if "latents" + key_reso_suffix not in npz: + raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") + + latents = npz["latents" + key_reso_suffix] + original_size = npz["original_size" + key_reso_suffix].tolist() + crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() + flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None + alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None return latents, original_size, crop_ltrb, flipped_latents, alpha_mask def save_latents_to_disk( - self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None + self, + npz_path, + latents_tensor, + original_size, + crop_ltrb, + flipped_latents_tensor=None, + alpha_mask=None, + key_reso_suffix="", ): kwargs = {} + + if os.path.exists(npz_path): + # load existing npz and update it + npz = np.load(npz_path) + for key in npz.files: + kwargs[key] = npz[key] + + kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy() + kwargs["original_size" + key_reso_suffix] = np.array(original_size) + kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb) if flipped_latents_tensor is not None: - kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() + kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy() if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - np.savez( - npz_path, - latents=latents_tensor.float().cpu().numpy(), - original_size=np.array(original_size), - crop_ltrb=np.array(crop_ltrb), - **kwargs, - ) + kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy() + np.savez(npz_path, **kwargs) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 3880a1e1b..5c620f3d6 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -200,7 +200,12 @@ def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -208,7 +213,9 @@ def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask vae_device = vae.device vae_dtype = vae.dtype - self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, True + ) if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device) diff --git a/library/train_util.py b/library/train_util.py index f4ac8740a..8929c192f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1381,7 +1381,7 @@ def __getitem__(self, index): image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( - self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz) + self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso) ) if flipped: latents = flipped_latents From 7e459c00b2e142e40a9452341934c2eb9f70a172 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 21 Aug 2024 08:02:33 +0900 Subject: [PATCH 063/163] Update T5 attention mask handling in FLUX --- README.md | 3 +++ flux_minimal_inference.py | 33 +++++++++++++++++++----- flux_train.py | 6 ++++- flux_train_network.py | 13 +++++----- library/flux_models.py | 51 +++++++++++++++++++++---------------- library/flux_train_utils.py | 20 ++++++++++++--- library/strategy_flux.py | 25 ++++++++++-------- 7 files changed, 101 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index 1d44c9e58..43edbbed6 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 21, 2024: +The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed. + Aug 20, 2024 (update 3): __Experimental__ The multi-resolution training is now supported with caching latents to disk. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index b09f63808..5b8aa2506 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -70,12 +70,22 @@ def denoise( vec: torch.Tensor, timesteps: list[float], guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) img = img + (t_prev - t_curr) * pred @@ -92,6 +102,7 @@ def do_sample( txt_ids: torch.Tensor, num_steps: int, guidance: float, + t5_attn_mask: Optional[torch.Tensor], is_schnell: bool, device: torch.device, flux_dtype: torch.dtype, @@ -101,10 +112,14 @@ def do_sample( # denoise initial noise if accelerator: with accelerator.autocast(), torch.no_grad(): - x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance) + x = denoise( + model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + ) else: with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): - x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance) + x = denoise( + model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + ) return x @@ -156,14 +171,14 @@ def generate_image( clip_l.to(clip_l_dtype) t5xxl.to(t5xxl_dtype) with accelerator.autocast(): - _, t5_out, txt_ids = encoding_strategy.encode_tokens( + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) else: with torch.autocast(device_type=device.type, dtype=clip_l_dtype): - l_pooled, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): - _, t5_out, txt_ids = encoding_strategy.encode_tokens( + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) @@ -186,7 +201,11 @@ def generate_image( steps = 4 if is_schnell else 50 img_ids = img_ids.to(device) - x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, is_schnell, device, flux_dtype) + t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None + + x = do_sample( + accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, t5_attn_mask, is_schnell, device, flux_dtype + ) if args.offload: model = model.cpu() # del model diff --git a/flux_train.py b/flux_train.py index 669963856..ecb8a1086 100644 --- a/flux_train.py +++ b/flux_train.py @@ -610,7 +610,10 @@ def optimizer_hook(parameter: torch.Tensor): guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) # call model - l_pooled, t5_out, txt_ids = text_encoder_conds + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + with accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = flux( @@ -621,6 +624,7 @@ def optimizer_hook(parameter: torch.Tensor): y=l_pooled, timesteps=timesteps / 1000, guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, ) # unpack latents diff --git a/flux_train_network.py b/flux_train_network.py index 002252c87..49bd270c7 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -233,11 +233,11 @@ def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.Fl self.flux_lower = flux_lower self.target_device = device - def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None): + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): self.flux_lower.to("cpu") clean_memory_on_device(self.target_device) self.flux_upper.to(self.target_device) - img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance) + img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) self.flux_upper.to("cpu") clean_memory_on_device(self.target_device) self.flux_lower.to(self.target_device) @@ -300,10 +300,9 @@ def get_noise_pred_and_target( guidance_vec.requires_grad_(True) # Predict the noise residual - l_pooled, t5_out, txt_ids = text_encoder_conds - # print( - # f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}" - # ) + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None if not args.split_mode: # normal forward @@ -317,6 +316,7 @@ def get_noise_pred_and_target( y=l_pooled, timesteps=timesteps / 1000, guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, ) else: # split forward to reduce memory usage @@ -337,6 +337,7 @@ def get_noise_pred_and_target( y=l_pooled, timesteps=timesteps / 1000, guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, ) # move flux upper back to cpu, and then move flux lower to gpu diff --git a/library/flux_models.py b/library/flux_models.py index 11ef647ad..6f28da603 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -440,10 +440,10 @@ class ModelSpec: # region math -def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: q, k = apply_rope(q, k, pe) - x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) x = rearrange(x, "B H L D -> B L (H D)") return x @@ -607,11 +607,7 @@ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): self.norm = QKNorm(head_dim) self.proj = nn.Linear(dim, dim) - # self.gradient_checkpointing = False - - # def enable_gradient_checkpointing(self): - # self.gradient_checkpointing = True - + # this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly def forward(self, x: Tensor, pe: Tensor) -> Tensor: qkv = self.qkv(x) q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) @@ -620,12 +616,6 @@ def forward(self, x: Tensor, pe: Tensor) -> Tensor: x = self.proj(x) return x - # def forward(self, *args, **kwargs): - # if self.training and self.gradient_checkpointing: - # return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) - # else: - # return self._forward(*args, **kwargs) - @dataclass class ModulationOut: @@ -690,7 +680,9 @@ def disable_gradient_checkpointing(self): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False - def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + def _forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -713,7 +705,18 @@ def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[T k = torch.cat((txt_k, img_k), dim=2) v = torch.cat((txt_v, img_v), dim=2) - attn = attention(q, k, v, pe=pe) + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + attn_mask = txt_attention_mask # b, seq_len + attn_mask = torch.cat( + (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1]).to(attn_mask.device)), dim=1 + ) # b, seq_len + img_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img blocks @@ -725,10 +728,12 @@ def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[T txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) return img, txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + def forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: if self.training and self.gradient_checkpointing: if not self.cpu_offload_checkpointing: - return checkpoint(self._forward, img, txt, vec, pe, use_reentrant=False) + return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False) # cpu offload checkpointing def create_custom_forward(func): @@ -739,10 +744,10 @@ def custom_forward(*inputs): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe) + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask) else: - return self._forward(img, txt, vec, pe) + return self._forward(img, txt, vec, pe, txt_attention_mask) # def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): # if self.training and self.gradient_checkpointing: @@ -992,6 +997,7 @@ def forward( timesteps: Tensor, y: Tensor, guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -1011,7 +1017,7 @@ def forward( if not self.double_blocks_to_swap: for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning for block_idx in range(self.double_blocks_to_swap): @@ -1033,7 +1039,7 @@ def forward( block.to(self.device) # move to cuda # print(f"Moved double block {block_idx} to cuda.") - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) if moving: self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) @@ -1164,6 +1170,7 @@ def forward( timesteps: Tensor, y: Tensor, guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -1182,7 +1189,7 @@ def forward( pe = self.pe_embedder(ids) for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) return img, txt, vec, pe diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 3f9e8660f..1d3f80d72 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -190,9 +190,10 @@ def sample_image_inference( te_outputs = sample_prompts_te_outputs[prompt] else: tokens_and_masks = tokenize_strategy.tokenize(prompt) + # strategy has apply_t5_attn_mask option te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) - l_pooled, t5_out, txt_ids = te_outputs + l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs # sample image weight_dtype = ae.dtype # TOFO give dtype as argument @@ -208,9 +209,10 @@ def sample_image_inference( ) timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None with accelerator.autocast(), torch.no_grad(): - x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale) + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask) x = x.float() x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) @@ -289,12 +291,22 @@ def denoise( vec: torch.Tensor, timesteps: list[float], guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) img = img + (t_prev - t_curr) * pred @@ -498,7 +510,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--apply_t5_attn_mask", action="store_true", - help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する", ) parser.add_argument( "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 5c620f3d6..737af390a 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -64,22 +64,25 @@ def encode_tokens( l_tokens, t5_tokens = tokens[:2] t5_attn_mask = tokens[2] if len(tokens) > 2 else None + # clip_l is None when using T5 only if clip_l is not None and l_tokens is not None: l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"] else: l_pooled = None + # t5xxl is None when using CLIP only if t5xxl is not None and t5_tokens is not None: # t5_out is [b, max length, 4096] - t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True) - if apply_t5_attn_mask: - t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) + attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device) + t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True) + # if zero_pad_t5_output: + # t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device) else: t5_out = None txt_ids = None - return [l_pooled, t5_out, txt_ids] + return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): @@ -115,6 +118,8 @@ def is_disk_cached_outputs_expected(self, npz_path: str): return False if "txt_ids" not in npz: return False + if "t5_attn_mask" not in npz: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -129,12 +134,12 @@ def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: l_pooled = data["l_pooled"] t5_out = data["t5_out"] txt_ids = data["txt_ids"] + t5_attn_mask = data["t5_attn_mask"] if self.apply_t5_attn_mask: - t5_attn_mask = data["t5_attn_mask"] t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) - return [l_pooled, t5_out, txt_ids] + return [l_pooled, t5_out, txt_ids, t5_attn_mask] def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List @@ -145,7 +150,7 @@ def cache_batch_outputs( tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): # attn_mask is not applied when caching to disk: it is applied when loading from disk - l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens( + l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens( tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk ) @@ -159,15 +164,15 @@ def cache_batch_outputs( l_pooled = l_pooled.cpu().numpy() t5_out = t5_out.cpu().numpy() txt_ids = txt_ids.cpu().numpy() + t5_attn_mask = tokens_and_masks[2].cpu().numpy() for i, info in enumerate(infos): l_pooled_i = l_pooled[i] t5_out_i = t5_out[i] txt_ids_i = txt_ids[i] + t5_attn_mask_i = t5_attn_mask[i] if self.cache_to_disk: - t5_attn_mask = tokens_and_masks[2] - t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() np.savez( info.text_encoder_outputs_npz, l_pooled=l_pooled_i, @@ -176,7 +181,7 @@ def cache_batch_outputs( t5_attn_mask=t5_attn_mask_i, ) else: - info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i) + info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) class FluxLatentsCachingStrategy(LatentsCachingStrategy): From e17c42cb0de8a1303a607ecc75af092dc12dc272 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 21 Aug 2024 12:28:45 +0900 Subject: [PATCH 064/163] Add BFL/Diffusers LoRA converter #1467 #1458 #1483 --- networks/convert_flux_lora.py | 403 ++++++++++++++++++++++++++++++++++ 1 file changed, 403 insertions(+) create mode 100644 networks/convert_flux_lora.py diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py new file mode 100644 index 000000000..dd962ebfe --- /dev/null +++ b/networks/convert_flux_lora.py @@ -0,0 +1,403 @@ +# convert key mapping and data format from some LoRA format to another +""" +Original LoRA format: Based on Black Forest Labs, QKV and MLP are unified into one module +alpha is scalar for each LoRA module + +0 to 18 +lora_unet_double_blocks_0_img_attn_proj.alpha torch.Size([]) +lora_unet_double_blocks_0_img_attn_proj.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_attn_proj.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_img_attn_qkv.alpha torch.Size([]) +lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight torch.Size([9216, 4]) +lora_unet_double_blocks_0_img_mlp_0.alpha torch.Size([]) +lora_unet_double_blocks_0_img_mlp_0.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_mlp_0.lora_up.weight torch.Size([12288, 4]) +lora_unet_double_blocks_0_img_mlp_2.alpha torch.Size([]) +lora_unet_double_blocks_0_img_mlp_2.lora_down.weight torch.Size([4, 12288]) +lora_unet_double_blocks_0_img_mlp_2.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_img_mod_lin.alpha torch.Size([]) +lora_unet_double_blocks_0_img_mod_lin.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_mod_lin.lora_up.weight torch.Size([18432, 4]) +lora_unet_double_blocks_0_txt_attn_proj.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_attn_proj.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_attn_proj.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_txt_attn_qkv.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_attn_qkv.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_attn_qkv.lora_up.weight torch.Size([9216, 4]) +lora_unet_double_blocks_0_txt_mlp_0.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_mlp_0.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_mlp_0.lora_up.weight torch.Size([12288, 4]) +lora_unet_double_blocks_0_txt_mlp_2.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_mlp_2.lora_down.weight torch.Size([4, 12288]) +lora_unet_double_blocks_0_txt_mlp_2.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_txt_mod_lin.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_mod_lin.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_mod_lin.lora_up.weight torch.Size([18432, 4]) + +0 to 37 +lora_unet_single_blocks_0_linear1.alpha torch.Size([]) +lora_unet_single_blocks_0_linear1.lora_down.weight torch.Size([4, 3072]) +lora_unet_single_blocks_0_linear1.lora_up.weight torch.Size([21504, 4]) +lora_unet_single_blocks_0_linear2.alpha torch.Size([]) +lora_unet_single_blocks_0_linear2.lora_down.weight torch.Size([4, 15360]) +lora_unet_single_blocks_0_linear2.lora_up.weight torch.Size([3072, 4]) +lora_unet_single_blocks_0_modulation_lin.alpha torch.Size([]) +lora_unet_single_blocks_0_modulation_lin.lora_down.weight torch.Size([4, 3072]) +lora_unet_single_blocks_0_modulation_lin.lora_up.weight torch.Size([9216, 4]) +""" +""" +ai-toolkit: Based on Diffusers, QKV and MLP are separated into 3 modules. +A is down, B is up. No alpha for each LoRA module. + +0 to 18 +transformer.transformer_blocks.0.attn.add_k_proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.add_k_proj.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.add_v_proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.add_v_proj.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_add_out.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_add_out.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_out.0.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_out.0.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.ff.net.0.proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.ff.net.0.proj.lora_B.weight torch.Size([12288, 16]) +transformer.transformer_blocks.0.ff.net.2.lora_A.weight torch.Size([16, 12288]) +transformer.transformer_blocks.0.ff.net.2.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.ff_context.net.0.proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.ff_context.net.0.proj.lora_B.weight torch.Size([12288, 16]) +transformer.transformer_blocks.0.ff_context.net.2.lora_A.weight torch.Size([16, 12288]) +transformer.transformer_blocks.0.ff_context.net.2.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.norm1.linear.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.norm1.linear.lora_B.weight torch.Size([18432, 16]) +transformer.transformer_blocks.0.norm1_context.linear.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.norm1_context.linear.lora_B.weight torch.Size([18432, 16]) + +0 to 37 +transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16]) +transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16]) +transformer.single_transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16]) +transformer.single_transformer_blocks.0.norm.linear.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.norm.linear.lora_B.weight torch.Size([9216, 16]) +transformer.single_transformer_blocks.0.proj_mlp.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.proj_mlp.lora_B.weight torch.Size([12288, 16]) +transformer.single_transformer_blocks.0.proj_out.lora_A.weight torch.Size([16, 15360]) +transformer.single_transformer_blocks.0.proj_out.lora_B.weight torch.Size([3072, 16]) +""" +""" +xlabs: Unknown format. +0 to 18 +double_blocks.0.processor.proj_lora1.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.proj_lora1.up.weight torch.Size([3072, 16]) +double_blocks.0.processor.proj_lora2.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.proj_lora2.up.weight torch.Size([3072, 16]) +double_blocks.0.processor.qkv_lora1.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.qkv_lora1.up.weight torch.Size([9216, 16]) +double_blocks.0.processor.qkv_lora2.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.qkv_lora2.up.weight torch.Size([9216, 16]) +""" + + +import argparse +from safetensors.torch import save_file +from safetensors import safe_open +import torch + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def convert_to_sd_scripts(sds_sd, ait_sd, sds_key, ait_key): + ait_down_key = ait_key + ".lora_A.weight" + if ait_down_key not in ait_sd: + return + ait_up_key = ait_key + ".lora_B.weight" + + down_weight = ait_sd.pop(ait_down_key) + sds_sd[sds_key + ".lora_down.weight"] = down_weight + sds_sd[sds_key + ".lora_up.weight"] = ait_sd.pop(ait_up_key) + rank = down_weight.shape[0] + sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(rank, dtype=down_weight.dtype, device=down_weight.device) + + +def convert_to_sd_scripts_cat(sds_sd, ait_sd, sds_key, ait_keys): + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + if ait_down_keys[0] not in ait_sd: + return + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + + down_weights = [ait_sd.pop(k) for k in ait_down_keys] + up_weights = [ait_sd.pop(k) for k in ait_up_keys] + + # lora_down is concatenated along dim=0, so rank is multiplied by the number of splits + rank = down_weights[0].shape[0] + num_splits = len(ait_keys) + sds_sd[sds_key + ".lora_down.weight"] = torch.cat(down_weights, dim=0) + + merged_up_weights = torch.zeros( + (sum(w.shape[0] for w in up_weights), rank * num_splits), + dtype=up_weights[0].dtype, + device=up_weights[0].device, + ) + + i = 0 + for j, up_weight in enumerate(up_weights): + merged_up_weights[i : i + up_weight.shape[0], j * rank : (j + 1) * rank] = up_weight + i += up_weight.shape[0] + + sds_sd[sds_key + ".lora_up.weight"] = merged_up_weights + + # set alpha to new_rank + new_rank = rank * num_splits + sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(new_rank, dtype=down_weights[0].dtype, device=down_weights[0].device) + + +def convert_ai_toolkit_to_sd_scripts(ait_sd): + sds_sd = {} + for i in range(19): + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0" + ) + convert_to_sd_scripts_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out" + ) + convert_to_sd_scripts_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear" + ) + + for i in range(38): + convert_to_sd_scripts_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear" + ) + + if len(ait_sd) > 0: + logger.warning(f"Unsuppored keys for sd-scripts: {ait_sd.keys()}") + return sds_sd + + +def convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + print(f"rank: {rank}, alpha: {alpha}, scale: {scale}") + + # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + # print(f"scale: {scale}, scale_down: {scale_down}, scale_up: {scale_up}") + + ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down + ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up + + +def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + alpha = sds_sd.pop(sds_key + ".alpha") + scale = alpha / rank + + # calculate scale_down and scale_up + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + + num_splits = len(ait_keys) + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + + # down_weight is copied to each split + ait_sd.update({k: down_weight * scale_down for k in ait_down_keys}) + + # calculate dims if not provided + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # up_weight is split to each split + ait_sd.update({k: v * scale_up for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) + + +def convert_sd_scripts_to_ai_toolkit(sds_sd): + ait_sd = {} + for i in range(19): + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0" + ) + convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out" + ) + convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear" + ) + + for i in range(38): + convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + dims=[3072, 3072, 3072, 12288], + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear" + ) + + if len(sds_sd) > 0: + logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") + return ait_sd + + +def main(args): + # load source safetensors + logger.info(f"Loading source file {args.src_path}") + state_dict = {} + with safe_open(args.src_path, framework="pt") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + + logger.info(f"Converting {args.src} to {args.dst} format") + if args.src == "ai-toolkit" and args.dst == "sd-scripts": + state_dict = convert_ai_toolkit_to_sd_scripts(state_dict) + elif args.src == "sd-scripts" and args.dst == "ai-toolkit": + state_dict = convert_sd_scripts_to_ai_toolkit(state_dict) + else: + raise NotImplementedError(f"Conversion from {args.src} to {args.dst} is not supported") + + # save destination safetensors + logger.info(f"Saving destination file {args.dst_path}") + save_file(state_dict, args.dst_path, metadata=metadata) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert LoRA format") + parser.add_argument("--src", type=str, default="ai-toolkit", help="source format, ai-toolkit or sd-scripts") + parser.add_argument("--dst", type=str, default="sd-scripts", help="destination format, ai-toolkit or sd-scripts") + parser.add_argument("--src_path", type=str, default=None, help="source path") + parser.add_argument("--dst_path", type=str, default=None, help="destination path") + args = parser.parse_args() + main(args) From 2b07a92c8d970a8538a47dd1bcad3122da4e195a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 21 Aug 2024 12:30:23 +0900 Subject: [PATCH 065/163] Fix error in applying mask in Attention and add LoRA converter script --- README.md | 6 ++++++ library/flux_models.py | 5 +++-- networks/convert_flux_lora.py | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 43edbbed6..f4056851f 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,12 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 21, 2024 (update 2): +Fixed an error in applying mask in Attention. The attention mask was float, but it should be bool. + +Added a script `convert_flux_lora.py` to convert LoRA between sd-scripts format (BFL-based) and AI-toolkit format (Diffusers-based). See `--help` for details. BFL-based LoRA has a large module, so converting it to Diffusers format may reduce temporary memory usage in the inference environment. Note that re-conversion will increase the size of LoRA. + + Aug 21, 2024: The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed. diff --git a/library/flux_models.py b/library/flux_models.py index 6f28da603..e38119cd7 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -708,9 +708,10 @@ def _forward( # make attention mask if not None attn_mask = None if txt_attention_mask is not None: - attn_mask = txt_attention_mask # b, seq_len + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len attn_mask = torch.cat( - (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1]).to(attn_mask.device)), dim=1 + (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1 ) # b, seq_len + img_len # broadcast attn_mask to all heads diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py index dd962ebfe..e9743534d 100644 --- a/networks/convert_flux_lora.py +++ b/networks/convert_flux_lora.py @@ -248,7 +248,7 @@ def convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): rank = down_weight.shape[0] alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here - print(f"rank: {rank}, alpha: {alpha}, scale: {scale}") + # print(f"rank: {rank}, alpha: {alpha}, scale: {scale}") # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 scale_down = scale From e1cd19c0c0ef55709e8eb1e5babe25045f65031f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 21 Aug 2024 21:04:10 +0900 Subject: [PATCH 066/163] add stochastic rounding, fix single block --- README.md | 19 ++++++-- flux_train.py | 95 ++++++++++++++++++++++++++++++++++---- library/adafactor_fused.py | 36 ++++++++++++++- library/flux_models.py | 1 + 4 files changed, 136 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index f4056851f..45349ba38 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,15 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 21, 2024 (update 3): +- There is a bug that `--full_bf16` option is enabled even if it is not specified in `flux_train.py`. The bug will be fixed sooner. __Please specify the `--full_bf16` option explicitly, especially when training with 24GB VRAM.__ +- Stochastic rounding is now implemented when `--fused_backward_pass` is specified. The implementation is +based on the code provided by 2kpr. Thank you so much! + - With this change, `--fused_backward_pass` is recommended over `--blockwise_fused_optimizers` when `--full_bf16` is specified. + - Please note that `--fused_backward_pass` is only supported with Adafactor. +- The sample command in [FLUX.1 fine-tuning](#flux1-fine-tuning) is updated to reflect these changes. +- Fixed `--single_blocks_to_swap` is not working in `flux_train.py`. + Aug 21, 2024 (update 2): Fixed an error in applying mask in Attention. The attention mask was float, but it should be bool. @@ -142,7 +151,7 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1 --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 ---blockwise_fused_optimizers --double_blocks_to_swap 6 --cpu_offload_checkpointing +--fused_backward_pass --double_blocks_to_swap 6 --cpu_offload_checkpointing --full_bf16 ``` (Combine the command into one line.) @@ -151,9 +160,13 @@ Sample image generation during training is not tested yet. Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. -`--blockwise_fused_optimizers` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. +`--full_bf16` enables the training with bf16 (weights and gradients). + +`--fused_backward_pass` enables the fusing of the optimizer step into the backward pass for each parameter. This reduces the memory usage during training. Only Adafactor optimizer is supported for now. Stochastic rounding is also enabled when `--fused_backward_pass` and `--full_bf16` are specified. + +`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. -`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--blockwise_fused_optimizers`. +`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. `--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. diff --git a/flux_train.py b/flux_train.py index ecb8a1086..bcf4b9564 100644 --- a/flux_train.py +++ b/flux_train.py @@ -277,7 +277,10 @@ def train(args): training_models = [] params_to_optimize = [] training_models.append(flux) - params_to_optimize.append({"params": list(flux.parameters()), "lr": args.learning_rate}) + name_and_params = list(flux.named_parameters()) + # single param group for now + params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate}) + param_names = [[n for n, _ in name_and_params]] # calculate number of trainable parameters n_params = 0 @@ -433,17 +436,89 @@ def train(args): import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - for param_group in optimizer.param_groups: - for parameter in param_group["params"]: - if parameter.requires_grad: - def __grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + double_blocks_to_swap = args.double_blocks_to_swap + single_blocks_to_swap = args.single_blocks_to_swap + num_double_blocks = len(flux.double_blocks) + num_single_blocks = len(flux.single_blocks) + handled_double_block_indices = set() + handled_single_block_indices = set() - parameter.register_post_accumulate_grad_hook(__grad_hook) + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + grad_hook = None + + if double_blocks_to_swap: + if param_name.startswith("double_blocks"): + block_idx = int(param_name.split(".")[1]) + if ( + block_idx not in handled_double_block_indices + and block_idx >= (num_double_blocks - double_blocks_to_swap) - 1 + and block_idx < num_double_blocks - 1 + ): + # swap next (already backpropagated) block + handled_double_block_indices.add(block_idx) + block_idx_cpu = block_idx + 1 + block_idx_cuda = double_blocks_to_swap - (num_double_blocks - block_idx_cpu) + + # create swap hook + def create_double_swap_grad_hook(bidx, bidx_cuda): + def __grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + # swap blocks if necessary + flux.double_blocks[bidx].to("cpu") + flux.double_blocks[bidx_cuda].to(accelerator.device) + # print(f"Move double block {bidx} to cpu and {bidx_cuda} to device") + + return __grad_hook + + grad_hook = create_double_swap_grad_hook(block_idx_cpu, block_idx_cuda) + if single_blocks_to_swap: + if param_name.startswith("single_blocks"): + block_idx = int(param_name.split(".")[1]) + if ( + block_idx not in handled_single_block_indices + and block_idx >= (num_single_blocks - single_blocks_to_swap) - 1 + and block_idx < num_single_blocks - 1 + ): + handled_single_block_indices.add(block_idx) + block_idx_cpu = block_idx + 1 + block_idx_cuda = single_blocks_to_swap - (num_single_blocks - block_idx_cpu) + # print(param_name, block_idx_cpu, block_idx_cuda) + + # create swap hook + def create_single_swap_grad_hook(bidx, bidx_cuda): + def __grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + # swap blocks if necessary + flux.single_blocks[bidx].to("cpu") + flux.single_blocks[bidx_cuda].to(accelerator.device) + # print(f"Move single block {bidx} to cpu and {bidx_cuda} to device") + + return __grad_hook + + grad_hook = create_single_swap_grad_hook(block_idx_cpu, block_idx_cuda) + + if grad_hook is None: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + grad_hook = __grad_hook + + parameter.register_post_accumulate_grad_hook(grad_hook) elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py index bdfc32ced..b5afa236b 100644 --- a/library/adafactor_fused.py +++ b/library/adafactor_fused.py @@ -2,6 +2,32 @@ import torch from transformers import Adafactor +# stochastic rounding for bfloat16 +# The implementation was provided by 2kpr. Thank you very much! + +def copy_stochastic_(target: torch.Tensor, source: torch.Tensor): + """ + copies source into target using stochastic rounding + + Args: + target: the target tensor with dtype=bfloat16 + source: the target tensor with dtype=float32 + """ + # create a random 16 bit integer + result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16)) + + # add the random number to the lower 16 bit of the mantissa + result.add_(source.view(dtype=torch.int32)) + + # mask off the lower 16 bit of the mantissa + result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 + + # copy the higher 16 bit into the target tensor + target.copy_(result.view(dtype=torch.float32)) + + del result + + @torch.no_grad() def adafactor_step_param(self, p, group): if p.grad is None: @@ -48,7 +74,7 @@ def adafactor_step_param(self, p, group): lr = Adafactor._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) - update = (grad ** 2) + group["eps"][0] + update = (grad**2) + group["eps"][0] if factored: exp_avg_sq_row = state["exp_avg_sq_row"] exp_avg_sq_col = state["exp_avg_sq_col"] @@ -78,7 +104,12 @@ def adafactor_step_param(self, p, group): p_data_fp32.add_(-update) - if p.dtype in {torch.float16, torch.bfloat16}: + # if p.dtype in {torch.float16, torch.bfloat16}: + # p.copy_(p_data_fp32) + + if p.dtype == torch.bfloat16: + copy_stochastic_(p, p_data_fp32) + elif p.dtype == torch.float16: p.copy_(p_data_fp32) @@ -101,6 +132,7 @@ def adafactor_step(self, closure=None): return loss + def patch_adafactor_fused(optimizer: Adafactor): optimizer.step_param = adafactor_step_param.__get__(optimizer) optimizer.step = adafactor_step.__get__(optimizer) diff --git a/library/flux_models.py b/library/flux_models.py index e38119cd7..c98d52ec0 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1078,6 +1078,7 @@ def forward( if moving: self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) # print(f"Moved single block {to_cpu_block_index} to cpu.") + to_cpu_block_index += 1 img = img[:, txt.shape[1] :, ...] From 98c91a762513bbce9ebce137da720a448a3da6c9 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Thu, 22 Aug 2024 12:37:41 +0900 Subject: [PATCH 067/163] Fix bug in FLUX multi GPU training --- README.md | 6 +++ flux_train.py | 29 ++++++------- flux_train_network.py | 10 +++-- library/flux_models.py | 6 ++- library/flux_utils.py | 40 ++++++++++++++---- library/strategy_flux.py | 4 +- library/train_util.py | 10 ++--- library/utils.py | 89 ++++++++++++++++++++++++++++++++++++++++ 8 files changed, 156 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 45349ba38..5125c6631 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,12 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 22, 2024: +Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. + +`--disable_mmap_load_safetensors` option now works in `flux_train.py`. It speeds up model loading during training in WSL2. It is also effective in reducing memory usage when loading models during multi-GPU training. Please always check if the model is loaded correctly, as it uses a custom implementation of safetensors loading. + + Aug 21, 2024 (update 3): - There is a bug that `--full_bf16` option is enabled even if it is not specified in `flux_train.py`. The bug will be fixed sooner. __Please specify the `--full_bf16` option explicitly, especially when training with 24GB VRAM.__ - Stochastic rounding is now implemented when `--fused_backward_pass` is specified. The implementation is diff --git a/flux_train.py b/flux_train.py index bcf4b9564..e7d45e04d 100644 --- a/flux_train.py +++ b/flux_train.py @@ -174,7 +174,7 @@ def train(args): # load VAE for caching latents ae = None if cache_latents: - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae.to(accelerator.device, dtype=weight_dtype) ae.requires_grad_(False) ae.eval() @@ -199,8 +199,8 @@ def train(args): strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) # load clip_l, t5xxl for caching text encoder outputs - clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") - t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors) clip_l.eval() t5xxl.eval() clip_l.requires_grad_(False) @@ -228,7 +228,6 @@ def train(args): if args.sample_prompts is not None: logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") - tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() prompts = load_prompts(args.sample_prompts) @@ -238,9 +237,9 @@ def train(args): for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: if p not in sample_prompts_te_outputs: logger.info(f"cache Text Encoder outputs for prompt: {p}") - tokens_and_masks = tokenize_strategy.tokenize(p) + tokens_and_masks = flux_tokenize_strategy.tokenize(p) sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) accelerator.wait_for_everyone() @@ -251,7 +250,9 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + flux = flux_utils.load_flow_model( + name, args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) if args.gradient_checkpointing: flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing) @@ -419,7 +420,7 @@ def train(args): # if we doesn't swap blocks, we can move the model to device flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks]) if is_swapping_blocks: - flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -439,8 +440,8 @@ def train(args): double_blocks_to_swap = args.double_blocks_to_swap single_blocks_to_swap = args.single_blocks_to_swap - num_double_blocks = len(flux.double_blocks) - num_single_blocks = len(flux.single_blocks) + num_double_blocks = 19 # len(flux.double_blocks) + num_single_blocks = 38 # len(flux.single_blocks) handled_double_block_indices = set() handled_single_block_indices = set() @@ -537,8 +538,8 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): double_blocks_to_swap = args.double_blocks_to_swap single_blocks_to_swap = args.single_blocks_to_swap - num_double_blocks = len(flux.double_blocks) - num_single_blocks = len(flux.single_blocks) + num_double_blocks = 19 # len(flux.double_blocks) + num_single_blocks = 38 # len(flux.single_blocks) for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: @@ -618,7 +619,7 @@ def optimizer_hook(parameter: torch.Tensor): ) if is_swapping_blocks: - flux.prepare_block_swap_before_forward() + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() # For --sample_at_first flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) @@ -660,7 +661,7 @@ def optimizer_hook(parameter: torch.Tensor): with torch.no_grad(): input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] text_encoder_conds = text_encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask ) if args.full_fp16: text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] diff --git a/flux_train_network.py b/flux_train_network.py index 49bd270c7..3e2057e91 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -57,19 +57,21 @@ def load_target_model(self, args, weight_dtype, accelerator): name = self.get_flux_model_name(args) # if we load to cpu, flux.to(fp8) takes a long time - model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + model = flux_utils.load_flow_model( + name, args.pretrained_model_name_or_path, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + ) if args.split_mode: model = self.prepare_split_model(model, weight_dtype, accelerator) - clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) clip_l.eval() # loading t5xxl to cpu takes a long time, so we should load to gpu in future - t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) t5xxl.eval() - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model diff --git a/library/flux_models.py b/library/flux_models.py index c98d52ec0..c045aef6b 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -745,7 +745,9 @@ def custom_forward(*inputs): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask) + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False + ) else: return self._forward(img, txt, vec, pe, txt_attention_mask) @@ -836,7 +838,7 @@ def custom_forward(*inputs): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe) + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False) else: return self._forward(x, vec, pe) diff --git a/library/flux_utils.py b/library/flux_utils.py index 166cd833b..37166933a 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -9,7 +9,7 @@ from library import flux_models -from library.utils import setup_logging +from library.utils import setup_logging, MemoryEfficientSafeOpen setup_logging() import logging @@ -19,32 +19,54 @@ MODEL_VERSION_FLUX_V1 = "flux1" -def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux: +# temporary copy from sd3_utils TODO refactor +def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32): + if disable_mmap: + # return safetensors.torch.load(open(path, "rb").read()) + # use experimental loader + logger.info(f"Loading without mmap (experimental)") + state_dict = {} + with MemoryEfficientSafeOpen(path) as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) + return state_dict + else: + try: + return load_file(path, device=device) + except: + return load_file(path) # prevent device invalid Error + + +def load_flow_model( + name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +) -> flux_models.Flux: logger.info(f"Building Flux model {name}") with torch.device("meta"): model = flux_models.Flux(flux_models.configs[name].params).to(dtype) # load_sft doesn't support torch.device logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Flux: {info}") return model -def load_ae(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.AutoEncoder: +def load_ae( + name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +) -> flux_models.AutoEncoder: logger.info("Building AutoEncoder") with torch.device("meta"): ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype) logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = ae.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded AE: {info}") return ae -def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> CLIPTextModel: +def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> CLIPTextModel: logger.info("Building CLIP") CLIPL_CONFIG = { "_name_or_path": "clip-vit-large-patch14/", @@ -139,13 +161,13 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev clip = CLIPTextModel._from_config(config) logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = clip.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded CLIP: {info}") return clip -def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> T5EncoderModel: +def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel: T5_CONFIG_JSON = """ { "architectures": [ @@ -185,7 +207,7 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi t5xxl = T5EncoderModel._from_config(config) logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = t5xxl.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded T5xxl: {info}") return t5xxl diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 737af390a..b3643cbfc 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -137,7 +137,7 @@ def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: t5_attn_mask = data["t5_attn_mask"] if self.apply_t5_attn_mask: - t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) + t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) # FIXME do not mask here!!! return [l_pooled, t5_out, txt_ids, t5_attn_mask] @@ -149,7 +149,7 @@ def cache_batch_outputs( tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - # attn_mask is not applied when caching to disk: it is applied when loading from disk + # attn_mask is not applied when caching to disk: it is applied when loading from disk FIXME apply mask when loading l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens( tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk ) diff --git a/library/train_util.py b/library/train_util.py index 8929c192f..989758ad5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1104,10 +1104,6 @@ def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: boo caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() batch_size = caching_strategy.batch_size or self.batch_size - # if cache to disk, don't cache TE outputs in non-main process - if caching_strategy.cache_to_disk and not is_main_process: - return - logger.info("caching Text Encoder outputs with caching strategy.") image_infos = list(self.image_data.values()) @@ -1120,9 +1116,9 @@ def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: boo # check disk cache exists and size of latents if caching_strategy.cache_to_disk: - info.text_encoder_outputs_npz = te_out_npz + info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) - if cache_available: # do not add to batch + if cache_available or not is_main_process: # do not add to batch continue batch.append(info) @@ -2638,7 +2634,7 @@ def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset: return train_dataset_group -def load_image(image_path, alpha=False): +def load_image(image_path, alpha=False): try: with Image.open(image_path) as image: if alpha: diff --git a/library/utils.py b/library/utils.py index 7de22d5a9..a16209979 100644 --- a/library/utils.py +++ b/library/utils.py @@ -153,6 +153,95 @@ def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: v.contiguous().view(torch.uint8).numpy().tofile(f) +class MemoryEfficientSafeOpen: + # does not support metadata loading + def __init__(self, filename): + self.filename = filename + self.header, self.header_size = self._read_header() + self.file = open(filename, "rb") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def keys(self): + return [k for k in self.header.keys() if k != "__metadata__"] + + def get_tensor(self, key): + if key not in self.header: + raise KeyError(f"Tensor '{key}' not found in the file") + + metadata = self.header[key] + offset_start, offset_end = metadata["data_offsets"] + + if offset_start == offset_end: + tensor_bytes = None + else: + # adjust offset by header size + self.file.seek(self.header_size + 8 + offset_start) + tensor_bytes = self.file.read(offset_end - offset_start) + + return self._deserialize_tensor(tensor_bytes, metadata) + + def _read_header(self): + with open(self.filename, "rb") as f: + header_size = struct.unpack(" Date: Thu, 22 Aug 2024 19:55:31 +0900 Subject: [PATCH 068/163] Fix --debug_dataset to work. --- flux_train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/flux_train.py b/flux_train.py index e7d45e04d..410728d44 100644 --- a/flux_train.py +++ b/flux_train.py @@ -142,6 +142,12 @@ def train(args): args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False ) ) + name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" + t5xxl_max_token_length = ( + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if name == "schnell" else 512) + ) + strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) + train_dataset_group.set_current_strategies() train_util.debug_dataset(train_dataset_group, True) return From 2d8fa3387a4adfdc2e36f2582e4ffc21864569f0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 19:56:27 +0900 Subject: [PATCH 069/163] Fix to remove zero pad for t5xxl output --- README.md | 5 +++++ library/strategy_flux.py | 23 +++++++++++------------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 5125c6631..33b3a9a99 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,11 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 22, 2024 (update 2): +Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option. + +Added a script to extract LoRA from the difference between the two models of FLUX.1. Use `networks/flux_extract_lora.py`. See `--help` for details. Normally, more than 50GB of memory is required, but specifying the `--mem_eff_safe_open` option significantly reduces memory usage. However, this option is a custom implementation, so unexpected problems may occur. Please always check if the model is loaded correctly. + Aug 22, 2024: Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. diff --git a/library/strategy_flux.py b/library/strategy_flux.py index b3643cbfc..d52b3b8dd 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -22,7 +22,7 @@ class FluxTokenizeStrategy(TokenizeStrategy): - def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None: + def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None: self.t5xxl_max_length = t5xxl_max_length self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) @@ -120,25 +120,24 @@ def is_disk_cached_outputs_expected(self, npz_path: str): return False if "t5_attn_mask" not in npz: return False + if "apply_t5_attn_mask" not in npz: + return False + npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] + if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e return True - def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray: - return t5_out * np.expand_dims(t5_attn_mask, -1) - def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: data = np.load(npz_path) l_pooled = data["l_pooled"] t5_out = data["t5_out"] txt_ids = data["txt_ids"] t5_attn_mask = data["t5_attn_mask"] - - if self.apply_t5_attn_mask: - t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) # FIXME do not mask here!!! - + # apply_t5_attn_mask should be same as self.apply_t5_attn_mask return [l_pooled, t5_out, txt_ids, t5_attn_mask] def cache_batch_outputs( @@ -149,10 +148,8 @@ def cache_batch_outputs( tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - # attn_mask is not applied when caching to disk: it is applied when loading from disk FIXME apply mask when loading - l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk - ) + # attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True + l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks) if l_pooled.dtype == torch.bfloat16: l_pooled = l_pooled.float() @@ -171,6 +168,7 @@ def cache_batch_outputs( t5_out_i = t5_out[i] txt_ids_i = txt_ids[i] t5_attn_mask_i = t5_attn_mask[i] + apply_t5_attn_mask_i = self.apply_t5_attn_mask if self.cache_to_disk: np.savez( @@ -179,6 +177,7 @@ def cache_batch_outputs( t5_out=t5_out_i, txt_ids=txt_ids_i, t5_attn_mask=t5_attn_mask_i, + apply_t5_attn_mask=apply_t5_attn_mask_i, ) else: info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) From b0a980844a2e02b1b1ae4cf615ae489dbf8ece67 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 19:57:29 +0900 Subject: [PATCH 070/163] added a script to extract LoRA --- networks/flux_extract_lora.py | 219 ++++++++++++++++++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 networks/flux_extract_lora.py diff --git a/networks/flux_extract_lora.py b/networks/flux_extract_lora.py new file mode 100644 index 000000000..3ee6e816d --- /dev/null +++ b/networks/flux_extract_lora.py @@ -0,0 +1,219 @@ +# extract approximating LoRA by svd from two FLUX models +# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo! + +import argparse +import json +import os +import time +import torch +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from tqdm import tqdm +from library import flux_utils, sai_model_spec, model_util, sdxl_model_util +import lora +from library.utils import MemoryEfficientSafeOpen +from library.utils import setup_logging +from networks import lora_flux + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +# CLAMP_QUANTILE = 0.99 +# MIN_DIFF = 1e-1 + + +def save_to_file(file_name, state_dict, metadata, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + save_file(state_dict, file_name, metadata=metadata) + + +def svd( + model_org=None, + model_tuned=None, + save_to=None, + dim=4, + device=None, + save_precision=None, + clamp_quantile=0.99, + min_diff=0.01, + no_metadata=False, + mem_eff_safe_open=False, +): + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + calc_dtype = torch.float + save_dtype = str_to_dtype(save_precision) + store_device = "cpu" + + # open models + lora_weights = {} + if not mem_eff_safe_open: + # use original safetensors.safe_open + open_fn = lambda fn: safe_open(fn, framework="pt") + else: + logger.info("Using memory efficient safe_open") + open_fn = lambda fn: MemoryEfficientSafeOpen(fn) + + with open_fn(model_org) as fo: + # filter keys + keys = [] + for key in fo.keys(): + if not ("single_block" in key or "double_block" in key): + continue + if ".bias" in key: + continue + if "norm" in key: + continue + keys.append(key) + + with open_fn(model_tuned) as ft: + for key in tqdm(keys): + # get tensors and calculate difference + value_o = fo.get_tensor(key) + value_t = ft.get_tensor(key) + mat = value_t.to(calc_dtype) - value_o.to(calc_dtype) + del value_o, value_t + + # extract LoRA weights + if device: + mat = mat.to(device) + out_dim, in_dim = mat.size()[0:2] + rank = min(dim, in_dim, out_dim) # LoRA rank cannot exceed the original dim + + mat = mat.squeeze() + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, clamp_quantile) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + U = U.to(store_device, dtype=save_dtype).contiguous() + Vh = Vh.to(store_device, dtype=save_dtype).contiguous() + + print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}") + lora_weights[key] = (U, Vh) + del mat, U, S, Vh + + # make state dict for LoRA + lora_sd = {} + for key, (up_weight, down_weight) in lora_weights.items(): + lora_name = key.replace(".weight", "").replace(".", "_") + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + lora_name + lora_sd[lora_name + ".lora_up.weight"] = up_weight + lora_sd[lora_name + ".lora_down.weight"] = down_weight + lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) # same as rank + + # minimum metadata + net_kwargs = {} + metadata = { + "ss_v2": str(False), + "ss_base_model_version": flux_utils.MODEL_VERSION_FLUX_V1, + "ss_network_module": "networks.lora_flux", + "ss_network_dim": str(dim), + "ss_network_alpha": str(float(dim)), + "ss_network_args": json.dumps(net_kwargs), + } + + if not no_metadata: + title = os.path.splitext(os.path.basename(save_to))[0] + sai_metadata = sai_model_spec.build_metadata(lora_sd, False, False, False, True, False, time.time(), title, flux="dev") + metadata.update(sai_metadata) + + save_to_file(save_to, lora_sd, metadata, save_dtype) + + logger.info(f"LoRA weights saved to {save_to}") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat", + ) + parser.add_argument( + "--model_org", + type=str, + default=None, + required=True, + help="Original model: safetensors file / 元モデル、safetensors", + ) + parser.add_argument( + "--model_tuned", + type=str, + default=None, + required=True, + help="Tuned model, LoRA is difference of `original to tuned`: safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", + ) + parser.add_argument( + "--mem_eff_safe_open", + action="store_true", + help="use memory efficient safe_open. This is an experimental feature, use only when memory is not enough." + " / メモリ効率の良いsafe_openを使用する。実装は実験的なものなので、メモリが足りない場合のみ使用してください。", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + required=True, + help="destination file name: safetensors file / 保存先のファイル名、safetensors", + ) + parser.add_argument( + "--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)" + ) + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) + parser.add_argument( + "--clamp_quantile", + type=float, + default=0.99, + help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99", + ) + # parser.add_argument( + # "--min_diff", + # type=float, + # default=0.01, + # help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /" + # + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01", + # ) + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + svd(**vars(args)) From bf9f798985dd75fc2dd1fbc8c8dc775c92176854 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 19:59:38 +0900 Subject: [PATCH 071/163] chore: fix typos, remove debug print --- networks/flux_extract_lora.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/networks/flux_extract_lora.py b/networks/flux_extract_lora.py index 3ee6e816d..63ab2960c 100644 --- a/networks/flux_extract_lora.py +++ b/networks/flux_extract_lora.py @@ -68,10 +68,10 @@ def str_to_dtype(p): logger.info("Using memory efficient safe_open") open_fn = lambda fn: MemoryEfficientSafeOpen(fn) - with open_fn(model_org) as fo: + with open_fn(model_org) as f_org: # filter keys keys = [] - for key in fo.keys(): + for key in f_org.keys(): if not ("single_block" in key or "double_block" in key): continue if ".bias" in key: @@ -80,11 +80,11 @@ def str_to_dtype(p): continue keys.append(key) - with open_fn(model_tuned) as ft: + with open_fn(model_tuned) as f_tuned: for key in tqdm(keys): # get tensors and calculate difference - value_o = fo.get_tensor(key) - value_t = ft.get_tensor(key) + value_o = f_org.get_tensor(key) + value_t = f_tuned.get_tensor(key) mat = value_t.to(calc_dtype) - value_o.to(calc_dtype) del value_o, value_t @@ -114,7 +114,7 @@ def str_to_dtype(p): U = U.to(store_device, dtype=save_dtype).contiguous() Vh = Vh.to(store_device, dtype=save_dtype).contiguous() - print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}") + # print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}") lora_weights[key] = (U, Vh) del mat, U, S, Vh From afb971f9c36823040eaba3c9e02fdfa0928cd4ee Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 21:33:15 +0900 Subject: [PATCH 072/163] fix SD1.5 LoRA extraction #1490 --- networks/lora.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 82b8b5b47..6f33f1a1e 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -815,7 +815,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh weights_sd = torch.load(file, map_location="cpu") # if keys are Diffusers based, convert to SAI based - convert_diffusers_to_sai_if_needed(weights_sd) + if is_sdxl: + convert_diffusers_to_sai_if_needed(weights_sd) # get dim/alpha mapping modules_dim = {} @@ -840,7 +841,13 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh module_class = LoRAInfModule if for_inference else LoRAModule network = LoRANetwork( - text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + text_encoder, + unet, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + is_sdxl=is_sdxl, ) # block lr From 1e8108fec9962333e4cf2a8db1dcedf657049900 Mon Sep 17 00:00:00 2001 From: liesen Date: Sat, 24 Aug 2024 01:38:17 +0300 Subject: [PATCH 073/163] Handle args.v_parameterization properly for MinSNR and changed prediction target --- sdxl_train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860be..14b259657 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -590,7 +590,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with accelerator.autocast(): noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - target = noise + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise if ( args.min_snr_gamma @@ -606,7 +610,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: From 2e89cd2cc634c27add7a04c21fcb6d0e16716a2b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 12:39:54 +0900 Subject: [PATCH 074/163] Fix issue with attention mask not being applied in single blocks --- README.md | 3 ++ flux_train_network.py | 4 +-- library/flux_models.py | 62 +++++++++++++++++++++--------------------- 3 files changed, 36 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 33b3a9a99..4151bf44e 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 24, 2024: +Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified. + Aug 22, 2024 (update 2): Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option. diff --git a/flux_train_network.py b/flux_train_network.py index 3e2057e91..82f77a77e 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -243,7 +243,7 @@ def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_a self.flux_upper.to("cpu") clean_memory_on_device(self.target_device) self.flux_lower.to(self.target_device) - return self.flux_lower(img, txt, vec, pe) + return self.flux_lower(img, txt, vec, pe, txt_attention_mask) wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) clean_memory_on_device(accelerator.device) @@ -352,7 +352,7 @@ def get_noise_pred_and_target( intermediate_txt.requires_grad_(True) vec.requires_grad_(True) pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) diff --git a/library/flux_models.py b/library/flux_models.py index c045aef6b..b5726c298 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -752,18 +752,6 @@ def custom_forward(*inputs): else: return self._forward(img, txt, vec, pe, txt_attention_mask) - # def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): - # if self.training and self.gradient_checkpointing: - # def create_custom_forward(func): - # def custom_forward(*inputs): - # return func(*inputs) - # return custom_forward - # return torch.utils.checkpoint.checkpoint( - # create_custom_forward(self._forward), img, txt, vec, pe, use_reentrant=USE_REENTRANT - # ) - # else: - # return self._forward(img, txt, vec, pe) - class SingleStreamBlock(nn.Module): """ @@ -809,7 +797,7 @@ def disable_gradient_checkpointing(self): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False - def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: mod, _ = self.modulation(vec) x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) @@ -817,16 +805,35 @@ def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = self.norm(q, k, v) + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len + attn_mask = torch.cat( + ( + attn_mask, + torch.ones( + attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool + ), + ), + dim=1, + ) # b, seq_len + img_len = x_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + # compute attention - attn = attention(q, k, v, pe=pe) + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) return x + mod.gate * output - def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: if self.training and self.gradient_checkpointing: if not self.cpu_offload_checkpointing: - return checkpoint(self._forward, x, vec, pe, use_reentrant=False) + return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False) # cpu offload checkpointing @@ -838,19 +845,11 @@ def custom_forward(*inputs): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False) + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False + ) else: - return self._forward(x, vec, pe) - - # def forward(self, x: Tensor, vec: Tensor, pe: Tensor): - # if self.training and self.gradient_checkpointing: - # def create_custom_forward(func): - # def custom_forward(*inputs): - # return func(*inputs) - # return custom_forward - # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=USE_REENTRANT) - # else: - # return self._forward(x, vec, pe) + return self._forward(x, vec, pe, txt_attention_mask) class LastLayer(nn.Module): @@ -1053,7 +1052,7 @@ def forward( if not self.single_blocks_to_swap: for block in self.single_blocks: - img = block(img, vec=vec, pe=pe) + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning for block_idx in range(self.single_blocks_to_swap): @@ -1075,7 +1074,7 @@ def forward( block.to(self.device) # move to cuda # print(f"Moved single block {block_idx} to cuda.") - img = block(img, vec=vec, pe=pe) + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) if moving: self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) @@ -1250,10 +1249,11 @@ def forward( txt: Tensor, vec: Tensor | None = None, pe: Tensor | None = None, + txt_attention_mask: Tensor | None = None, ) -> Tensor: img = torch.cat((txt, img), 1) for block in self.single_blocks: - img = block(img, vec=vec, pe=pe) + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) img = img[:, txt.shape[1] :, ...] img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) From cf689e7aa697877a0eee58622035ab702ce59d3e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 16:35:43 +0900 Subject: [PATCH 075/163] feat: Add option to split projection layers and apply LoRA --- README.md | 14 ++ networks/check_lora_weights.py | 2 +- networks/convert_flux_lora.py | 51 ++++-- networks/lora_flux.py | 326 +++++++++++++++++++++++++++------ 4 files changed, 325 insertions(+), 68 deletions(-) diff --git a/README.md b/README.md index 4151bf44e..7d326a867 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,20 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 24, 2024 (update 2): + +__Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). + +The number of parameters may increase slightly, so the expressiveness may increase, but the training time may be longer. No detailed verification has been done. + +This implementation is experimental, so it may be deprecated or changed in the future. + +The .safetensors file of the trained model is compatible with the normal LoRA model of sd-scripts, so it should be usable in inference environments such as ComfyUI as it is. Also, converting it to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size. It should be no problem to convert it if you use it in the inference environment. + +Technical details: In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter. + +The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. + Aug 24, 2024: Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified. diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 794659c94..b5b5e61ae 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -18,7 +18,7 @@ def main(file): keys = list(sd.keys()) for key in keys: - if "lora_up" in key or "lora_down" in key: + if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key: values.append((key, sd[key])) print(f"number of LoRA modules: {len(values)}") diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py index e9743534d..bd4c1cf78 100644 --- a/networks/convert_flux_lora.py +++ b/networks/convert_flux_lora.py @@ -266,11 +266,12 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): if sds_key + ".lora_down.weight" not in sds_sd: return down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + sd_lora_rank = down_weight.shape[0] # scale weight by alpha and dim - rank = down_weight.shape[0] alpha = sds_sd.pop(sds_key + ".alpha") - scale = alpha / rank + scale = alpha / sd_lora_rank # calculate scale_down and scale_up scale_down = scale @@ -279,23 +280,49 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): scale_down *= 2 scale_up /= 2 - ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] - ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] - - num_splits = len(ait_keys) - up_weight = sds_sd.pop(sds_key + ".lora_up.weight") - - # down_weight is copied to each split - ait_sd.update({k: down_weight * scale_down for k in ait_down_keys}) + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up # calculate dims if not provided + num_splits = len(ait_keys) if dims is None: dims = [up_weight.shape[0] // num_splits] * num_splits else: assert sum(dims) == up_weight.shape[0] - # up_weight is split to each split - ait_sd.update({k: v * scale_up for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) + # check upweight is sparse or not + is_sparse = False + if sd_lora_rank % num_splits == 0: + ait_rank = sd_lora_rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all(up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0) + i += dims[j] + if is_sparse: + logger.info(f"weight is sparse: {sds_key}") + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + if not is_sparse: + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) + else: + # down_weight is chunked to each split + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) + + # up_weight is sparse: only non-zero values are copied to each split + i = 0 + for j in range(len(dims)): + ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() + i += dims[j] def convert_sd_scripts_to_ai_toolkit(sds_sd): diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 4da33542f..efc7847ed 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -39,6 +39,7 @@ def __init__( dropout=None, rank_dropout=None, module_dropout=None, + split_dims: Optional[List[int]] = None, ): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__() @@ -52,16 +53,34 @@ def __init__( out_dim = org_module.out_features self.lora_dim = lora_dim + self.split_dims = split_dims + + if split_dims is None: + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) - if org_module.__class__.__name__ == "Conv2d": - kernel_size = org_module.kernel_size - stride = org_module.stride - padding = org_module.padding - self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) else: - self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + # conv2d not supported + assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" + assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" + # print(f"split_dims: {split_dims}") + self.lora_down = torch.nn.ModuleList( + [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] + ) + self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) + for lora_down in self.lora_down: + torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + for lora_up in self.lora_up: + torch.nn.init.zeros_(lora_up.weight) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error @@ -70,9 +89,6 @@ def __init__( self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える # same as microsoft's - torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_up.weight) - self.multiplier = multiplier self.org_module = org_module # remove in applying self.dropout = dropout @@ -92,30 +108,56 @@ def forward(self, x): if torch.rand(1) < self.module_dropout: return org_forwarded - lx = self.lora_down(x) - - # normal dropout - if self.dropout is not None and self.training: - lx = torch.nn.functional.dropout(lx, p=self.dropout) + if self.split_dims is None: + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale - # rank dropout - if self.rank_dropout is not None and self.training: - mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout - if len(lx.size()) == 3: - mask = mask.unsqueeze(1) # for Text Encoder - elif len(lx.size()) == 4: - mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d - lx = lx * mask + lx = self.lora_up(lx) - # scaling for rank dropout: treat as if the rank is changed - # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる - scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + return org_forwarded + lx * self.multiplier * scale else: - scale = self.scale + lxs = [lora_down(x) for lora_down in self.lora_down] + + # normal dropout + if self.dropout is not None and self.training: + lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] + + # rank dropout + if self.rank_dropout is not None and self.training: + masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] + for i in range(len(lxs)): + if len(lx.size()) == 3: + masks[i] = masks[i].unsqueeze(1) + elif len(lx.size()) == 4: + masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) + lxs[i] = lxs[i] * masks[i] + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale - lx = self.lora_up(lx) + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] - return org_forwarded + lx * self.multiplier * scale + return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale class LoRAInfModule(LoRAModule): @@ -152,31 +194,50 @@ def merge_to(self, sd, dtype, device): if device is None: device = org_device - # get up/down weight - up_weight = sd["lora_up.weight"].to(torch.float).to(device) - down_weight = sd["lora_down.weight"].to(torch.float).to(device) - - # merge weight - if len(weight.size()) == 2: - # linear - weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + self.multiplier - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * self.scale - ) + if self.split_dims is None: + # get up/down weight + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + self.multiplier * conved * self.scale + # split_dims + total_dims = sum(self.split_dims) + for i in range(len(self.split_dims)): + # get up/down weight + down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim) + up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank) + + # pad up_weight -> (total_dims, rank) + padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) + padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight + + # merge weight + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale - # set weight to org_module - org_sd["weight"] = weight.to(dtype) - self.org_module.load_state_dict(org_sd) + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) # 復元できるマージのため、このモジュールのweightを返す def get_weight(self, multiplier=None): @@ -211,7 +272,14 @@ def set_region(self, region): def default_forward(self, x): # logger.info(f"default_forward {self.lora_name} {x.size()}") - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + if self.split_dims is None: + lx = self.lora_down(x) + lx = self.lora_up(lx) + return self.org_forward(x) + lx * self.multiplier * self.scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale def forward(self, x): if not self.enabled: @@ -257,6 +325,11 @@ def create_network( if train_blocks is not None: assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -270,6 +343,7 @@ def create_network( conv_lora_dim=conv_dim, conv_alpha=conv_alpha, train_blocks=train_blocks, + split_qkv=split_qkv, varbose=True, ) @@ -311,10 +385,34 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_dim[lora_name] = dim # logger.info(lora_name, value.size(), dim) + # # split qkv + # double_qkv_rank = None + # single_qkv_rank = None + # rank = None + # for lora_name, dim in modules_dim.items(): + # if "double" in lora_name and "qkv" in lora_name: + # double_qkv_rank = dim + # elif "single" in lora_name and "linear1" in lora_name: + # single_qkv_rank = dim + # elif rank is None: + # rank = dim + # if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None: + # break + # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( + # single_qkv_rank is not None and single_qkv_rank != rank + # ) + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + module_class = LoRAInfModule if for_inference else LoRAModule network = LoRANetwork( - text_encoders, flux, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + text_encoders, + flux, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, ) return network, weights_sd @@ -344,6 +442,7 @@ def __init__( modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, train_blocks: Optional[str] = None, + split_qkv: bool = False, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -357,6 +456,7 @@ def __init__( self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.train_blocks = train_blocks if train_blocks is not None else "all" + self.split_qkv = split_qkv self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -373,6 +473,8 @@ def __init__( logger.info( f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" ) + if self.split_qkv: + logger.info(f"split qkv for LoRA") # create module instances def create_modules( @@ -420,6 +522,14 @@ def create_modules( skipped.append(lora_name) continue + # qkv split + split_dims = None + if is_flux and split_qkv: + if "double" in lora_name and "qkv" in lora_name: + split_dims = [3072] * 3 + elif "single" in lora_name and "linear1" in lora_name: + split_dims = [3072] * 3 + [12288] + lora = module_class( lora_name, child_module, @@ -429,6 +539,7 @@ def create_modules( dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + split_dims=split_dims, ) loras.append(lora) return loras, skipped @@ -492,6 +603,111 @@ def load_weights(self, file): info = self.load_state_dict(weights_sd, False) return info + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to splitted qkv weight + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, len(split_dims), dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // len(split_dims) + i = 0 + for j in range(len(split_dims)): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dims[j], j * rank : (j + 1) * rank] + i += split_dims[j] + del state_dict[key] + + # # check is sparse + # i = 0 + # is_zero = True + # for j in range(len(split_dims)): + # for k in range(len(split_dims)): + # if j == k: + # continue + # is_zero = is_zero and torch.all(weight[i : i + split_dims[j], k * rank : (k + 1) * rank] == 0) + # i += split_dims[j] + # if not is_zero: + # logger.warning(f"weight is not sparse: {key}") + # else: + # logger.info(f"weight is sparse: {key}") + + # print( + # f"split {key}: {weight.shape} to {[state_dict[k].shape for k in [f'{lora_name}.lora_up.{j}.weight' for j in range(len(split_dims))]]}" + # ) + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + rank = up_weights[0].size(1) + up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(len(split_dims)): + up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] + i += split_dims[j] + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") From 5639c2adc0085e2e995bb3eee5a278aace397e7a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 16:37:49 +0900 Subject: [PATCH 076/163] fix typo --- networks/lora_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index efc7847ed..07a80f0bf 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -604,7 +604,7 @@ def load_weights(self, file): return info def load_state_dict(self, state_dict, strict=True): - # override to convert original weight to splitted qkv weight + # override to convert original weight to split qkv if not self.split_qkv: return super().load_state_dict(state_dict, strict) From d5c076cf9007f86f6dd1b9ecdfc5531336774b2f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 21:21:39 +0900 Subject: [PATCH 077/163] update readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 946df58f3..81a549378 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened! - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. From 72287d39c76176c0e1c16e8da4f5ddc6f94ea7d6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 25 Aug 2024 16:01:24 +0900 Subject: [PATCH 078/163] feat: Add `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training --- README.md | 4 ++++ library/flux_train_utils.py | 15 +++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 282f3b3bd..562dcdb2a 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,10 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 25, 2024: +Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. +Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` + Aug 24, 2024 (update 2): __Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 1d3f80d72..75f70a54f 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -380,9 +380,19 @@ def get_noisy_model_input_and_timesteps( t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: t = torch.rand((bsz,), device=device) + timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "shift": + shift = args.discrete_flow_shift + logits_norm = torch.randn(bsz, device=device) + timesteps = logits_norm.sigmoid() + timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -559,9 +569,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--timestep_sampling", - choices=["sigma", "uniform", "sigmoid"], + choices=["sigma", "uniform", "sigmoid", "shift"], default="sigma", - help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", ) parser.add_argument( "--sigmoid_scale", From 0087a46e14c8e568982cbe3a5d9b9c561b175abf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Aug 2024 19:59:40 +0900 Subject: [PATCH 079/163] FLUX.1 LoRA supports CLIP-L --- README.md | 8 ++++ flux_train_network.py | 40 +++++++++++++----- library/flux_train_utils.py | 8 ++-- library/strategy_flux.py | 3 +- networks/lora_flux.py | 4 +- train_network.py | 81 ++++++++++++++++++++++++------------- 6 files changed, 101 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 562dcdb2a..1203b5ebc 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,14 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 27, 2024: + +- FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI. + - `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. +- `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. + +- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is not required (Flux is fp8, and CLIP-L is bf16/fp16, regardless of the `--fp8_base` option). + Aug 25, 2024: Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` diff --git a/flux_train_network.py b/flux_train_network.py index 82f77a77e..1a40de61a 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -40,9 +40,13 @@ def assert_extra_args(self, args, train_dataset_group): train_dataset_group.is_text_encoder_output_cacheable() ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - assert ( - args.network_train_unet_only or not args.cache_text_encoder_outputs - ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + # assert ( + # args.network_train_unet_only or not args.cache_text_encoder_outputs + # ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + if not args.network_train_unet_only: + logger.info( + "network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません" + ) if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") @@ -137,12 +141,25 @@ def get_text_encoding_strategy(self, args): return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) def get_models_for_text_encoding(self, args, accelerator, text_encoders): - return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])] + if args.cache_text_encoder_outputs: + if self.is_train_text_encoder(args): + return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached + else: + return text_encoders # ignored + else: + return text_encoders # both CLIP-L and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [True, False] if self.is_train_text_encoder(args) else [False, False] def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: return strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False, apply_t5_attn_mask=args.apply_t5_attn_mask + args.cache_text_encoder_outputs_to_disk, + None, + False, + is_partial=self.is_train_text_encoder(args), + apply_t5_attn_mask=args.apply_t5_attn_mask, ) else: return None @@ -190,9 +207,11 @@ def cache_text_encoder_outputs_if_needed( accelerator.wait_for_everyone() # move back to cpu - logger.info("move text encoders back to cpu") - text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU - text_encoders[1].to("cpu") # , dtype=torch.float32) + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[1].to("cpu") clean_memory_on_device(accelerator.device) if not args.lowram: @@ -297,7 +316,8 @@ def get_noise_pred_and_target( if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) for t in text_encoder_conds: - t.requires_grad_(True) + if t.dtype.is_floating_point: + t.requires_grad_(True) img_ids.requires_grad_(True) guidance_vec.requires_grad_(True) @@ -384,7 +404,7 @@ def update_metadata(self, metadata, args): metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift def is_text_encoder_not_needed_for_training(self, args): - return args.cache_text_encoder_outputs + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) def setup_parser() -> argparse.ArgumentParser: diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 75f70a54f..a8e94ac00 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -58,7 +58,7 @@ def sample_images( logger.info("") logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") - if not os.path.isfile(args.sample_prompts): + if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None: logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return @@ -66,7 +66,8 @@ def sample_images( # unwrap unet and text_encoder(s) flux = accelerator.unwrap_model(flux) - text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + if text_encoders is not None: + text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) prompts = load_prompts(args.sample_prompts) @@ -134,7 +135,7 @@ def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, flux: flux_models.Flux, - text_encoders: List[CLIPTextModel], + text_encoders: Optional[List[CLIPTextModel]], ae: flux_models.AutoEncoder, save_dir, prompt_dict, @@ -387,6 +388,7 @@ def get_noisy_model_input_and_timesteps( elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index d52b3b8dd..5d0839132 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -60,7 +60,7 @@ def encode_tokens( if apply_t5_attn_mask is None: apply_t5_attn_mask = self.apply_t5_attn_mask - clip_l, t5xxl = models + clip_l, t5xxl = models if len(models) == 2 else (models[0], None) l_tokens, t5_tokens = tokens[:2] t5_attn_mask = tokens[2] if len(tokens) > 2 else None @@ -81,6 +81,7 @@ def encode_tokens( else: t5_out = None txt_ids = None + t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 07a80f0bf..fcb56a467 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -401,7 +401,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( # single_qkv_rank is not None and single_qkv_rank != rank # ) - split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined module_class = LoRAInfModule if for_inference else LoRAModule @@ -421,7 +421,7 @@ class LoRANetwork(torch.nn.Module): # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" diff --git a/train_network.py b/train_network.py index cab0ec52e..048c7e7bd 100644 --- a/train_network.py +++ b/train_network.py @@ -127,8 +127,15 @@ def get_text_encoder_outputs_caching_strategy(self, args): return None def get_models_for_text_encoding(self, args, accelerator, text_encoders): + """ + Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models. + """ return text_encoders + # returns a list of bool values indicating whether each text encoder should be trained + def get_text_encoders_train_flags(self, args, text_encoders): + return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders) + def is_train_text_encoder(self, args): return not args.network_train_unet_only @@ -136,11 +143,6 @@ def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, tex for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype) - return encoder_hidden_states - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred @@ -313,7 +315,7 @@ def train(self, args): collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: - train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: @@ -437,8 +439,10 @@ def train(self, args): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - for t_enc in text_encoders: - t_enc.gradient_checkpointing_enable() + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): + if flag: + if t_enc.supports_gradient_checkpointing: + t_enc.gradient_checkpointing_enable() del t_enc network.enable_gradient_checkpointing() # may have no effect @@ -522,14 +526,17 @@ def train(self, args): unet_weight_dtype = te_weight_dtype = weight_dtype # Experimental Feature: Put base model into fp8 to save vram - if args.fp8_base: + if args.fp8_base or args.fp8_base_unet: assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" assert ( args.mixed_precision != "no" ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" - accelerator.print("enable fp8 training.") + accelerator.print("enable fp8 training for U-Net.") unet_weight_dtype = torch.float8_e4m3fn - te_weight_dtype = torch.float8_e4m3fn + + if not args.fp8_base_unet: + accelerator.print("enable fp8 training for Text Encoder.") + te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory @@ -546,19 +553,18 @@ def train(self, args): t_enc.to(dtype=te_weight_dtype) if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"): # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to( - dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"): - t_enc.encoder.embeddings.to( - dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: + flags = self.get_text_encoders_train_flags(args, text_encoders) ds_model = deepspeed_utils.prepare_deepspeed_model( args, unet=unet if train_unet else None, - text_encoder1=text_encoders[0] if train_text_encoder else None, - text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None, + text_encoder1=text_encoders[0] if flags[0] else None, + text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None, network=network, ) ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -571,11 +577,14 @@ def train(self, args): else: unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: + text_encoders = [ + (accelerator.prepare(t_enc) if flag else t_enc) + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)) + ] if len(text_encoders) > 1: - text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] + text_encoder = text_encoders else: - text_encoder = accelerator.prepare(text_encoder) - text_encoders = [text_encoder] + text_encoder = text_encoders[0] else: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set @@ -587,11 +596,11 @@ def train(self, args): if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() - for t_enc in text_encoders: + for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): t_enc.train() # set top parameter requires_grad = True for gradient checkpointing works - if train_text_encoder: + if frag: t_enc.text_model.embeddings.requires_grad_(True) else: @@ -736,6 +745,7 @@ def load_model_hook(models, input_dir): "ss_huber_schedule": args.huber_schedule, "ss_huber_c": args.huber_c, "ss_fp8_base": args.fp8_base, + "ss_fp8_base_unet": args.fp8_base_unet, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1004,6 +1014,7 @@ def remove_model(old_ckpt_name): for t_enc in text_encoders: del t_enc text_encoders = [] + text_encoder = None # For --sample_at_first self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) @@ -1018,7 +1029,7 @@ def remove_model(old_ckpt_name): # log device and dtype for each model logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") for t_enc in text_encoders: - logger.info(f"text_encoder dtype: {te_weight_dtype}, device: {t_enc.device}") + logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}") clean_memory_on_device(accelerator.device) @@ -1073,12 +1084,17 @@ def remove_model(old_ckpt_name): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - else: + if ( + text_encoder_conds is None + or len(text_encoder_conds) == 0 + or text_encoder_conds[0] is None + or train_text_encoder + ): with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: # SD only - text_encoder_conds = get_weighted_text_embeddings( + encoded_text_encoder_conds = get_weighted_text_embeddings( tokenizers[0], text_encoder, batch["captions"], @@ -1088,13 +1104,18 @@ def remove_model(old_ckpt_name): ) else: input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] - text_encoder_conds = text_encoding_strategy.encode_tokens( + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids, ) if args.full_fp16: - text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( @@ -1257,6 +1278,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument( + "--fp8_base_unet", + action="store_true", + help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16" + " / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16", + ) parser.add_argument( "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" From 3be712e3e011b0378fad389641cec0c1869555ab Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Aug 2024 21:40:02 +0900 Subject: [PATCH 080/163] feat: Update direct loading fp8 ckpt for LoRA training --- README.md | 7 +++- flux_minimal_inference.py | 27 +----------- flux_train_network.py | 16 +++++++- library/flux_utils.py | 12 ++++-- library/utils.py | 62 +++++++++++++++++++++++++++- networks/flux_merge_lora.py | 82 ++++++++++++++++++++++++++----------- 6 files changed, 151 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index 1203b5ebc..0108ada59 100644 --- a/README.md +++ b/README.md @@ -9,13 +9,18 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 27, 2024 (update 2): +In FLUX.1 LoRA training, when `--fp8_base` is specified, the FLUX.1 model file with fp8 (`float8_e4m3fn` type) can be loaded directly. Also, in `flux_minimal_inference.py`, it is possible to load it by specifying `fp8 (float8_e4m3fn)` in `--flux_dtype`. + +In `flux_merge_lora.py`, you can now specify the precision at save time with `fp8` (see `--help` for details). Also, if you do not specify the merge model, only the model type conversion will be performed. + Aug 27, 2024: - FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI. - `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. - `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. -- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is not required (Flux is fp8, and CLIP-L is bf16/fp16, regardless of the `--fp8_base` option). +- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is automatically enabled. Aug 25, 2024: Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 5b8aa2506..56c1b1982 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -10,7 +10,6 @@ import numpy as np import torch -from safetensors.torch import safe_open, load_file from tqdm import tqdm from PIL import Image import accelerate @@ -21,7 +20,7 @@ init_ipex() -from library.utils import setup_logging +from library.utils import setup_logging, str_to_dtype setup_logging() import logging @@ -288,28 +287,6 @@ def generate_image( name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way is_schnell = name == "schnell" - def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: - if s is None: - return default_dtype - if s in ["bf16", "bfloat16"]: - return torch.bfloat16 - elif s in ["fp16", "float16"]: - return torch.float16 - elif s in ["fp32", "float32"]: - return torch.float32 - elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: - return torch.float8_e4m3fn - elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: - return torch.float8_e4m3fnuz - elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: - return torch.float8_e5m2 - elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: - return torch.float8_e5m2fnuz - elif s in ["fp8", "float8"]: - return torch.float8_e4m3fn # default fp8 - else: - raise ValueError(f"Unsupported dtype: {s}") - def is_fp8(dt): return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] @@ -348,7 +325,7 @@ def is_fp8(dt): encoding_strategy = strategy_flux.FluxTextEncodingStrategy() # DiT - model = flux_utils.load_flow_model(name, args.ckpt_path, flux_dtype, loading_device) + model = flux_utils.load_flow_model(name, args.ckpt_path, None, loading_device) model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype diff --git a/flux_train_network.py b/flux_train_network.py index 1a40de61a..4a63c2de4 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -29,6 +29,9 @@ def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: logger.warning( "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" @@ -61,9 +64,20 @@ def load_target_model(self, args, weight_dtype, accelerator): name = self.get_flux_model_name(args) # if we load to cpu, flux.to(fp8) takes a long time + if args.fp8_base: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + model = flux_utils.load_flow_model( - name, args.pretrained_model_name_or_path, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors ) + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 FLUX model") if args.split_mode: model = self.prepare_split_model(model, weight_dtype, accelerator) diff --git a/library/flux_utils.py b/library/flux_utils.py index 37166933a..680836168 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,5 +1,5 @@ import json -from typing import Union +from typing import Optional, Union import einops import torch @@ -20,7 +20,9 @@ # temporary copy from sd3_utils TODO refactor -def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32): +def load_safetensors( + path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 +): if disable_mmap: # return safetensors.torch.load(open(path, "rb").read()) # use experimental loader @@ -38,11 +40,13 @@ def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: def load_flow_model( - name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False + name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False ) -> flux_models.Flux: logger.info(f"Building Flux model {name}") with torch.device("meta"): - model = flux_models.Flux(flux_models.configs[name].params).to(dtype) + model = flux_models.Flux(flux_models.configs[name].params) + if dtype is not None: + model = model.to(dtype) # load_sft doesn't support torch.device logger.info(f"Loading state dict from {ckpt_path}") diff --git a/library/utils.py b/library/utils.py index a16209979..d355cb109 100644 --- a/library/utils.py +++ b/library/utils.py @@ -82,6 +82,66 @@ def setup_logging(args=None, log_level=None, reset=False): logger.info(msg_init) +def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: + """ + Convert a string to a torch.dtype + + Args: + s: string representation of the dtype + default_dtype: default dtype to return if s is None + + Returns: + torch.dtype: the corresponding torch.dtype + + Raises: + ValueError: if the dtype is not supported + + Examples: + >>> str_to_dtype("float32") + torch.float32 + >>> str_to_dtype("fp32") + torch.float32 + >>> str_to_dtype("float16") + torch.float16 + >>> str_to_dtype("fp16") + torch.float16 + >>> str_to_dtype("bfloat16") + torch.bfloat16 + >>> str_to_dtype("bf16") + torch.bfloat16 + >>> str_to_dtype("fp8") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fn") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fnuz") + torch.float8_e4m3fnuz + >>> str_to_dtype("fp8_e5m2") + torch.float8_e5m2 + >>> str_to_dtype("fp8_e5m2fnuz") + torch.float8_e5m2fnuz + """ + if s is None: + return default_dtype + if s in ["bf16", "bfloat16"]: + return torch.bfloat16 + elif s in ["fp16", "float16"]: + return torch.float16 + elif s in ["fp32", "float32", "float"]: + return torch.float32 + elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: + return torch.float8_e4m3fn + elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: + return torch.float8_e4m3fnuz + elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: + return torch.float8_e5m2 + elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: + return torch.float8_e5m2fnuz + elif s in ["fp8", "float8"]: + return torch.float8_e4m3fn # default fp8 + else: + raise ValueError(f"Unsupported dtype: {s}") + + def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): """ memory efficient save file @@ -198,7 +258,7 @@ def _deserialize_tensor(self, tensor_bytes, metadata): if tensor_bytes is None: byte_tensor = torch.empty(0, dtype=torch.uint8) else: - tensor_bytes = bytearray(tensor_bytes) # make it writable + tensor_bytes = bytearray(tensor_bytes) # make it writable byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8) # process float8 types diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index d5e82920d..2e0d4c297 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -8,7 +8,7 @@ from safetensors.torch import load_file, save_file from tqdm import tqdm -from library.utils import setup_logging +from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file setup_logging() import logging @@ -34,18 +34,23 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, state_dict, dtype, metadata): +def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False): if dtype is not None: logger.info(f"converting to {dtype}...") - for key in list(state_dict.keys()): + for key in tqdm(list(state_dict.keys())): if type(state_dict[key]) == torch.Tensor: state_dict[key] = state_dict[key].to(dtype) logger.info(f"saving to: {file_name}") - save_file(state_dict, file_name, metadata=metadata) + if mem_eff_save: + mem_eff_save_file(state_dict, file_name, metadata=metadata) + else: + save_file(state_dict, file_name, metadata=metadata) -def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): +def merge_to_flux_model( + loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False +): # create module map without loading state_dict logger.info(f"loading keys from FLUX.1 model: {flux_model}") lora_name_to_module_key = {} @@ -57,7 +62,14 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") lora_name_to_module_key[lora_name] = key - flux_state_dict = load_file(flux_model, device=loading_device) + if mem_eff_load_save: + flux_state_dict = {} + with MemoryEfficientSafeOpen(flux_model) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + else: + flux_state_dict = load_file(flux_model, device=loading_device) + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU @@ -120,9 +132,17 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati return flux_state_dict -def merge_to_flux_model_diffusers(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): +def merge_to_flux_model_diffusers( + loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False +): logger.info(f"loading keys from FLUX.1 model: {flux_model}") - flux_state_dict = load_file(flux_model, device=loading_device) + if mem_eff_load_save: + flux_state_dict = {} + with MemoryEfficientSafeOpen(flux_model) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + else: + flux_state_dict = load_file(flux_model, device=loading_device) def create_key_map(n_double_layers, n_single_layers): key_map = {} @@ -474,19 +494,15 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): def merge(args): + if args.models is None: + args.models = [] + if args.ratios is None: + args.ratios = [] + assert len(args.models) == len( args.ratios ), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" - def str_to_dtype(p): - if p == "float": - return torch.float - if p == "fp16": - return torch.float16 - if p == "bf16": - return torch.bfloat16 - return None - merge_dtype = str_to_dtype(args.precision) save_dtype = str_to_dtype(args.save_precision) if save_dtype is None: @@ -500,11 +516,25 @@ def str_to_dtype(p): if args.flux_model is not None: if not args.diffusers: state_dict = merge_to_flux_model( - args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + args.loading_device, + args.working_device, + args.flux_model, + args.models, + args.ratios, + merge_dtype, + save_dtype, + args.mem_eff_load_save, ) else: state_dict = merge_to_flux_model_diffusers( - args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + args.loading_device, + args.working_device, + args.flux_model, + args.models, + args.ratios, + merge_dtype, + save_dtype, + args.mem_eff_load_save, ) if args.no_metadata: @@ -517,7 +547,7 @@ def str_to_dtype(p): ) logger.info(f"saving FLUX model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, sai_metadata) + save_to_file(args.save_to, state_dict, save_dtype, sai_metadata, args.mem_eff_load_save) else: state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) @@ -546,14 +576,14 @@ def setup_parser() -> argparse.ArgumentParser: "--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], - help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + help="precision in saving, same to merging if omitted. supported types: " + "float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz" + " / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", ) parser.add_argument( "--precision", type=str, default="float", - choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", ) parser.add_argument( @@ -562,6 +592,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする", ) + parser.add_argument( + "--mem_eff_load_save", + action="store_true", + help="use custom memory efficient load and save functions for FLUX.1 model" + " / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する", + ) parser.add_argument( "--loading_device", type=str, From a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Aug 2024 21:44:10 +0900 Subject: [PATCH 081/163] update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0108ada59..7b1d9cc6c 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The command to install PyTorch is as follows: Aug 27, 2024 (update 2): In FLUX.1 LoRA training, when `--fp8_base` is specified, the FLUX.1 model file with fp8 (`float8_e4m3fn` type) can be loaded directly. Also, in `flux_minimal_inference.py`, it is possible to load it by specifying `fp8 (float8_e4m3fn)` in `--flux_dtype`. -In `flux_merge_lora.py`, you can now specify the precision at save time with `fp8` (see `--help` for details). Also, if you do not specify the merge model, only the model type conversion will be performed. +In `flux_merge_lora.py`, you can now specify `fp8` for the save precision (see `--help` for details). Also, if you do not specify the merge model, only the dtype conversion will be performed. Aug 27, 2024: From 6c0e8a5a1740dbd50a0a45ec1f08983877605cd7 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 29 Aug 2024 14:50:29 +0800 Subject: [PATCH 082/163] make guidance_scale keep float in args --- flux_train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index 4a63c2de4..354a8c6f3 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -324,7 +324,8 @@ def get_noise_pred_and_target( img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) # get guidance - guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) + # ensure guidance_scale in args is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # ensure the hidden state will require grad if args.gradient_checkpointing: From a0cfb0894c4be4ea27412e4c12ed13f68b57094b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 29 Aug 2024 21:20:33 +0900 Subject: [PATCH 083/163] Cleaned up README --- README.md | 281 +++++++++++++++++++++++++++--------------------------- 1 file changed, 143 insertions(+), 138 deletions(-) diff --git a/README.md b/README.md index 7b1d9cc6c..a73eead0b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ This repository contains training, generation and utility scripts for Stable Diffusion. -## FLUX.1 LoRA training (WIP) +## FLUX.1 training (WIP) This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. @@ -9,127 +9,24 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` -Aug 27, 2024 (update 2): -In FLUX.1 LoRA training, when `--fp8_base` is specified, the FLUX.1 model file with fp8 (`float8_e4m3fn` type) can be loaded directly. Also, in `flux_minimal_inference.py`, it is possible to load it by specifying `fp8 (float8_e4m3fn)` in `--flux_dtype`. +- [FLUX.1 LoRA training](#flux1-lora-training) + - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) + - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) + - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) +- [FLUX.1 fine-tuning](#flux1-fine-tuning) + - [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning) +- [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models) +- [Convert FLUX LoRA](#convert-flux-lora) +- [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) +- [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) -In `flux_merge_lora.py`, you can now specify `fp8` for the save precision (see `--help` for details). Also, if you do not specify the merge model, only the dtype conversion will be performed. - -Aug 27, 2024: - -- FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI. - - `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. -- `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. - -- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is automatically enabled. - -Aug 25, 2024: -Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. -Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` - -Aug 24, 2024 (update 2): - -__Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). - -The number of parameters may increase slightly, so the expressiveness may increase, but the training time may be longer. No detailed verification has been done. - -This implementation is experimental, so it may be deprecated or changed in the future. - -The .safetensors file of the trained model is compatible with the normal LoRA model of sd-scripts, so it should be usable in inference environments such as ComfyUI as it is. Also, converting it to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size. It should be no problem to convert it if you use it in the inference environment. - -Technical details: In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter. - -The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. - -Aug 24, 2024: -Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified. - -Aug 22, 2024 (update 2): -Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option. - -Added a script to extract LoRA from the difference between the two models of FLUX.1. Use `networks/flux_extract_lora.py`. See `--help` for details. Normally, more than 50GB of memory is required, but specifying the `--mem_eff_safe_open` option significantly reduces memory usage. However, this option is a custom implementation, so unexpected problems may occur. Please always check if the model is loaded correctly. - -Aug 22, 2024: -Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. - -`--disable_mmap_load_safetensors` option now works in `flux_train.py`. It speeds up model loading during training in WSL2. It is also effective in reducing memory usage when loading models during multi-GPU training. Please always check if the model is loaded correctly, as it uses a custom implementation of safetensors loading. - - -Aug 21, 2024 (update 3): -- There is a bug that `--full_bf16` option is enabled even if it is not specified in `flux_train.py`. The bug will be fixed sooner. __Please specify the `--full_bf16` option explicitly, especially when training with 24GB VRAM.__ -- Stochastic rounding is now implemented when `--fused_backward_pass` is specified. The implementation is -based on the code provided by 2kpr. Thank you so much! - - With this change, `--fused_backward_pass` is recommended over `--blockwise_fused_optimizers` when `--full_bf16` is specified. - - Please note that `--fused_backward_pass` is only supported with Adafactor. -- The sample command in [FLUX.1 fine-tuning](#flux1-fine-tuning) is updated to reflect these changes. -- Fixed `--single_blocks_to_swap` is not working in `flux_train.py`. - -Aug 21, 2024 (update 2): -Fixed an error in applying mask in Attention. The attention mask was float, but it should be bool. - -Added a script `convert_flux_lora.py` to convert LoRA between sd-scripts format (BFL-based) and AI-toolkit format (Diffusers-based). See `--help` for details. BFL-based LoRA has a large module, so converting it to Diffusers format may reduce temporary memory usage in the inference environment. Note that re-conversion will increase the size of LoRA. - - -Aug 21, 2024: -The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed. - -Aug 20, 2024 (update 3): -__Experimental__ The multi-resolution training is now supported with caching latents to disk. - -The cache files now hold latents for multiple resolutions. Since the latents are appended to the current cache file, it is recommended to delete the cache file in advance (if not, the old latents is kept in .npz file). - -See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. - -Aug 20, 2024 (update 2): -`flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015! - -Aug 20, 2024: -FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution). - -The script seems to support multi-resolution even in the current version, ~~if `--cache_latents_to_disk` is not specified~~ -> `--cache_latents_to_disk` is now supported for multi-resolution training. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. - -We will support multi-resolution caching to disk in the near future. - -Aug 19, 2024: -In `flux_train.py`, the memory consumption during model saving is reduced when `--save_precision` is set to the same value as `--mixed_precision` (about 22GB). Please set the same value unless there is a reason. - -An experimental option `--mem_eff_save` is also added. When specified, it can further reduce memory consumption (about 22GB), but since it is a custom implementation, unexpected problems may occur. We do not recommend using it unless you are familiar with the code. - -Aug 18, 2024: -Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. - -Aug 17, 2024: -Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training. - -Aug 16, 2024: - -Added a script `networks/flux_merge_lora.py` to merge LoRA into FLUX.1 checkpoint. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. - -FLUX.1 schnell model based training is now supported (but not tested). If the name of the model file contains `schnell`, the model is treated as a schnell model. - -Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. The default is 512 in dev and 256 in schnell. - -Previously, when `--max_token_length` was specified, that value was used, and 512 was used when omitted (default). Therefore, there is no impact if `--max_token_length` was not specified. If `--max_token_length` was specified, please specify `--t5xxl_max_token_length` instead. `--max_token_length` is ignored during FLUX.1 training. - -Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified. - -Aug 13, 2024: - -__Experimental__ A network argument `train_blocks` is added to `lora_flux`. This is to select the target blocks of LoRA from FLUX double blocks and single blocks. Specify like `--network_args "train_blocks=single"`. `all` trains both double blocks and single blocks, `double` trains only double blocks, and `single` trains only single blocks. The default (omission) is `all`. - -This argument is available even if `--split_mode` is not specified. - -__Experimental__ `--split_mode` option is added to `flux_train_network.py`. This splits FLUX into double blocks and single blocks for training. By enabling gradients only for the single blocks part, memory usage is reduced. When this option is specified, you need to specify `"train_blocks=single"` in the network arguments. - -This option enables training with 12GB VRAM GPUs, but the training speed is 2-3 times slower than the default. - -Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before. - -Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. +### FLUX.1 LoRA training +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. -### FLUX.1 LoRA training +FLUX.1 model, CLIP-L, and T5XXL models are recommended to be in bf16/fp16 format. If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format. -We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. +Sample command is below. It will work with 24GB VRAM GPUs. ``` accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py @@ -137,45 +34,106 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 ---network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base +--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml ---output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid ---model_prediction_type raw --guidance_scale 1.0 --loss_type l2 +--output_dir path/to/output/dir --output_name flux-lora-name +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 ``` (The command is multi-line for readability. Please combine it into one line.) The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: ``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_mode` and `train_blocks=single` options. Please use settings like below: ``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` -LoRAs for Text Encoders are not tested yet. +We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. + +The trained LoRA model can be used with ComfyUI. + +#### Key Options for FLUX.1 LoRA training -We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows: +There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome. -- `--timestep_sampling` is the method to sample timesteps (0-1): `sigma` (sigma-based, same as SD3), `uniform` (uniform random), or `sigmoid` (sigmoid of random normal, same as x-flux). +- `--timestep_sampling` is the method to sample timesteps (0-1): + - `sigma`: sigma-based, same as SD3 + - `uniform`: uniform random + - `sigmoid`: sigmoid of random normal, same as x-flux, AI-toolkit etc. + - `shift`: shifts the value of sigmoid of normal distribution random number - `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform. -- `--model_prediction_type` is how to interpret and process the model prediction: `raw` (use as is, same as x-flux), `additive` (add to noisy input), `sigma_scaled` (apply sigma scaling, same as SD3). + - This option is effective even when`--timestep_sampling shift` is specified. + - Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. +- `--model_prediction_type` is how to interpret and process the model prediction: + - `raw`: use as is, same as x-flux + - `additive`: add to noisy input + - `sigma_scaled`: apply sigma scaling, same as SD3 - `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3). -`--loss_type` may be useful for FLUX.1 training. The default is `l2`. +The existing `--loss_type` option may be useful for FLUX.1 training. The default is `l2`. -In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. +~~In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. ~~ -additional note (Aug 11): A quick check shows that the settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). This seems to be a good starting point. Thanks to Ostris for the great work! +In our experiments, `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` with `--loss_type l2` seems to work better than other settings. + +The settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). Other settings may work better, so please try different settings. -We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. +Other options are described below. -The trained LoRA model can be used with ComfyUI. +#### Distribution of timesteps + +`--timestep_sampling` and `--sigmoid_scale`, `--discrete_flow_shift` adjust the distribution of timesteps. The distribution is shown in the figures below. + +The effect of `--discrete_flow_shift` with `--timestep_sampling shift` (when `--sigmoid_scale` is not specified, the default is 1.0): + +The difference between `--timestep_sampling uniform` and `--timestep_sampling sigma`: + +The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--timestep_sampling sigmoid` is specified, `--discrete_flow_shift` is ignored): + +#### Key Features for FLUX.1 LoRA training + +1. CLIP-L LoRA Support: + - FLUX.1 LoRA training now supports CLIP-L LoRA. + - Remove `--network_train_unet_only` from your command. + - T5XXL is not trained. Its output is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. + - The trained LoRA can be used with ComfyUI. + - Note: `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. + +2. Experimental FP8/FP16 mixed training: + - `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L. + - FLUX can be trained with fp8, and CLIP-L can be trained with bf16/fp16. + - When specifying this option, the `--fp8_base` option is automatically enabled. + +3. Split Q/K/V Projection Layers (Experimental): + - Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them. + - Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). + - May increase expressiveness but also training time. + - The trained model is compatible with normal LoRA models in sd-scripts and can be used in environments like ComfyUI. + - Converting to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size. + +4. T5 Attention Mask Application: + - T5 attention mask is applied when `--apply_t5_attn_mask` is specified. + - Now applies mask when encoding T5 and in the attention of Double and Single Blocks + - Affects fine-tuning, LoRA training, and inference in `flux_minimal_inference.py`. + +5. Multi-resolution Training Support: + - FLUX.1 now supports multi-resolution training, even with caching latents to disk. + + +Technical details of Q/K/V split: + +In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter. + +The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. + +### Inference for FLUX.1 with LoRA model The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. @@ -185,6 +143,8 @@ python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safete ### FLUX.1 fine-tuning +The memory-efficient training with block swap is based on 2kpr's implementation. Thanks to 2kpr! + Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended. ``` @@ -195,15 +155,13 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name --learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1 --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" ---timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 +--lr_scheduler constant_with_warmup --max_grad_norm 0.0 +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --fused_backward_pass --double_blocks_to_swap 6 --cpu_offload_checkpointing --full_bf16 ``` +(The command is multi-line for readability. Please combine it into one line.) -(Combine the command into one line.) - -Sample image generation during training is not tested yet. - -Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. +Options are almost the same as LoRA training. The difference is `--full_bf16`, `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. `--full_bf16` enables the training with bf16 (weights and gradients). @@ -223,6 +181,53 @@ Swap 6 double blocks and use cpu offload checkpointing may be a good starting po The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results. +#### Key Features for FLUX.1 fine-tuning + +1. Sample Image Generation: + - Sample image generation during training is now supported. + - The prompts are cached and used for generation if `--cache_latents` is specified. So changing the prompts during training will not affect the generated images. + - Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. + - Note: It will be very slow when `--split_mode` is specified. + +2. Experimental Memory-Efficient Saving: + - `--mem_eff_save` option can further reduce memory consumption during model saving (about 22GB). + - This is a custom implementation and may cause unexpected issues. Use with caution. + +3. T5XXL Token Length Control: + - Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. + - Default is 512 in dev and 256 in schnell models. + +4. Multi-GPU Training Support: + - Note: `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. + +5. Disable mmap Load for Safetensors: + - `--disable_mmap_load_safetensors` option now works in `flux_train.py`. + - Speeds up model loading during training in WSL2. + - Effective in reducing memory usage when loading models during multi-GPU training. + + +### Extract LoRA from FLUX.1 Models + +Script: `networks/flux_extract_lora.py` + +Extracts LoRA from the difference between two FLUX.1 models. + +Offers memory-efficient option with `--mem_eff_safe_open`. + +CLIP-L LoRA is not supported. + +### Convert FLUX LoRA + +Script: `convert_flux_lora.py` + +Converts LoRA between sd-scripts format (BFL-based) and AI-toolkit format (Diffusers-based). + +If you use LoRA in the inference environment, converting it to AI-toolkit format may reduce temporary memory usage. + +Note that re-conversion will increase the size of LoRA. + +CLIP-L LoRA is not supported. + ### Merge LoRA to FLUX.1 checkpoint `networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ From daa6ad516581872aa6acaa15c0d24aad4f998838 Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 29 Aug 2024 21:25:30 +0900 Subject: [PATCH 084/163] Update README.md --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a73eead0b..6e2ae3376 100644 --- a/README.md +++ b/README.md @@ -77,9 +77,9 @@ There are many unknown points in FLUX.1 training, so some settings can be specif The existing `--loss_type` option may be useful for FLUX.1 training. The default is `l2`. -~~In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. ~~ +~~In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted.~~ -In our experiments, `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` with `--loss_type l2` seems to work better than other settings. +In our experiments, `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type) seems to work better. The settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). @@ -92,10 +92,13 @@ Other options are described below. `--timestep_sampling` and `--sigmoid_scale`, `--discrete_flow_shift` adjust the distribution of timesteps. The distribution is shown in the figures below. The effect of `--discrete_flow_shift` with `--timestep_sampling shift` (when `--sigmoid_scale` is not specified, the default is 1.0): +![Figure_2](https://github.com/user-attachments/assets/d9de42f9-f17d-40da-b88d-d964402569c6) -The difference between `--timestep_sampling uniform` and `--timestep_sampling sigma`: +The difference between `--timestep_sampling sigmoid` and `--timestep_sampling uniform` (when `--timestep_sampling sigmoid` or `uniform` is specified, `--discrete_flow_shift` is ignored): +![Figure_3](https://github.com/user-attachments/assets/27029009-1f5d-4dc0-bb24-13d02ac4fdad) The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--timestep_sampling sigmoid` is specified, `--discrete_flow_shift` is ignored): +![Figure_4](https://github.com/user-attachments/assets/08a2267c-e47e-48b7-826e-f9a080787cdc) #### Key Features for FLUX.1 LoRA training From 8ecf0fc4bfd1b03cfc6fd4055af0b3363f5d1f38 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 29 Aug 2024 22:10:57 +0900 Subject: [PATCH 085/163] Refactor code to ensure args.guidance_scale is always a float #1525 --- flux_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train.py b/flux_train.py index 410728d44..32a36f036 100644 --- a/flux_train.py +++ b/flux_train.py @@ -688,8 +688,8 @@ def optimizer_hook(parameter: torch.Tensor): packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) - # get guidance - guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # call model l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds From 8fdfd8c857a88aaa78ac9c2488432ef8115982f2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 29 Aug 2024 22:26:29 +0900 Subject: [PATCH 086/163] Update safetensors to version 0.4.4 in requirements.txt #1524 --- README.md | 7 +++++++ requirements.txt | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 6e2ae3376..30264e738 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,13 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +### Recent Updates + +Aug 29, 2024: +Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated. + +### Contents + - [FLUX.1 LoRA training](#flux1-lora-training) - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) diff --git a/requirements.txt b/requirements.txt index 4ee19b3ee..4c1bc3922 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ bitsandbytes==0.43.3 prodigyopt==1.0 lion-pytorch==0.0.6 tensorboard -safetensors==0.4.2 +safetensors==0.4.4 # gradio==3.16.2 altair==4.2.2 easygui==0.98.3 From 34f2315047f8d5b89b7a8a6093bb56679bff13c3 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 29 Aug 2024 22:33:37 +0800 Subject: [PATCH 087/163] fix: text_encoder_conds referenced before assignment --- train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 048c7e7bd..628c421cb 100644 --- a/train_network.py +++ b/train_network.py @@ -1081,12 +1081,12 @@ def remove_model(old_ckpt_name): # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) + text_encoder_conds = [] text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs if ( - text_encoder_conds is None - or len(text_encoder_conds) == 0 + len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder ): From 35882f8d5bbd076a97622cf6193c988621481803 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 29 Aug 2024 23:03:43 +0800 Subject: [PATCH 088/163] fix --- train_network.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index 628c421cb..4204bce34 100644 --- a/train_network.py +++ b/train_network.py @@ -1112,10 +1112,14 @@ def remove_model(old_ckpt_name): if args.full_fp16: encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] - # if encoded_text_encoder_conds is not None, update cached text_encoder_conds - for i in range(len(encoded_text_encoder_conds)): - if encoded_text_encoder_conds[i] is not None: - text_encoder_conds[i] = encoded_text_encoder_conds[i] + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( From 2a3aefb4e44dce1f189677d0a996ba0244633956 Mon Sep 17 00:00:00 2001 From: Nando Metzger <42088121+nandometzger@users.noreply.github.com> Date: Fri, 30 Aug 2024 08:15:05 +0200 Subject: [PATCH 089/163] Update train_util.py, bug fix --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..0fec565db 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1489,7 +1489,7 @@ def read_caption(img_path, caption_extension, enable_wildcard): def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): logger.warning(f"not directory: {subset.image_dir}") - return [], [] + return [], [], [] info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE) use_cached_info_for_subset = subset.cache_info From 3a6154b7b0dbcae82d24adacf5a76f75288b98f4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 31 Aug 2024 06:21:16 +0000 Subject: [PATCH 090/163] Bump opencv-python from 4.7.0.68 to 4.8.1.78 Bumps [opencv-python](https://github.com/opencv/opencv-python) from 4.7.0.68 to 4.8.1.78. - [Release notes](https://github.com/opencv/opencv-python/releases) - [Commits](https://github.com/opencv/opencv-python/commits) --- updated-dependencies: - dependency-name: opencv-python dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e99775b8a..977c5cd91 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ transformers==4.36.2 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 -opencv-python==4.7.0.68 +opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 bitsandbytes==0.43.0 From 25c9040f4fbbcbddc0297895369337846152fea4 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 31 Aug 2024 03:05:19 +0800 Subject: [PATCH 091/163] Update flux_train_utils.py --- library/flux_train_utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index a8e94ac00..735bcced7 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -371,7 +371,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents, noise, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - bsz = latents.shape[0] + bsz, _, H, W = latents.shape sigmas = None if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": @@ -392,6 +392,16 @@ def get_noisy_model_input_and_timesteps( timesteps = logits_norm.sigmoid() timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "flux_shift": + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + mu=get_lin_function(y1=0.5, y2=1.15)((H//2) * (W//2)) + timesteps = time_shift(mu, 1.0, timesteps) + t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 noisy_model_input = (1 - t) * latents + t * noise @@ -571,7 +581,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--timestep_sampling", - choices=["sigma", "uniform", "sigmoid", "shift"], + choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"], default="sigma", help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", From 1bcf8d600bfb9f4314a41a12a5e7b272a17ceaed Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 1 Sep 2024 01:33:04 +0000 Subject: [PATCH 092/163] Bump crate-ci/typos from 1.19.0 to 1.24.3 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.19.0 to 1.24.3. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.19.0...v1.24.3) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/typos.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index e8b06483f..0149dcdd3 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -18,4 +18,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.19.0 + uses: crate-ci/typos@v1.24.3 From ef510b3cb94427d72df681389e1214251813b1a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Sun, 1 Sep 2024 17:41:01 +0800 Subject: [PATCH 093/163] Sd3 freeze x_block (#1417) * Update sd3_train.py * add freeze block lr * Update train_util.py * update --- library/train_util.py | 21 +++++++++++++++++++++ sd3_train.py | 9 ++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 989758ad5..74aae0a79 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3246,6 +3246,12 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): default=None, help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", ) + parser.add_argument( + "--num_last_block_to_freeze", + type=int, + default=None, + help="num_last_block_to_freeze", + ) def add_optimizer_arguments(parser: argparse.ArgumentParser): @@ -5758,6 +5764,21 @@ def sample_image_inference( pass +def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"): + + filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name] + print(f"filtered_blocks: {len(filtered_blocks)}") + + num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze) + + print(f"freeze_blocks: {num_blocks_to_freeze}") + + start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) + + for i in range(start_freezing_from, len(filtered_blocks)): + _, param = filtered_blocks[i] + param.requires_grad = False + # endregion diff --git a/sd3_train.py b/sd3_train.py index 3b6c8a118..ce9500b0b 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -368,12 +368,19 @@ def train(args): vae.eval() vae.to(accelerator.device, dtype=vae_dtype) + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared + + if args.num_last_block_to_freeze: + train_util.freeze_blocks(mmdit,num_last_block_to_freeze=args.num_last_block_to_freeze) + training_models = [] params_to_optimize = [] # if train_unet: training_models.append(mmdit) # if block_lrs is None: - params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate}) + params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate}) # else: # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) From 92e7600cc2fea604321004f260e7db76c764f388 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Sep 2024 18:57:07 +0900 Subject: [PATCH 094/163] Move freeze_blocks to sd3_train because it's only for sd3 --- README.md | 3 +++ library/train_util.py | 21 --------------------- sd3_train.py | 22 ++++++++++++++++++++-- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 30264e738..d96367194 100644 --- a/README.md +++ b/README.md @@ -309,6 +309,9 @@ resolution = [512, 512] SD3 training is done with `sd3_train.py`. +__Sep 1, 2024__: +- `--num_last_block_to_freeze` is added to `sd3_train.py`. This option is to freeze the last n blocks of the MMDiT. See [#1417](https://github.com/kohya-ss/sd-scripts/pull/1417) for details. Thanks to sdbds! + __Jul 27, 2024__: - Latents and text encoder outputs caching mechanism is refactored significantly. - Existing cache files for SD3 need to be recreated. Please delete the previous cache files. diff --git a/library/train_util.py b/library/train_util.py index 74aae0a79..989758ad5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3246,12 +3246,6 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): default=None, help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", ) - parser.add_argument( - "--num_last_block_to_freeze", - type=int, - default=None, - help="num_last_block_to_freeze", - ) def add_optimizer_arguments(parser: argparse.ArgumentParser): @@ -5764,21 +5758,6 @@ def sample_image_inference( pass -def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"): - - filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name] - print(f"filtered_blocks: {len(filtered_blocks)}") - - num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze) - - print(f"freeze_blocks: {num_blocks_to_freeze}") - - start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) - - for i in range(start_freezing_from, len(filtered_blocks)): - _, param = filtered_blocks[i] - param.requires_grad = False - # endregion diff --git a/sd3_train.py b/sd3_train.py index ce9500b0b..87011b215 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -373,7 +373,20 @@ def train(args): mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared if args.num_last_block_to_freeze: - train_util.freeze_blocks(mmdit,num_last_block_to_freeze=args.num_last_block_to_freeze) + # freeze last n blocks of MM-DIT + block_name = "x_block" + filtered_blocks = [(name, param) for name, param in mmdit.named_parameters() if block_name in name] + accelerator.print(f"filtered_blocks: {len(filtered_blocks)}") + + num_blocks_to_freeze = min(len(filtered_blocks), args.num_last_block_to_freeze) + + accelerator.print(f"freeze_blocks: {num_blocks_to_freeze}") + + start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) + + for i in range(start_freezing_from, len(filtered_blocks)): + _, param = filtered_blocks[i] + param.requires_grad = False training_models = [] params_to_optimize = [] @@ -1033,12 +1046,17 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", ) - parser.add_argument( "--skip_latents_validity_check", action="store_true", help="skip latents validity check / latentsの正当性チェックをスキップする", ) + parser.add_argument( + "--num_last_block_to_freeze", + type=int, + default=None, + help="freeze last n blocks of MM-DIT / MM-DITの最後のnブロックを凍結する", + ) return parser From 4f6d915d15262447b1049a78a55678b2825784a3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Sep 2024 19:12:29 +0900 Subject: [PATCH 095/163] update help and README --- README.md | 5 +++++ library/flux_train_utils.py | 8 ++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d96367194..331951ef4 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 1, 2024: +- `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds! + - This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`. + Aug 29, 2024: Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated. @@ -73,6 +77,7 @@ There are many unknown points in FLUX.1 training, so some settings can be specif - `uniform`: uniform random - `sigmoid`: sigmoid of random normal, same as x-flux, AI-toolkit etc. - `shift`: shifts the value of sigmoid of normal distribution random number + - `flux_shift`: shifts the value of sigmoid of normal distribution random number, depending on the resolution (same as FLUX.1 dev inference). `--discrete_flow_shift` is ignored when `flux_shift` is specified. - `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform. - This option is effective even when`--timestep_sampling shift` is specified. - Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 735bcced7..9dad4baa2 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -371,7 +371,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents, noise, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - bsz, _, H, W = latents.shape + bsz, _, h, w = latents.shape sigmas = None if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": @@ -399,7 +399,7 @@ def get_noisy_model_input_and_timesteps( logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() - mu=get_lin_function(y1=0.5, y2=1.15)((H//2) * (W//2)) + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) timesteps = time_shift(mu, 1.0, timesteps) t = timesteps.view(-1, 1, 1, 1) @@ -583,8 +583,8 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): "--timestep_sampling", choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"], default="sigma", - help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." - " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。", ) parser.add_argument( "--sigmoid_scale", From 6abacf04da756808ffca567f6660445ecdf478bd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 2 Sep 2024 13:05:26 +0900 Subject: [PATCH 096/163] update README --- README.md | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 331951ef4..5dd916aa0 100644 --- a/README.md +++ b/README.md @@ -184,7 +184,7 @@ Options are almost the same as LoRA training. The difference is `--full_bf16`, ` `--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. -`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. +`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. Please see the next chapter for details. `--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. @@ -198,24 +198,32 @@ The learning rate and the number of epochs are not optimized yet. Please adjust #### Key Features for FLUX.1 fine-tuning -1. Sample Image Generation: +1. Technical details of double/single block swap: + - Reduce memory usage by transferring double and single blocks of FLUX.1 from GPU to CPU when they are not needed. + - During forward pass, the weights of the blocks that have finished calculation are transferred to CPU, and the weights of the blocks to be calculated are transferred to GPU. + - The same is true for the backward pass, but the order is reversed. The gradients remain on the GPU. + - Since the transfer between CPU and GPU takes time, the training will be slower. + - `--double_blocks_to_swap` and `--single_blocks_to_swap` specify the number of blocks to swap. For example, `--double_blocks_to_swap 6` swaps 6 blocks at each step of training, but the remaining 13 blocks are always on the GPU. + - About 640MB of memory can be saved per double block, and about 320MB of memory can be saved per single block. + +2. Sample Image Generation: - Sample image generation during training is now supported. - The prompts are cached and used for generation if `--cache_latents` is specified. So changing the prompts during training will not affect the generated images. - Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. - Note: It will be very slow when `--split_mode` is specified. -2. Experimental Memory-Efficient Saving: +3. Experimental Memory-Efficient Saving: - `--mem_eff_save` option can further reduce memory consumption during model saving (about 22GB). - This is a custom implementation and may cause unexpected issues. Use with caution. -3. T5XXL Token Length Control: +4. T5XXL Token Length Control: - Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. - Default is 512 in dev and 256 in schnell models. -4. Multi-GPU Training Support: +5. Multi-GPU Training Support: - Note: `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. -5. Disable mmap Load for Safetensors: +6. Disable mmap Load for Safetensors: - `--disable_mmap_load_safetensors` option now works in `flux_train.py`. - Speeds up model loading during training in WSL2. - Effective in reducing memory usage when loading models during multi-GPU training. From b65ae9b439e4324359014d6d720aa01def3a19fc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 4 Sep 2024 21:33:17 +0900 Subject: [PATCH 097/163] T5XXL LoRA training, fp8 T5XXL support --- README.md | 45 +++++++++++---- flux_train_network.py | 112 +++++++++++++++++++++++++++++------- library/flux_train_utils.py | 23 ++++++-- library/flux_utils.py | 9 ++- library/strategy_flux.py | 13 ++++- networks/lora_flux.py | 39 ++++++++++--- train_network.py | 48 ++++++++++------ 7 files changed, 222 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index 5dd916aa0..840655705 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,11 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 4, 2024: +- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. +- In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching. +- Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training. + Sep 1, 2024: - `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds! - This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`. @@ -41,8 +46,8 @@ Sample command is below. It will work with 24GB VRAM GPUs. ``` accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py ---pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors ---ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors +--ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base @@ -72,6 +77,11 @@ The trained LoRA model can be used with ComfyUI. There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome. +- `--pretrained_model_name_or_path` is the path to the pretrained model (FLUX.1). bf16 (original BFL model) is recommended (`flux1-dev.safetensors` or `flux1-dev.sft`). If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format. +- `--clip_l` is the path to the CLIP-L model. +- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching. +- `--ae` is the path to the autoencoder model (`ae.safetensors` or `ae.sft`). + - `--timestep_sampling` is the method to sample timesteps (0-1): - `sigma`: sigma-based, same as SD3 - `uniform`: uniform random @@ -114,16 +124,29 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times #### Key Features for FLUX.1 LoRA training -1. CLIP-L LoRA Support: - - FLUX.1 LoRA training now supports CLIP-L LoRA. +1. CLIP-L and T5XXL LoRA Support: + - FLUX.1 LoRA training now supports CLIP-L and T5XXL LoRA training. - Remove `--network_train_unet_only` from your command. - - T5XXL is not trained. Its output is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. + - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time. + - T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. - The trained LoRA can be used with ComfyUI. - - Note: `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. + - Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet. + + | trained LoRA|option|network_args|cache_text_encoder_outputs (*1)| + |---|---|---|---| + |FLUX.1|`--network_train_unet_only`|-|o| + |FLUX.1 + CLIP-L|-|-|o (*2)| + |FLUX.1 + CLIP-L + T5XXL|-|`train_t5xxl=True`|-| + |CLIP-L (*3)|`--network_train_text_encoder_only`|-|o (*2)| + |CLIP-L + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-| + + - *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - *2: T5XXL output can be cached for CLIP-L LoRA training. + - *3: Not tested yet. 2. Experimental FP8/FP16 mixed training: - - `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L. - - FLUX can be trained with fp8, and CLIP-L can be trained with bf16/fp16. + - `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L/T5XXL. + - FLUX can be trained with fp8, and CLIP-L/T5XXL can be trained with bf16/fp16. - When specifying this option, the `--fp8_base` option is automatically enabled. 3. Split Q/K/V Projection Layers (Experimental): @@ -153,7 +176,7 @@ The compatibility of the saved model (state dict) is ensured by concatenating th The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. ``` -python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 +python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` ### FLUX.1 fine-tuning @@ -164,7 +187,7 @@ Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GP ``` accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py ---pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.safetensors --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name @@ -256,7 +279,7 @@ CLIP-L LoRA is not supported. `networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ ``` -python networks/flux_merge_lora.py --flux_model flux1-dev.sft --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu +python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu ``` You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`. diff --git a/flux_train_network.py b/flux_train_network.py index 354a8c6f3..2fc0f3234 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -43,13 +43,9 @@ def assert_extra_args(self, args, train_dataset_group): train_dataset_group.is_text_encoder_output_cacheable() ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - # assert ( - # args.network_train_unet_only or not args.cache_text_encoder_outputs - # ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" - if not args.network_train_unet_only: - logger.info( - "network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません" - ) + # prepare CLIP-L/T5XXL training flags + self.train_clip_l = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") @@ -63,12 +59,10 @@ def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models name = self.get_flux_model_name(args) - # if we load to cpu, flux.to(fp8) takes a long time - if args.fp8_base: - loading_dtype = None # as is - else: - loading_dtype = weight_dtype + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future model = flux_utils.load_flow_model( name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors ) @@ -85,9 +79,21 @@ def load_target_model(self, args, weight_dtype, accelerator): clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) clip_l.eval() + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + # loading t5xxl to cpu takes a long time, so we should load to gpu in future - t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) @@ -154,25 +160,35 @@ def get_latents_caching_strategy(self, args): def get_text_encoding_strategy(self, args): return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + def get_models_for_text_encoding(self, args, accelerator, text_encoders): if args.cache_text_encoder_outputs: - if self.is_train_text_encoder(args): + if self.train_clip_l and not self.train_t5xxl: return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached else: - return text_encoders # ignored + return None # no text encoders are needed for encoding because both are cached else: return text_encoders # both CLIP-L and T5XXL are needed for encoding def get_text_encoders_train_flags(self, args, text_encoders): - return [True, False] if self.is_train_text_encoder(args) else [False, False] + return [self.train_clip_l, self.train_t5xxl] def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True return strategy_flux.FluxTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, None, False, - is_partial=self.is_train_text_encoder(args), + is_partial=self.train_clip_l or self.train_t5xxl, apply_t5_attn_mask=args.apply_t5_attn_mask, ) else: @@ -193,8 +209,16 @@ def cache_text_encoder_outputs_if_needed( # When TE is not be trained, it will not be prepared so we need to use explicit autocast logger.info("move text encoders to gpu") - text_encoders[0].to(accelerator.device, dtype=weight_dtype) - text_encoders[1].to(accelerator.device, dtype=weight_dtype) + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device) + + if text_encoders[1].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[1].to(weight_dtype) + with accelerator.autocast(): dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) @@ -235,7 +259,7 @@ def cache_text_encoder_outputs_if_needed( else: # Text Encoderから毎回出力を取得するので、GPUに乗せておく text_encoders[0].to(accelerator.device, dtype=weight_dtype) - text_encoders[1].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device) # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -255,9 +279,12 @@ def cache_text_encoder_outputs_if_needed( # return noise_pred def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + if not args.split_mode: flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs + accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs ) return @@ -281,7 +308,7 @@ def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_a wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) clean_memory_on_device(accelerator.device) flux_train_utils.sample_images( - accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs + accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs ) clean_memory_on_device(accelerator.device) @@ -421,6 +448,47 @@ def update_metadata(self, metadata, args): def is_text_encoder_not_needed_for_training(self, args): return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0: # CLIP-L + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0: # CLIP-L + logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 9dad4baa2..0b5d4d90e 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -85,7 +85,7 @@ def sample_images( if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. - with torch.no_grad(): + with torch.no_grad(), accelerator.autocast(): for prompt_dict in prompts: sample_image_inference( accelerator, @@ -187,14 +187,27 @@ def sample_image_inference( tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + text_encoder_conds = [] if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: - te_outputs = sample_prompts_te_outputs[prompt] - else: + text_encoder_conds = sample_prompts_te_outputs[prompt] + print(f"Using cached text encoder outputs for prompt: {prompt}") + if text_encoders is not None: + print(f"Encoding prompt: {prompt}") tokens_and_masks = tokenize_strategy.tokenize(prompt) # strategy has apply_t5_attn_mask option - te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + print([x.shape if x is not None else None for x in encoded_text_encoder_conds]) + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] - l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds # sample image weight_dtype = ae.dtype # TOFO give dtype as argument diff --git a/library/flux_utils.py b/library/flux_utils.py index 680836168..7b0a41a8a 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -171,7 +171,9 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev return clip -def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel: +def load_t5xxl( + ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False +) -> T5EncoderModel: T5_CONFIG_JSON = """ { "architectures": [ @@ -217,6 +219,11 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi return t5xxl +def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype: + # nn.Embedding is the first layer, but it could be casted to bfloat16 or float32 + return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype + + def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 5d0839132..6c9ef5e4a 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -5,8 +5,7 @@ import numpy as np from transformers import CLIPTokenizer, T5TokenizerFast -from library import sd3_utils, train_util -from library import sd3_models +from library import flux_utils, train_util from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy from library.utils import setup_logging @@ -100,6 +99,8 @@ def __init__( super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) self.apply_t5_attn_mask = apply_t5_attn_mask + self.warn_fp8_weights = False + def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX @@ -144,6 +145,14 @@ def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List ): + if not self.warn_fp8_weights: + if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn: + logger.warning( + "T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs." + " / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。" + ) + self.warn_fp8_weights = True + flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy captions = [info.caption for info in infos] diff --git a/networks/lora_flux.py b/networks/lora_flux.py index fcb56a467..295267beb 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -330,6 +330,11 @@ def create_network( if split_qkv is not None: split_qkv = True if split_qkv == "True" else False + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -344,6 +349,7 @@ def create_network( conv_alpha=conv_alpha, train_blocks=train_blocks, split_qkv=split_qkv, + train_t5xxl=train_t5xxl, varbose=True, ) @@ -370,9 +376,10 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh else: weights_sd = torch.load(file, map_location="cpu") - # get dim/alpha mapping + # get dim/alpha mapping, and train t5xxl modules_dim = {} modules_alpha = {} + train_t5xxl = None for key, value in weights_sd.items(): if "." not in key: continue @@ -385,6 +392,12 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_dim[lora_name] = dim # logger.info(lora_name, value.size(), dim) + if train_t5xxl is None: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + # # split qkv # double_qkv_rank = None # single_qkv_rank = None @@ -413,6 +426,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_alpha=modules_alpha, module_class=module_class, split_qkv=split_qkv, + train_t5xxl=train_t5xxl, ) return network, weights_sd @@ -421,10 +435,10 @@ class LoRANetwork(torch.nn.Module): # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" - LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible def __init__( self, @@ -443,6 +457,7 @@ def __init__( modules_alpha: Optional[Dict[str, int]] = None, train_blocks: Optional[str] = None, split_qkv: bool = False, + train_t5xxl: bool = False, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -457,6 +472,7 @@ def __init__( self.module_dropout = module_dropout self.train_blocks = train_blocks if train_blocks is not None else "all" self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -469,12 +485,16 @@ def __init__( logger.info( f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" ) - if self.conv_lora_dim is not None: - logger.info( - f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" - ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) if self.split_qkv: logger.info(f"split qkv for LoRA") + if self.train_blocks is not None: + logger.info(f"train {self.train_blocks} blocks only") + if train_t5xxl: + logger.info(f"train T5XXL as well") # create module instances def create_modules( @@ -550,12 +570,15 @@ def create_modules( skipped_te = [] for i, text_encoder in enumerate(text_encoders): index = i + if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False + break + logger.info(f"create LoRA for Text Encoder {index+1}:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # create LoRA for U-Net if self.train_blocks == "all": diff --git a/train_network.py b/train_network.py index 4204bce34..a68ccfcc4 100644 --- a/train_network.py +++ b/train_network.py @@ -157,6 +157,9 @@ def sample_images(self, accelerator, args, epoch, global_step, device, vae, toke # region SD/SDXL + def post_process_network(self, args, accelerator, network, text_encoders, unet): + pass + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False @@ -237,6 +240,13 @@ def update_metadata(self, metadata, args): def is_text_encoder_not_needed_for_training(self, args): return False # use for sample images + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + # set top parameter requires_grad = True for gradient checkpointing works + text_encoder.text_model.embeddings.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + # endregion def train(self, args): @@ -329,7 +339,7 @@ def train(self, args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - self.assert_extra_args(args, train_dataset_group) + self.assert_extra_args(args, train_dataset_group) # may change some args # acceleratorを準備する logger.info("preparing accelerator") @@ -428,12 +438,15 @@ def train(self, args): ) args.scale_weight_norms = False + self.post_process_network(args, accelerator, network, text_encoders, unet) + + # apply network to unet and text_encoder train_unet = not args.network_train_text_encoder_only train_text_encoder = self.is_train_text_encoder(args) network.apply_to(text_encoder, unet, train_text_encoder, train_unet) if args.network_weights is not None: - # FIXME consider alpha of weights + # FIXME consider alpha of weights: this assumes that the alpha is not changed info = network.load_weights(args.network_weights) accelerator.print(f"load network weights from {args.network_weights}: {info}") @@ -533,7 +546,7 @@ def train(self, args): ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" accelerator.print("enable fp8 training for U-Net.") unet_weight_dtype = torch.float8_e4m3fn - + if not args.fp8_base_unet: accelerator.print("enable fp8 training for Text Encoder.") te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn @@ -545,17 +558,16 @@ def train(self, args): unet.requires_grad_(False) unet.to(dtype=unet_weight_dtype) - for t_enc in text_encoders: + for i, t_enc in enumerate(text_encoders): t_enc.requires_grad_(False) # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 if t_enc.device.type != "cpu": t_enc.to(dtype=te_weight_dtype) - if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"): - # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) - elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"): - t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + + # nn.Embedding not support FP8 + if te_weight_dtype != weight_dtype: + self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: @@ -596,12 +608,12 @@ def train(self, args): if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() - for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): + for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))): t_enc.train() # set top parameter requires_grad = True for gradient checkpointing works if frag: - t_enc.text_model.embeddings.requires_grad_(True) + self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc) else: unet.eval() @@ -1028,8 +1040,12 @@ def remove_model(old_ckpt_name): # log device and dtype for each model logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") - for t_enc in text_encoders: - logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}") + for i, t_enc in enumerate(text_encoders): + params_itr = t_enc.parameters() + params_itr.__next__() # skip the first parameter + params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings + param_3rd = params_itr.__next__() + logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}") clean_memory_on_device(accelerator.device) @@ -1085,11 +1101,7 @@ def remove_model(old_ckpt_name): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - if ( - len(text_encoder_conds) == 0 - or text_encoder_conds[0] is None - or train_text_encoder - ): + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: From b7cff0a7548e5e33f735f06293ba24119fdaa585 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 4 Sep 2024 21:35:47 +0900 Subject: [PATCH 098/163] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 840655705..c0acfa1d2 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The command to install PyTorch is as follows: ### Recent Updates Sep 4, 2024: -- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. +- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. See [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) for details. - In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching. - Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training. From 56cb2fc885d818e9c4493fb2843870d7a141db1c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 4 Sep 2024 23:15:27 +0900 Subject: [PATCH 099/163] support T5XXL LoRA, reduce peak memory usage #1560 --- flux_minimal_inference.py | 73 +++++++++++++++++++++++++++++++-------- networks/lora_flux.py | 2 +- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 56c1b1982..1c194e7c1 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -5,7 +5,7 @@ import math import os import random -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional import einops import numpy as np @@ -13,6 +13,7 @@ from tqdm import tqdm from PIL import Image import accelerate +from transformers import CLIPTextModel from library import device_utils from library.device_utils import init_ipex, get_preferred_device @@ -125,7 +126,7 @@ def do_sample( def generate_image( model, - clip_l, + clip_l: CLIPTextModel, t5xxl, ae, prompt: str, @@ -141,12 +142,13 @@ def generate_image( # make first noise with packed shape # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) + noise_dtype = torch.float32 if is_fp8(dtype) else dtype noise = torch.randn( 1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, - dtype=dtype, + dtype=noise_dtype, generator=torch.Generator(device=device).manual_seed(seed), ) @@ -166,9 +168,48 @@ def generate_image( clip_l = clip_l.to(device) t5xxl = t5xxl.to(device) with torch.no_grad(): - if is_fp8(clip_l_dtype) or is_fp8(t5xxl_dtype): - clip_l.to(clip_l_dtype) - t5xxl.to(t5xxl_dtype) + if is_fp8(clip_l_dtype): + param_itr = clip_l.parameters() + param_itr.__next__() # skip first + param_2nd = param_itr.__next__() + if param_2nd.dtype != clip_l_dtype: + logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") + clip_l.to(clip_l_dtype) # fp8 + clip_l.text_model.embeddings.to(dtype=torch.bfloat16) + + with accelerator.autocast(): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + + if is_fp8(t5xxl_dtype): + if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"): + logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + text_encoder.fp8_prepared = True + + t5xxl.to(t5xxl_dtype) + prepare_fp8(t5xxl.encoder, torch.bfloat16) + with accelerator.autocast(): _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask @@ -315,10 +356,10 @@ def is_fp8(dt): t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) t5xxl.eval() - if is_fp8(clip_l_dtype): - clip_l = accelerator.prepare(clip_l) - if is_fp8(t5xxl_dtype): - t5xxl = accelerator.prepare(t5xxl) + # if is_fp8(clip_l_dtype): + # clip_l = accelerator.prepare(clip_l) + # if is_fp8(t5xxl_dtype): + # t5xxl = accelerator.prepare(t5xxl) t5xxl_max_length = 256 if is_schnell else 512 tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) @@ -329,14 +370,16 @@ def is_fp8(dt): model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype - if is_fp8(flux_dtype): - model = accelerator.prepare(model) + # if is_fp8(flux_dtype): + # model = accelerator.prepare(model) + # if args.offload: + # model = model.to("cpu") # AE ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device) ae.eval() - if is_fp8(ae_dtype): - ae = accelerator.prepare(ae) + # if is_fp8(ae_dtype): + # ae = accelerator.prepare(ae) # LoRA lora_models: List[lora_flux.LoRANetwork] = [] @@ -360,7 +403,7 @@ def is_fp8(dt): lora_model.to(device) lora_models.append(lora_model) - + if not args.interactive: generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) else: diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 295267beb..ab9ccc4d8 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -392,7 +392,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_dim[lora_name] = dim # logger.info(lora_name, value.size(), dim) - if train_t5xxl is None: + if train_t5xxl is None or train_t5xxl is False: train_t5xxl = "lora_te3" in lora_name if train_t5xxl is None: From 90ed2dfb526168b2e77b8d367e928d8cc44b4278 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Sep 2024 08:39:29 +0900 Subject: [PATCH 100/163] feat: Add support for merging CLIP-L and T5XXL LoRA models --- README.md | 22 ++++- networks/flux_merge_lora.py | 182 ++++++++++++++++++++++++++++-------- 2 files changed, 163 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index c0acfa1d2..fa81f6c0f 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 5, 2024: +The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. + Sep 4, 2024: - T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. See [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) for details. - In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching. @@ -276,7 +279,7 @@ CLIP-L LoRA is not supported. ### Merge LoRA to FLUX.1 checkpoint -`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ +`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint, CLIP-L or T5XXL models. __The script is experimental.__ ``` python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu @@ -284,13 +287,24 @@ python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`. -`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`): +CLIP-L and T5XXL LoRA are supported. `--clip_l` and `--clip_l_save_to` are for CLIP-L, `--t5xxl` and `--t5xxl_save_to` are for T5XXL. Sample command is below. + +``` +--clip_l clip_l.safetensors --clip_l_save_to merged_clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --t5xxl_save_to merged_t5xxl.safetensors +``` + +FLUX.1, CLIP-L, and T5XXL can be merged together or separately for memory efficiency. + +An experimental option `--mem_eff_load_save` is available. This option is for memory-efficient loading and saving. It may also speed up loading and saving. + +`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`, `float32` will consume more memory): - 'cpu' / 'cpu': Uses >50GB of RAM, but works on any machine. - 'cuda' / 'cpu': Uses 24GB of VRAM, but requires 30GB of RAM. -- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cuda' / 'cpu'. +- 'cpu' / 'cuda': Uses 4GB of VRAM, but requires 50GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'. +- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'. -In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`. +`--save_precision` is the precision to save the merged model. In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`. The script can merge multiple LoRA models. If you want to merge multiple LoRA models, specify `--concat` option to work the merged LoRA model properly. diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index 2e0d4c297..5e100a3ba 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -2,6 +2,7 @@ import math import os import time +from typing import Any, Dict, Union import torch from safetensors import safe_open @@ -34,11 +35,11 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False): +def save_to_file(file_name, state_dict: Dict[str, Union[Any, torch.Tensor]], dtype, metadata, mem_eff_save=False): if dtype is not None: logger.info(f"converting to {dtype}...") for key in tqdm(list(state_dict.keys())): - if type(state_dict[key]) == torch.Tensor: + if type(state_dict[key]) == torch.Tensor and state_dict[key].dtype.is_floating_point: state_dict[key] = state_dict[key].to(dtype) logger.info(f"saving to: {file_name}") @@ -49,26 +50,76 @@ def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False): def merge_to_flux_model( - loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False + loading_device, + working_device, + flux_path: str, + clip_l_path: str, + t5xxl_path: str, + models, + ratios, + merge_dtype, + save_dtype, + mem_eff_load_save=False, ): # create module map without loading state_dict - logger.info(f"loading keys from FLUX.1 model: {flux_model}") lora_name_to_module_key = {} - with safe_open(flux_model, framework="pt", device=loading_device) as flux_file: - keys = list(flux_file.keys()) - for key in keys: - if key.endswith(".weight"): - module_name = ".".join(key.split(".")[:-1]) - lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") - lora_name_to_module_key[lora_name] = key - + if flux_path is not None: + logger.info(f"loading keys from FLUX.1 model: {flux_path}") + with safe_open(flux_path, framework="pt", device=loading_device) as flux_file: + keys = list(flux_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") + lora_name_to_module_key[lora_name] = key + + lora_name_to_clip_l_key = {} + if clip_l_path is not None: + logger.info(f"loading keys from clip_l model: {clip_l_path}") + with safe_open(clip_l_path, framework="pt", device=loading_device) as clip_l_file: + keys = list(clip_l_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP + "_" + module_name.replace(".", "_") + lora_name_to_clip_l_key[lora_name] = key + + lora_name_to_t5xxl_key = {} + if t5xxl_path is not None: + logger.info(f"loading keys from t5xxl model: {t5xxl_path}") + with safe_open(t5xxl_path, framework="pt", device=loading_device) as t5xxl_file: + keys = list(t5xxl_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5 + "_" + module_name.replace(".", "_") + lora_name_to_t5xxl_key[lora_name] = key + + flux_state_dict = {} + clip_l_state_dict = {} + t5xxl_state_dict = {} if mem_eff_load_save: - flux_state_dict = {} - with MemoryEfficientSafeOpen(flux_model) as flux_file: - for key in tqdm(flux_file.keys()): - flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + if flux_path is not None: + with MemoryEfficientSafeOpen(flux_path) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + + if clip_l_path is not None: + with MemoryEfficientSafeOpen(clip_l_path) as clip_l_file: + for key in tqdm(clip_l_file.keys()): + clip_l_state_dict[key] = clip_l_file.get_tensor(key).to(loading_device) + + if t5xxl_path is not None: + with MemoryEfficientSafeOpen(t5xxl_path) as t5xxl_file: + for key in tqdm(t5xxl_file.keys()): + t5xxl_state_dict[key] = t5xxl_file.get_tensor(key).to(loading_device) else: - flux_state_dict = load_file(flux_model, device=loading_device) + if flux_path is not None: + flux_state_dict = load_file(flux_path, device=loading_device) + if clip_l_path is not None: + clip_l_state_dict = load_file(clip_l_path, device=loading_device) + if t5xxl_path is not None: + t5xxl_state_dict = load_file(t5xxl_path, device=loading_device) for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") @@ -81,8 +132,20 @@ def merge_to_flux_model( up_key = key.replace("lora_down", "lora_up") alpha_key = key[: key.index("lora_down")] + "alpha" - if lora_name not in lora_name_to_module_key: - logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.") + if lora_name in lora_name_to_module_key: + module_weight_key = lora_name_to_module_key[lora_name] + state_dict = flux_state_dict + elif lora_name in lora_name_to_clip_l_key: + module_weight_key = lora_name_to_clip_l_key[lora_name] + state_dict = clip_l_state_dict + elif lora_name in lora_name_to_t5xxl_key: + module_weight_key = lora_name_to_t5xxl_key[lora_name] + state_dict = t5xxl_state_dict + else: + logger.warning( + f"no module found for LoRA weight: {key}. Skipping..." + f"LoRAの重みに対応するモジュールが見つかりませんでした。スキップします。" + ) continue down_weight = lora_sd.pop(key) @@ -93,11 +156,7 @@ def merge_to_flux_model( scale = alpha / dim # W <- W + U * D - module_weight_key = lora_name_to_module_key[lora_name] - if module_weight_key not in flux_state_dict: - weight = flux_file.get_tensor(module_weight_key) - else: - weight = flux_state_dict[module_weight_key] + weight = state_dict[module_weight_key] weight = weight.to(working_device, merge_dtype) up_weight = up_weight.to(working_device, merge_dtype) @@ -121,7 +180,7 @@ def merge_to_flux_model( # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + ratio * conved * scale - flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + state_dict[module_weight_key] = weight.to(loading_device, save_dtype) del up_weight del down_weight del weight @@ -129,7 +188,7 @@ def merge_to_flux_model( if len(lora_sd) > 0: logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}") - return flux_state_dict + return flux_state_dict, clip_l_state_dict, t5xxl_state_dict def merge_to_flux_model_diffusers( @@ -508,17 +567,28 @@ def merge(args): if save_dtype is None: save_dtype = merge_dtype - dest_dir = os.path.dirname(args.save_to) + assert ( + args.save_to or args.clip_l_save_to or args.t5xxl_save_to + ), "save_to or clip_l_save_to or t5xxl_save_to must be specified / save_toまたはclip_l_save_toまたはt5xxl_save_toを指定してください" + dest_dir = os.path.dirname(args.save_to or args.clip_l_save_to or args.t5xxl_save_to) if not os.path.exists(dest_dir): logger.info(f"creating directory: {dest_dir}") os.makedirs(dest_dir) - if args.flux_model is not None: + if args.flux_model is not None or args.clip_l is not None or args.t5xxl is not None: if not args.diffusers: - state_dict = merge_to_flux_model( + assert (args.clip_l is None and args.clip_l_save_to is None) or ( + args.clip_l is not None and args.clip_l_save_to is not None + ), "clip_l_save_to must be specified if clip_l is specified / clip_lが指定されている場合はclip_l_save_toも指定してください" + assert (args.t5xxl is None and args.t5xxl_save_to is None) or ( + args.t5xxl is not None and args.t5xxl_save_to is not None + ), "t5xxl_save_to must be specified if t5xxl is specified / t5xxlが指定されている場合はt5xxl_save_toも指定してください" + flux_state_dict, clip_l_state_dict, t5xxl_state_dict = merge_to_flux_model( args.loading_device, args.working_device, args.flux_model, + args.clip_l, + args.t5xxl, args.models, args.ratios, merge_dtype, @@ -526,7 +596,10 @@ def merge(args): args.mem_eff_load_save, ) else: - state_dict = merge_to_flux_model_diffusers( + assert ( + args.clip_l is None and args.t5xxl is None + ), "clip_l and t5xxl are not supported with --diffusers / clip_l、t5xxlはDiffusersではサポートされていません" + flux_state_dict = merge_to_flux_model_diffusers( args.loading_device, args.working_device, args.flux_model, @@ -536,8 +609,10 @@ def merge(args): save_dtype, args.mem_eff_load_save, ) + clip_l_state_dict = None + t5xxl_state_dict = None - if args.no_metadata: + if args.no_metadata or (flux_state_dict is None or len(flux_state_dict) == 0): sai_metadata = None else: merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) @@ -546,15 +621,24 @@ def merge(args): None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) - logger.info(f"saving FLUX model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, sai_metadata, args.mem_eff_load_save) + if flux_state_dict is not None and len(flux_state_dict) > 0: + logger.info(f"saving FLUX model to: {args.save_to}") + save_to_file(args.save_to, flux_state_dict, save_dtype, sai_metadata, args.mem_eff_load_save) + + if clip_l_state_dict is not None and len(clip_l_state_dict) > 0: + logger.info(f"saving clip_l model to: {args.clip_l_save_to}") + save_to_file(args.clip_l_save_to, clip_l_state_dict, save_dtype, None, args.mem_eff_load_save) + + if t5xxl_state_dict is not None and len(t5xxl_state_dict) > 0: + logger.info(f"saving t5xxl model to: {args.t5xxl_save_to}") + save_to_file(args.t5xxl_save_to, t5xxl_state_dict, save_dtype, None, args.mem_eff_load_save) else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + flux_state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) logger.info("calculating hashes and creating metadata...") - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(flux_state_dict, metadata) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash @@ -562,12 +646,12 @@ def merge(args): merged_from = sai_model_spec.build_merged_from(args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" + flux_state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) metadata.update(sai_metadata) logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, metadata) + save_to_file(args.save_to, flux_state_dict, save_dtype, metadata) def setup_parser() -> argparse.ArgumentParser: @@ -592,6 +676,18 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする", ) + parser.add_argument( + "--clip_l", + type=str, + default=None, + help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)", + ) + parser.add_argument( + "--t5xxl", + type=str, + default=None, + help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)", + ) parser.add_argument( "--mem_eff_load_save", action="store_true", @@ -617,6 +713,18 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="destination file name: safetensors file / 保存先のファイル名、safetensorsファイル", ) + parser.add_argument( + "--clip_l_save_to", + type=str, + default=None, + help="destination file name for clip_l: safetensors file / clip_lの保存先のファイル名、safetensorsファイル", + ) + parser.add_argument( + "--t5xxl_save_to", + type=str, + default=None, + help="destination file name for t5xxl: safetensors file / t5xxlの保存先のファイル名、safetensorsファイル", + ) parser.add_argument( "--models", type=str, From d9129522a6effea7077f18cdea0ee733a5ac7cb0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Sep 2024 12:20:07 +0900 Subject: [PATCH 101/163] set dtype before calling ae closes #1562 --- flux_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train.py b/flux_train.py index 32a36f036..0293b7be3 100644 --- a/flux_train.py +++ b/flux_train.py @@ -651,7 +651,7 @@ def optimizer_hook(parameter: torch.Tensor): else: with torch.no_grad(): # encode images to latents. images are [-1, 1] - latents = ae.encode(batch["images"]) + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): From 2889108d858880589d362e06e98eeadf4682476a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Sep 2024 20:58:33 +0900 Subject: [PATCH 102/163] feat: Add --cpu_offload_checkpointing option to LoRA training --- README.md | 7 +++++++ flux_train.py | 2 +- flux_train_network.py | 5 +++++ train_network.py | 12 +++++++++++- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index fa81f6c0f..e8a12089f 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,12 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 5, 2024 (update 1): + +Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. + Sep 5, 2024: + The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. Sep 4, 2024: @@ -72,6 +77,8 @@ The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_ --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` +`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. + We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. The trained LoRA model can be used with ComfyUI. diff --git a/flux_train.py b/flux_train.py index 0293b7be3..0edc83a9f 100644 --- a/flux_train.py +++ b/flux_train.py @@ -261,7 +261,7 @@ def train(args): ) if args.gradient_checkpointing: - flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing) + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) flux.requires_grad_(True) diff --git a/flux_train_network.py b/flux_train_network.py index 2fc0f3234..a6e57eede 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -50,6 +50,11 @@ def assert_extra_args(self, args, train_dataset_group): if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + assert not args.split_mode or not args.cpu_offload_checkpointing, ( + "split_mode and cpu_offload_checkpointing cannot be used together" + " / split_modeとcpu_offload_checkpointingは同時に使用できません" + ) + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this def get_flux_model_name(self, args): diff --git a/train_network.py b/train_network.py index a68ccfcc4..ad97491df 100644 --- a/train_network.py +++ b/train_network.py @@ -451,7 +451,11 @@ def train(self, args): accelerator.print(f"load network weights from {args.network_weights}: {info}") if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() + if args.cpu_offload_checkpointing: + unet.enable_gradient_checkpointing(cpu_offload=True) + else: + unet.enable_gradient_checkpointing() + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): if flag: if t_enc.supports_gradient_checkpointing: @@ -1281,6 +1285,12 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported" + " / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)", + ) parser.add_argument( "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" ) From 0005867ba509d2e1a5674b267e8286b561c0ed71 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Sep 2024 10:45:18 +0900 Subject: [PATCH 103/163] update README, format code --- README.md | 5 +++++ library/train_util.py | 4 ++-- library/utils.py | 4 +++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 81a549378..16ab80e7a 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds! + +- Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v! + - `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened! + - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. diff --git a/library/train_util.py b/library/train_util.py index 102d39ed7..1441e74f6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2094,7 +2094,7 @@ def __getitem__(self, index): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img=pil_resize(cond_img,(int(target_size_hw[1]), int(target_size_hw[0]))) + cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0]))) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2432,7 +2432,7 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group -def load_image(image_path, alpha=False): +def load_image(image_path, alpha=False): try: with Image.open(image_path) as image: if alpha: diff --git a/library/utils.py b/library/utils.py index a219f6cb7..5b7e657b2 100644 --- a/library/utils.py +++ b/library/utils.py @@ -11,6 +11,7 @@ from PIL import Image import numpy as np + def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -80,8 +81,8 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) -def pil_resize(image, size, interpolation=Image.LANCZOS): +def pil_resize(image, size, interpolation=Image.LANCZOS): pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) # use Pillow resize @@ -92,6 +93,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS): return resized_cv2 + # TODO make inf_utils.py From d29af146b8d4c4d028f8752657bd1349c8cd3509 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Sep 2024 23:01:15 +0900 Subject: [PATCH 104/163] add negative prompt for flux inference script --- README.md | 3 + flux_minimal_inference.py | 289 ++++++++++++++++++++++++++------------ 2 files changed, 206 insertions(+), 86 deletions(-) diff --git a/README.md b/README.md index 2f010f499..126516f95 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 9, 2024: +Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. + Sep 5, 2024 (update 1): Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 1c194e7c1..de607c52a 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -71,22 +71,57 @@ def denoise( timesteps: list[float], guidance: float = 4.0, t5_attn_mask: Optional[torch.Tensor] = None, + neg_txt: Optional[torch.Tensor] = None, + neg_vec: Optional[torch.Tensor] = None, + neg_t5_attn_mask: Optional[torch.Tensor] = None, + cfg_scale: Optional[float] = None, ): # this is ignored for schnell + logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}") guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + + # prepare classifier free guidance + if neg_txt is not None and neg_vec is not None: + b_img_ids = torch.cat([img_ids, img_ids], dim=0) + b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0) + b_txt = torch.cat([neg_txt, txt], dim=0) + b_vec = torch.cat([neg_vec, vec], dim=0) + if t5_attn_mask is not None and neg_t5_attn_mask is not None: + b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) + else: + b_t5_attn_mask = None + else: + b_img_ids = img_ids + b_txt_ids = txt_ids + b_txt = txt + b_vec = vec + b_t5_attn_mask = t5_attn_mask + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): - t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device) + + # classifier free guidance + if neg_txt is not None and neg_vec is not None: + b_img = torch.cat([img, img], dim=0) + else: + b_img = img + pred = model( - img=img, - img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, + img=b_img, + img_ids=b_img_ids, + txt=b_txt, + txt_ids=b_txt_ids, + y=b_vec, timesteps=t_vec, guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, + txt_attention_mask=b_t5_attn_mask, ) + # classifier free guidance + if neg_txt is not None and neg_vec is not None: + pred_uncond, pred = torch.chunk(pred, 2, dim=0) + pred = pred_uncond + cfg_scale * (pred - pred_uncond) + img = img + (t_prev - t_curr) * pred return img @@ -106,19 +141,48 @@ def do_sample( is_schnell: bool, device: torch.device, flux_dtype: torch.dtype, + neg_l_pooled: Optional[torch.Tensor] = None, + neg_t5_out: Optional[torch.Tensor] = None, + neg_t5_attn_mask: Optional[torch.Tensor] = None, + cfg_scale: Optional[float] = None, ): + logger.info(f"num_steps: {num_steps}") timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) # denoise initial noise if accelerator: with accelerator.autocast(), torch.no_grad(): x = denoise( - model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + model, + img, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps, + guidance, + t5_attn_mask, + neg_t5_out, + neg_l_pooled, + neg_t5_attn_mask, + cfg_scale, ) else: with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): x = denoise( - model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + model, + img, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps, + guidance, + t5_attn_mask, + neg_t5_out, + neg_l_pooled, + neg_t5_attn_mask, + cfg_scale, ) return x @@ -135,6 +199,8 @@ def generate_image( image_height: int, steps: Optional[int], guidance: float, + negative_prompt: Optional[str], + cfg_scale: float, ): seed = seed if seed is not None else random.randint(0, 2**32 - 1) logger.info(f"Seed: {seed}") @@ -162,65 +228,73 @@ def generate_image( # txt2img only needs img_ids img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) + # prepare fp8 models + if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared): + logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") + clip_l.to(clip_l_dtype) # fp8 + clip_l.text_model.embeddings.to(dtype=torch.bfloat16) + clip_l.fp8_prepared = True + + if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared): + logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + t5xxl.to(t5xxl_dtype) + prepare_fp8(t5xxl.encoder, torch.bfloat16) + t5xxl.fp8_prepared = True + # prepare embeddings logger.info("Encoding prompts...") - tokens_and_masks = tokenize_strategy.tokenize(prompt) clip_l = clip_l.to(device) t5xxl = t5xxl.to(device) - with torch.no_grad(): - if is_fp8(clip_l_dtype): - param_itr = clip_l.parameters() - param_itr.__next__() # skip first - param_2nd = param_itr.__next__() - if param_2nd.dtype != clip_l_dtype: - logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") - clip_l.to(clip_l_dtype) # fp8 - clip_l.text_model.embeddings.to(dtype=torch.bfloat16) - - with accelerator.autocast(): - l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) - if is_fp8(t5xxl_dtype): - if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"): - logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") - - def prepare_fp8(text_encoder, target_dtype): - def forward_hook(module): - def forward(hidden_states): - hidden_gelu = module.act(module.wi_0(hidden_states)) - hidden_linear = module.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = module.dropout(hidden_states) - - hidden_states = module.wo(hidden_states) - return hidden_states - - return forward - - for module in text_encoder.modules(): - if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: - # print("set", module.__class__.__name__, "to", target_dtype) - module.to(target_dtype) - if module.__class__.__name__ in ["T5DenseGatedActDense"]: - # print("set", module.__class__.__name__, "hooks") - module.forward = forward_hook(module) - - text_encoder.fp8_prepared = True - - t5xxl.to(t5xxl_dtype) - prepare_fp8(t5xxl.encoder, torch.bfloat16) - - with accelerator.autocast(): - _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask - ) - else: - with torch.autocast(device_type=device.type, dtype=clip_l_dtype): - l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) - with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): - _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( - tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask - ) + def encode(prpt: str): + tokens_and_masks = tokenize_strategy.tokenize(prpt) + with torch.no_grad(): + if is_fp8(clip_l_dtype): + with accelerator.autocast(): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + else: + with torch.autocast(device_type=device.type, dtype=clip_l_dtype): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + + if is_fp8(t5xxl_dtype): + with accelerator.autocast(): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + else: + with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + return l_pooled, t5_out, txt_ids, t5_attn_mask + + l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt) + if negative_prompt: + neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt) + else: + neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None # NaN check if torch.isnan(l_pooled).any(): @@ -244,7 +318,23 @@ def forward(hidden_states): t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None x = do_sample( - accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, t5_attn_mask, is_schnell, device, flux_dtype + accelerator, + model, + noise, + img_ids, + l_pooled, + t5_out, + txt_ids, + steps, + guidance, + t5_attn_mask, + is_schnell, + device, + flux_dtype, + neg_l_pooled, + neg_t5_out, + neg_t5_attn_mask, + cfg_scale, ) if args.offload: model = model.cpu() @@ -307,6 +397,8 @@ def forward(hidden_states): parser.add_argument("--seed", type=int, default=None) parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev") parser.add_argument("--guidance", type=float, default=3.5) + parser.add_argument("--negative_prompt", type=str, default=None) + parser.add_argument("--cfg_scale", type=float, default=1.0) parser.add_argument("--offload", action="store_true", help="Offload to CPU") parser.add_argument( "--lora_weights", @@ -403,19 +495,34 @@ def is_fp8(dt): lora_model.to(device) lora_models.append(lora_model) - + if not args.interactive: - generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) + generate_image( + model, + clip_l, + t5xxl, + ae, + args.prompt, + args.seed, + args.width, + args.height, + args.steps, + args.guidance, + args.negative_prompt, + args.cfg_scale, + ) else: # loop for interactive width = target_width height = target_height steps = None guidance = args.guidance + cfg_scale = args.cfg_scale while True: print( "Enter prompt (empty to exit). Options: --w --h --s --d --g --m " + " --n , `-` for empty negative prompt --c " ) prompt = input() if prompt == "": @@ -425,26 +532,36 @@ def is_fp8(dt): options = prompt.split("--") prompt = options[0].strip() seed = None + negative_prompt = None for opt in options[1:]: - opt = opt.strip() - if opt.startswith("w"): - width = int(opt[1:].strip()) - elif opt.startswith("h"): - height = int(opt[1:].strip()) - elif opt.startswith("s"): - steps = int(opt[1:].strip()) - elif opt.startswith("d"): - seed = int(opt[1:].strip()) - elif opt.startswith("g"): - guidance = float(opt[1:].strip()) - elif opt.startswith("m"): - mutipliers = opt[1:].strip().split(",") - if len(mutipliers) != len(lora_models): - logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") - continue - for i, lora_model in enumerate(lora_models): - lora_model.set_multiplier(float(mutipliers[i])) - - generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance) + try: + opt = opt.strip() + if opt.startswith("w"): + width = int(opt[1:].strip()) + elif opt.startswith("h"): + height = int(opt[1:].strip()) + elif opt.startswith("s"): + steps = int(opt[1:].strip()) + elif opt.startswith("d"): + seed = int(opt[1:].strip()) + elif opt.startswith("g"): + guidance = float(opt[1:].strip()) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) + elif opt.startswith("n"): + negative_prompt = opt[1:].strip() + if negative_prompt == "-": + negative_prompt = "" + elif opt.startswith("c"): + cfg_scale = float(opt[1:].strip()) + except ValueError as e: + logger.error(f"Invalid option: {opt}, {e}") + + generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale) logger.info("Done!") From d10ff62a78b15d0bb55f443cc2849c460300131b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 10 Sep 2024 20:32:09 +0900 Subject: [PATCH 105/163] support individual LR for CLIP-L/T5XXL --- README.md | 4 +++ networks/lora_flux.py | 71 +++++++++++++++---------------------------- train_network.py | 32 ++++++++++++------- 3 files changed, 49 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 126516f95..b5799dd6f 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 10, 2024: +In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. + Sep 9, 2024: Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. @@ -142,6 +145,7 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times - Remove `--network_train_unet_only` from your command. - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time. - T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. - The trained LoRA can be used with ComfyUI. - Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet. diff --git a/networks/lora_flux.py b/networks/lora_flux.py index ab9ccc4d8..d540c2215 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -786,28 +786,23 @@ def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, lorap logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") - # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): - # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) - # if ( - # self.loraplus_lr_ratio is not None - # or self.loraplus_text_encoder_lr_ratio is not None - # or self.loraplus_unet_lr_ratio is not None - # ): - # assert ( - # optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower() - # ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません" + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of two elements + if text_encoder_lr is None or len(text_encoder_lr) == 0: + text_encoder_lr = [default_lr, default_lr] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] self.requires_grad_(True) all_params = [] lr_descriptions = [] - def assemble_params(loras, lr, ratio): + def assemble_params(loras, lr, loraplus_ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): - if ratio is not None and "lora_up" in name: + if loraplus_ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param @@ -822,7 +817,7 @@ def assemble_params(loras, lr, ratio): if lr is not None: if key == "plus": - param_data["lr"] = lr * ratio + param_data["lr"] = lr * loraplus_ratio else: param_data["lr"] = lr @@ -836,41 +831,23 @@ def assemble_params(loras, lr, ratio): return params, descriptions if self.text_encoder_loras: - params, descriptions = assemble_params( - self.text_encoder_loras, - text_encoder_lr if text_encoder_lr is not None else default_lr, - self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, - ) - all_params.extend(params) - lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions]) + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP)] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 2 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 2 " + (" " + d if d else "") for d in descriptions]) if self.unet_loras: - # if self.block_lr: - # is_sdxl = False - # for lora in self.unet_loras: - # if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name: - # is_sdxl = True - # break - - # # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 - # block_idx_to_lora = {} - # for lora in self.unet_loras: - # idx = get_block_index(lora.lora_name, is_sdxl) - # if idx not in block_idx_to_lora: - # block_idx_to_lora[idx] = [] - # block_idx_to_lora[idx].append(lora) - - # # blockごとにパラメータを設定する - # for idx, block_loras in block_idx_to_lora.items(): - # params, descriptions = assemble_params( - # block_loras, - # (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx), - # self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, - # ) - # all_params.extend(params) - # lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions]) - - # else: params, descriptions = assemble_params( self.unet_loras, unet_lr if unet_lr is not None else default_lr, diff --git a/train_network.py b/train_network.py index ad97491df..e45db0525 100644 --- a/train_network.py +++ b/train_network.py @@ -466,9 +466,17 @@ def train(self, args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - # 後方互換性を確保するよ + # make backward compatibility for text_encoder_lr + support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs") + if support_multiple_lrs: + text_encoder_lr = args.text_encoder_lr + else: + text_encoder_lr = None if args.text_encoder_lr is None or len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] try: - results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + if support_multiple_lrs: + results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate) + else: + results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate) if type(results) is tuple: trainable_params = results[0] lr_descriptions = results[1] @@ -476,11 +484,7 @@ def train(self, args): trainable_params = results lr_descriptions = None except TypeError as e: - # logger.warning(f"{e}") - # accelerator.print( - # "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" - # ) - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr) lr_descriptions = None # if len(trainable_params) == 0: @@ -713,7 +717,7 @@ def load_model_hook(models, input_dir): "ss_training_started_at": training_started_at, # unix timestamp "ss_output_name": args.output_name, "ss_learning_rate": args.learning_rate, - "ss_text_encoder_lr": args.text_encoder_lr, + "ss_text_encoder_lr": text_encoder_lr, "ss_unet_lr": args.unet_lr, "ss_num_train_images": train_dataset_group.num_train_images, "ss_num_reg_images": train_dataset_group.num_reg_images, @@ -760,8 +764,8 @@ def load_model_hook(models, input_dir): "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, "ss_huber_c": args.huber_c, - "ss_fp8_base": args.fp8_base, - "ss_fp8_base_unet": args.fp8_base_unet, + "ss_fp8_base": bool(args.fp8_base), + "ss_fp8_base_unet": bool(args.fp8_base_unet), } self.update_metadata(metadata, args) # architecture specific metadata @@ -1303,7 +1307,13 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") - parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument( + "--text_encoder_lr", + type=float, + default=None, + nargs="*", + help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能", + ) parser.add_argument( "--fp8_base_unet", action="store_true", From 65b8a064f6bb9a403374d4b08f4003037df42f8d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 10 Sep 2024 21:20:38 +0900 Subject: [PATCH 106/163] update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b5799dd6f..caea59b7e 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The command to install PyTorch is as follows: ### Recent Updates Sep 10, 2024: -In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. +In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. Sep 9, 2024: Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. @@ -145,7 +145,7 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times - Remove `--network_train_unet_only` from your command. - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time. - T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. - - The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. + - The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. - The trained LoRA can be used with ComfyUI. - Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet. From fd68703f3795b3e9c75409ac5452807d056b928f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Wed, 11 Sep 2024 20:25:45 +0800 Subject: [PATCH 107/163] Add New lr scheduler (#1393) * add new lr scheduler * fix bugs and use num_cycles / 2 * Update requirements.txt * add num_cycles for min lr * keep PIECEWISE_CONSTANT * allow use float with warmup or decay ratio. * Update train_util.py --- library/train_util.py | 80 ++++++++++++++++++++++++++++++++++++++----- requirements.txt | 6 ++-- 2 files changed, 75 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index c7b73ee37..340f6d640 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -42,7 +42,8 @@ from torchvision import transforms from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection import transformers -from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION +from diffusers.optimization import SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION +from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( StableDiffusionPipeline, DDPMScheduler, @@ -2972,6 +2973,20 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser): + def int_or_float(value): + if value.endswith('%'): + try: + return float(value[:-1]) / 100.0 + except ValueError: + raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage") + try: + float_value = float(value) + if float_value >= 1: + return int(value) + return float(value) + except ValueError: + raise argparse.ArgumentTypeError(f"'{value}' is not an int or float") + parser.add_argument( "--optimizer_type", type=str, @@ -3024,9 +3039,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( "--lr_warmup_steps", - type=int, + type=int_or_float, + default=0, + help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + ) + parser.add_argument( + "--lr_decay_steps", + type=int_or_float, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + help="Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps", ) parser.add_argument( "--lr_scheduler_num_cycles", @@ -3046,6 +3067,18 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL" + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効", ) + parser.add_argument( + "--lr_scheduler_timescale", + type=int, + default=None, + help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`", + ) + parser.add_argument( + "--lr_scheduler_min_lr_ratio", + type=float, + default=None, + help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler", + ) def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): @@ -4293,10 +4326,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): Unified API to get any scheduler from its name. """ name = args.lr_scheduler - num_warmup_steps: Optional[int] = args.lr_warmup_steps num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps + num_warmup_steps: Optional[int] = int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps + num_decay_steps: Optional[int] = int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power + timescale = args.lr_scheduler_timescale + min_lr_ratio = args.lr_scheduler_min_lr_ratio lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: @@ -4332,13 +4369,13 @@ def wrap_check_needless_num_warmup_steps(return_vals): # logger.info(f"adafactor scheduler init lr {initial_lr}") return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + name = SchedulerType(name) or DiffusersSchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) - if name == SchedulerType.PIECEWISE_CONSTANT: + if name == DiffusersSchedulerType.PIECEWISE_CONSTANT: return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs # All other schedulers require `num_warmup_steps` @@ -4348,6 +4385,9 @@ def wrap_check_needless_num_warmup_steps(return_vals): if name == SchedulerType.CONSTANT_WITH_WARMUP: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs) + if name == SchedulerType.INVERSE_SQRT: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs) + # All other schedulers require `num_training_steps` if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") @@ -4366,7 +4406,31 @@ def wrap_check_needless_num_warmup_steps(return_vals): optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs ) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs) + if name == SchedulerType.COSINE_WITH_MIN_LR: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles / 2, + min_lr_rate=min_lr_ratio, + **lr_scheduler_kwargs, + ) + + # All other schedulers require `num_decay_steps` + if num_decay_steps is None: + raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") + if name == SchedulerType.WARMUP_STABLE_DECAY: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + num_cycles=num_cycles / 2, + min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0, + **lr_scheduler_kwargs, + ) + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_decay_steps=num_decay_steps, **lr_scheduler_kwargs) def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): diff --git a/requirements.txt b/requirements.txt index 977c5cd91..d2a2fbb8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -accelerate==0.25.0 -transformers==4.36.2 +accelerate==0.30.0 +transformers==4.41.2 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 @@ -16,7 +16,7 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.20.1 +huggingface-hub==0.23.3 # for Image utils imagesize==1.4.1 # for BLIP captioning From 6dbfd47a59cdb91be2077e1d0dec0f94698348dd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Sep 2024 21:44:36 +0900 Subject: [PATCH 108/163] Fix to work PIECEWISE_CONSTANT, update requirement.txt and README #1393 --- README.md | 9 ++++++ library/train_util.py | 66 ++++++++++++++++++++++++++++--------------- requirements.txt | 4 +-- 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 16ab80e7a..011141bf1 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,15 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries. + - transformers, accelerate and huggingface_hub are updated. + - If you encounter any issues, please report them. + +- en: The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds! + - See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler. + - `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc. + +https://github.com/kohya-ss/sd-scripts/pull/1393 - When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds! - Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v! diff --git a/library/train_util.py b/library/train_util.py index 340f6d640..e65760bae 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -42,7 +42,10 @@ from torchvision import transforms from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection import transformers -from diffusers.optimization import SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION +from diffusers.optimization import ( + SchedulerType as DiffusersSchedulerType, + TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION, +) from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( StableDiffusionPipeline, @@ -2974,7 +2977,7 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser): def int_or_float(value): - if value.endswith('%'): + if value.endswith("%"): try: return float(value[:-1]) / 100.0 except ValueError: @@ -3041,13 +3044,15 @@ def int_or_float(value): "--lr_warmup_steps", type=int_or_float, default=0, - help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps" + " / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", ) parser.add_argument( "--lr_decay_steps", type=int_or_float, default=0, - help="Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps", + help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps" + " / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", ) parser.add_argument( "--lr_scheduler_num_cycles", @@ -3071,13 +3076,16 @@ def int_or_float(value): "--lr_scheduler_timescale", type=int, default=None, - help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`", + help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", + , ) parser.add_argument( "--lr_scheduler_min_lr_ratio", type=float, default=None, - help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler", + help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", ) @@ -4327,8 +4335,12 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): """ name = args.lr_scheduler num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps - num_warmup_steps: Optional[int] = int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps - num_decay_steps: Optional[int] = int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + num_warmup_steps: Optional[int] = ( + int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps + ) + num_decay_steps: Optional[int] = ( + int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + ) num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power @@ -4369,15 +4381,17 @@ def wrap_check_needless_num_warmup_steps(return_vals): # logger.info(f"adafactor scheduler init lr {initial_lr}") return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) - name = SchedulerType(name) or DiffusersSchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] + if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value: + name = DiffusersSchedulerType(name) + schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] + return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs + + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) - if name == DiffusersSchedulerType.PIECEWISE_CONSTANT: - return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs - # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") @@ -4408,11 +4422,11 @@ def wrap_check_needless_num_warmup_steps(return_vals): if name == SchedulerType.COSINE_WITH_MIN_LR: return schedule_func( - optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, num_cycles=num_cycles / 2, - min_lr_rate=min_lr_ratio, + min_lr_rate=min_lr_ratio, **lr_scheduler_kwargs, ) @@ -4421,16 +4435,22 @@ def wrap_check_needless_num_warmup_steps(return_vals): raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") if name == SchedulerType.WARMUP_STABLE_DECAY: return schedule_func( - optimizer, - num_warmup_steps=num_warmup_steps, - num_stable_steps=num_stable_steps, - num_decay_steps=num_decay_steps, - num_cycles=num_cycles / 2, + optimizer, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + num_cycles=num_cycles / 2, min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0, **lr_scheduler_kwargs, ) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_decay_steps=num_decay_steps, **lr_scheduler_kwargs) + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_decay_steps=num_decay_steps, + **lr_scheduler_kwargs, + ) def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): diff --git a/requirements.txt b/requirements.txt index d2a2fbb8a..15e6e58f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ accelerate==0.30.0 -transformers==4.41.2 +transformers==4.44.0 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 @@ -16,7 +16,7 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.23.3 +huggingface-hub==0.24.5 # for Image utils imagesize==1.4.1 # for BLIP captioning From 8311e88225fef377591e5be19eb1f50fe7a2941f Mon Sep 17 00:00:00 2001 From: cocktailpeanut Date: Wed, 11 Sep 2024 09:02:29 -0400 Subject: [PATCH 109/163] typo fix --- library/train_util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index c38864fe6..f682dcbfb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3355,15 +3355,14 @@ def int_or_float(value): type=int, default=None, help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" - " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", - , + + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", ) parser.add_argument( "--lr_scheduler_min_lr_ratio", type=float, default=None, help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" - " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", + + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", ) From c7c666b1829a7c1f3435558efa425b08b50fab41 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Sep 2024 22:12:31 +0900 Subject: [PATCH 110/163] fix typo --- library/train_util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index e65760bae..a46d94877 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3077,15 +3077,14 @@ def int_or_float(value): type=int, default=None, help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" - " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", - , + + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", ) parser.add_argument( "--lr_scheduler_min_lr_ratio", type=float, default=None, help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" - " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", + + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", ) From a823fd9fb8d219b5b4c57df12eed41ae34fdf843 Mon Sep 17 00:00:00 2001 From: Plat <60182057+p1atdev@users.noreply.github.com> Date: Wed, 11 Sep 2024 22:21:16 +0900 Subject: [PATCH 111/163] Improve wandb logging (#1576) * fix: wrong training steps were recorded to wandb, and no log was sent when logging_dir was not specified * fix: checking of whether wandb is enabled * feat: log images to wandb with their positive prompt as captions * feat: logging sample images' caption for sd3 and flux * fix: import wandb before use --- fine_tune.py | 7 +++++-- flux_train.py | 7 +++++-- library/flux_train_utils.py | 20 +++++++++++--------- library/sd3_train_utils.py | 20 +++++++++++--------- library/train_util.py | 20 +++++++++++--------- sd3_train.py | 7 +++++-- sdxl_train.py | 7 +++++-- sdxl_train_control_net_lllite.py | 4 ++-- sdxl_train_control_net_lllite_old.py | 4 ++-- train_controlnet.py | 7 +++++-- train_db.py | 7 +++++-- train_network.py | 7 +++++-- train_textual_inversion.py | 8 ++++++-- train_textual_inversion_XTI.py | 4 ++-- 14 files changed, 80 insertions(+), 49 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index c9102f6c0..fb6b3ed69 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -337,6 +337,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -456,7 +459,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) accelerator.log(logs, step=global_step) @@ -469,7 +472,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/flux_train.py b/flux_train.py index 0edc83a9f..33481df8f 100644 --- a/flux_train.py +++ b/flux_train.py @@ -629,6 +629,9 @@ def optimizer_hook(parameter: torch.Tensor): # For --sample_at_first flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 @@ -777,7 +780,7 @@ def optimizer_hook(parameter: torch.Tensor): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) @@ -791,7 +794,7 @@ def optimizer_hook(parameter: torch.Tensor): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 0b5d4d90e..f77d4b585 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -254,17 +254,19 @@ def sample_image_inference( img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 - try: + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + import wandb + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log( + {f"sample_{i}": wandb.Image( + image, + caption=prompt # positive prompt as a caption + )}, + commit=False + ) def time_shift(mu: float, sigma: float, t: torch.Tensor): diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index da0729506..e819d440c 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -604,17 +604,19 @@ def sample_image_inference( img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 - try: + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + import wandb + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log( + {f"sample_{i}": wandb.Image( + image, + caption=prompt # positive prompt as a caption + )}, + commit=False + ) # region Diffusers diff --git a/library/train_util.py b/library/train_util.py index f682dcbfb..742d057e0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5832,17 +5832,19 @@ def sample_image_inference( img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 - try: + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + import wandb + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log( + {f"sample_{i}": wandb.Image( + image, + caption=prompt # positive prompt as a caption + )}, + commit=False + ) # endregion diff --git a/sd3_train.py b/sd3_train.py index 87011b215..5120105f2 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -682,6 +682,9 @@ def optimizer_hook(parameter: torch.Tensor): # For --sample_at_first sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # following function will be moved to sd3_train_utils @@ -901,7 +904,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_mmdit) @@ -915,7 +918,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/sdxl_train.py b/sdxl_train.py index b2c62dd11..7291ddd2f 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -617,6 +617,9 @@ def optimizer_hook(parameter: torch.Tensor): sdxl_train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -797,7 +800,7 @@ def optimizer_hook(parameter: torch.Tensor): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} if block_lrs is None: train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet) @@ -814,7 +817,7 @@ def optimizer_hook(parameter: torch.Tensor): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 0eaec29b8..9d1cfc63e 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -541,14 +541,14 @@ def remove_model(old_ckpt_name): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 292a0463a..6fa1d6096 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -480,14 +480,14 @@ def remove_model(old_ckpt_name): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_controlnet.py b/train_controlnet.py index c9ac6c5a8..57f0d263f 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -409,6 +409,9 @@ def remove_model(old_ckpt_name): train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # training loop for epoch in range(num_train_epochs): @@ -542,14 +545,14 @@ def remove_model(old_ckpt_name): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_db.py b/train_db.py index 7caee6647..d42afd89a 100644 --- a/train_db.py +++ b/train_db.py @@ -315,6 +315,9 @@ def train(args): train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -445,7 +448,7 @@ def train(args): ) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) accelerator.log(logs, step=global_step) @@ -458,7 +461,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_network.py b/train_network.py index e45db0525..34385ae08 100644 --- a/train_network.py +++ b/train_network.py @@ -1038,6 +1038,9 @@ def remove_model(old_ckpt_name): # For --sample_at_first self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # training loop if initial_step > 0: # only if skip_until_initial_step is specified @@ -1224,7 +1227,7 @@ def remove_model(old_ckpt_name): if args.scale_weight_norms: progress_bar.set_postfix(**{**max_mean_logs, **logs}) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = self.generate_step_logs( args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm ) @@ -1233,7 +1236,7 @@ def remove_model(old_ckpt_name): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 9044f50df..956c78603 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -550,6 +550,9 @@ def remove_model(old_ckpt_name): unet, prompt_replacement, ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # training loop for epoch in range(num_train_epochs): @@ -684,7 +687,7 @@ def remove_model(old_ckpt_name): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} if ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() @@ -702,7 +705,7 @@ def remove_model(old_ckpt_name): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch + 1) @@ -739,6 +742,7 @@ def remove_model(old_ckpt_name): unet, prompt_replacement, ) + accelerator.log({}) # end of epoch diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index efb59137b..ca0b603fb 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -538,7 +538,7 @@ def remove_model(old_ckpt_name): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} if ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() @@ -556,7 +556,7 @@ def remove_model(old_ckpt_name): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch + 1) From 237317fffd060bcfb078b770ccd2df18bc4dd3a6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Sep 2024 22:23:43 +0900 Subject: [PATCH 112/163] update README --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 2b3d0d5a8..d3481b6ae 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 11, 2024: +Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev! + Sep 10, 2024: In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. From cefe52629e1901dd8192b0487afd5e9f089e3519 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 12 Sep 2024 12:36:07 +0900 Subject: [PATCH 113/163] fix to work old notation for TE LR in .toml --- networks/lora_flux.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index d540c2215..dd267de0f 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -788,8 +788,11 @@ def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, lorap def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): # make sure text_encoder_lr as list of two elements - if text_encoder_lr is None or len(text_encoder_lr) == 0: + # if float, use the same value for both text encoders + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): text_encoder_lr = [default_lr, default_lr] + elif isinstance(text_encoder_lr, float): + text_encoder_lr = [text_encoder_lr, text_encoder_lr] elif len(text_encoder_lr) == 1: text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] From 1d7118a62268f12ebfd81c10db53bd85ef9d7631 Mon Sep 17 00:00:00 2001 From: Maru-mee <151493593+Maru-mee@users.noreply.github.com> Date: Fri, 13 Sep 2024 19:01:36 +0900 Subject: [PATCH 114/163] Support : OFT merge to base model (#1580) * Support : OFT merge to base model * Fix typo * Fix typo_2 * Delete unused parameter 'eye' --- networks/sdxl_merge_lora.py | 192 +++++++++++++++++++++++++++--------- 1 file changed, 144 insertions(+), 48 deletions(-) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 3383a80de..2c998c8cb 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -8,10 +8,12 @@ from library import sai_model_spec, sdxl_model_util, train_util import library.model_util as model_util import lora +import oft from library.utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) +import concurrent.futures def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -39,82 +41,176 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): else: torch.save(model, file_name) +def detect_method_from_training_model(models, dtype): + for model in models: + lora_sd, _ = load_state_dict(model, dtype) + for key in tqdm(lora_sd.keys()): + if 'lora_up' in key or 'lora_down' in key: + return 'LoRA' + elif "oft_blocks" in key: + return 'OFT' def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): text_encoder1.to(merge_dtype) text_encoder1.to(merge_dtype) unet.to(merge_dtype) + + # detect the method: OFT or LoRA_module + method = detect_method_from_training_model(models, merge_dtype) + logger.info(f"method:{method}") # create module map name_to_module = {} for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): - if i <= 1: - if i == 0: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + if method == 'LoRA': + if i <= 1: + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + else: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE else: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 - target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - else: - prefix = lora.LoRANetwork.LORA_PREFIX_UNET - target_replace_modules = ( + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = ( lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + ) + elif method == 'OFT': + prefix = oft.OFTNetwork.OFT_PREFIX_UNET + target_replace_modules = ( + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 ) for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - name_to_module[lora_name] = child_module - + if method == 'LoRA': + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + elif method == 'OFT': + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + oft_name = prefix + "." + name + "." + child_name + oft_name = oft_name.replace(".", "_") + name_to_module[oft_name] = child_module + + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) logger.info(f"merging...") - for key in tqdm(lora_sd.keys()): - if "lora_down" in key: - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" - # find original module for this lora - module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if method == 'LoRA': + for key in tqdm(lora_sd.keys()): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + # find original module for this lora + module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + logger.info(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # logger.info(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = module.weight + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + module.weight = torch.nn.Parameter(weight) + + + elif method == 'OFT': + + multiplier=1.0 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + for key in tqdm(lora_sd.keys()): + if "oft_blocks" in key: + oft_blocks = lora_sd[key] + dim = oft_blocks.shape[0] + break + for key in tqdm(lora_sd.keys()): + if "alpha" in key: + oft_blocks = lora_sd[key] + alpha = oft_blocks.item() + break + + def merge_to(key): + if "alpha" in key: + return + + # find original module for this OFT + module_name = ".".join(key.split(".")[:-1]) if module_name not in name_to_module: - logger.info(f"no module found for LoRA weight: {key}") - continue + return module = name_to_module[module_name] - # logger.info(f"apply {key} to {module}") - down_weight = lora_sd[key] - up_weight = lora_sd[up_key] - - dim = down_weight.size()[0] - alpha = lora_sd.get(alpha_key, dim) - scale = alpha / dim - - # W <- W + U * D - weight = module.weight - # logger.info(module_name, down_weight.size(), up_weight.size()) - if len(weight.size()) == 2: - # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + ratio - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) + # logger.info(f"apply {key} to {module}") + + oft_blocks = lora_sd[key] + + if isinstance(module, torch.nn.Linear): + out_dim = module.out_features + elif isinstance(module, torch.nn.Conv2d): + out_dim = module.out_channels + + num_blocks = dim + block_size = out_dim // dim + constraint = (0 if alpha is None else alpha) * out_dim + + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + block_R_weighted = multiplier * block_R + (1 - multiplier) * I + R = torch.block_diag(*block_R_weighted) + + # get org weight + org_sd = module.state_dict() + org_weight = org_sd["weight"].to(device) + + R = R.to(org_weight.device, dtype=org_weight.dtype) + + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R) else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale - + weight = torch.einsum("oi, op -> pi", org_weight, R) + + weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor + module.weight = torch.nn.Parameter(weight) + with concurrent.futures.ThreadPoolExecutor() as executor: + list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) + def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): base_alphas = {} # alpha for merged model From 57ae44eb6138fe4a3864fffa62090f9d0113417d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 19:45:00 +0900 Subject: [PATCH 115/163] refactor to make safer --- networks/sdxl_merge_lora.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 2c998c8cb..d5a54e02a 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -44,11 +44,11 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): def detect_method_from_training_model(models, dtype): for model in models: lora_sd, _ = load_state_dict(model, dtype) - for key in tqdm(lora_sd.keys()): - if 'lora_up' in key or 'lora_down' in key: - return 'LoRA' - elif "oft_blocks" in key: - return 'OFT' + for key in tqdm(lora_sd.keys()): + if 'lora_up' in key or 'lora_down' in key: + return 'LoRA' + elif "oft_blocks" in key: + return 'OFT' def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): text_encoder1.to(merge_dtype) @@ -76,6 +76,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ ) elif method == 'OFT': prefix = oft.OFTNetwork.OFT_PREFIX_UNET + # ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY target_replace_modules = ( oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 ) @@ -83,17 +84,11 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - if method == 'LoRA': - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - name_to_module[lora_name] = child_module - elif method == 'OFT': - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - oft_name = prefix + "." + name + "." + child_name - oft_name = oft_name.replace(".", "_") - name_to_module[oft_name] = child_module - + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") @@ -168,6 +163,7 @@ def merge_to(key): # find original module for this OFT module_name = ".".join(key.split(".")[:-1]) if module_name not in name_to_module: + logger.info(f"no module found for OFT weight: {key}") return module = name_to_module[module_name] @@ -208,7 +204,9 @@ def merge_to(key): module.weight = torch.nn.Parameter(weight) - with concurrent.futures.ThreadPoolExecutor() as executor: + # TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough + max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) From 3387dc7306087b84646666e49323980c89d14945 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 19:45:42 +0900 Subject: [PATCH 116/163] formatting, update README --- README.md | 6 +++ networks/sdxl_merge_lora.py | 86 +++++++++++++++++++++---------------- 2 files changed, 54 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index fd81a781f..d5d2a7f73 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### Sep 13, 2024 / 2024-09-13: + +- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). Will be included in the next release. + +- `sdxl_merge_lora.py` が OFT をサポートしました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。次のリリースに含まれます。 + ### Jun 23, 2024 / 2024-06-23: - Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index d5a54e02a..d5b6f7f34 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -10,11 +10,14 @@ import lora import oft from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) import concurrent.futures + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": sd = load_file(file_name) @@ -41,20 +44,22 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): else: torch.save(model, file_name) + def detect_method_from_training_model(models, dtype): for model in models: lora_sd, _ = load_state_dict(model, dtype) for key in tqdm(lora_sd.keys()): - if 'lora_up' in key or 'lora_down' in key: - return 'LoRA' + if "lora_up" in key or "lora_down" in key: + return "LoRA" elif "oft_blocks" in key: - return 'OFT' + return "OFT" + def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): text_encoder1.to(merge_dtype) text_encoder1.to(merge_dtype) unet.to(merge_dtype) - + # detect the method: OFT or LoRA_module method = detect_method_from_training_model(models, merge_dtype) logger.info(f"method:{method}") @@ -62,7 +67,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ # create module map name_to_module = {} for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): - if method == 'LoRA': + if method == "LoRA": if i <= 1: if i == 0: prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 @@ -72,9 +77,9 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ else: prefix = lora.LoRANetwork.LORA_PREFIX_UNET target_replace_modules = ( - lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 ) - elif method == 'OFT': + elif method == "OFT": prefix = oft.OFTNetwork.OFT_PREFIX_UNET # ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY target_replace_modules = ( @@ -88,15 +93,14 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ lora_name = prefix + "." + name + "." + child_name lora_name = lora_name.replace(".", "_") name_to_module[lora_name] = child_module - - + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) logger.info(f"merging...") - if method == 'LoRA': + if method == "LoRA": for key in tqdm(lora_sd.keys()): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -139,12 +143,11 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ module.weight = torch.nn.Parameter(weight) - - elif method == 'OFT': - - multiplier=1.0 - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - + elif method == "OFT": + + multiplier = 1.0 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + for key in tqdm(lora_sd.keys()): if "oft_blocks" in key: oft_blocks = lora_sd[key] @@ -154,12 +157,12 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ if "alpha" in key: oft_blocks = lora_sd[key] alpha = oft_blocks.item() - break - + break + def merge_to(key): if "alpha" in key: return - + # find original module for this OFT module_name = ".".join(key.split(".")[:-1]) if module_name not in name_to_module: @@ -168,18 +171,18 @@ def merge_to(key): module = name_to_module[module_name] # logger.info(f"apply {key} to {module}") - + oft_blocks = lora_sd[key] - + if isinstance(module, torch.nn.Linear): out_dim = module.out_features elif isinstance(module, torch.nn.Conv2d): out_dim = module.out_channels - + num_blocks = dim block_size = out_dim // dim constraint = (0 if alpha is None else alpha) * out_dim - + block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=constraint) @@ -188,24 +191,24 @@ def merge_to(key): block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) block_R_weighted = multiplier * block_R + (1 - multiplier) * I R = torch.block_diag(*block_R_weighted) - + # get org weight org_sd = module.state_dict() org_weight = org_sd["weight"].to(device) R = R.to(org_weight.device, dtype=org_weight.dtype) - + if org_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", org_weight, R) else: weight = torch.einsum("oi, op -> pi", org_weight, R) - - weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor - + + weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor + module.weight = torch.nn.Parameter(weight) # TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough - max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU + max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) @@ -258,7 +261,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): for key in tqdm(lora_sd.keys()): if "alpha" in key: continue - + if "lora_up" in key and concat: concat_dim = 1 elif "lora_down" in key and concat: @@ -272,8 +275,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio - scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 - + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + if key in merged_sd: assert ( merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None @@ -295,7 +298,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): dim = merged_sd[key_down].shape[0] perm = torch.randperm(dim) merged_sd[key_down] = merged_sd[key_down][perm] - merged_sd[key_up] = merged_sd[key_up][:,perm] + merged_sd[key_up] = merged_sd[key_up][:, perm] logger.info("merged model") logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") @@ -323,7 +326,9 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert len(args.models) == len( + args.ratios + ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" def str_to_dtype(p): if p == "float": @@ -410,10 +415,16 @@ def setup_parser() -> argparse.ArgumentParser: help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument( - "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + "--models", + type=str, + nargs="*", + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") parser.add_argument( @@ -431,8 +442,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--shuffle", action="store_true", - help="shuffle lora weight./ " - + "LoRAの重みをシャッフルする", + help="shuffle lora weight./ " + "LoRAの重みをシャッフルする", ) return parser From 734d2e5b2b7a1551f3750a15e71060f3beed98e9 Mon Sep 17 00:00:00 2001 From: terracottahaniwa <57107346+terracottahaniwa@users.noreply.github.com> Date: Fri, 13 Sep 2024 20:45:35 +0900 Subject: [PATCH 117/163] Support Lora Block Weight (LBW) to svd_merge_lora.py (#1575) * support lora block weight * solve license incompatibility * Fix issue: lbw index calculation --- networks/svd_merge_lora.py | 150 ++++++++++++++++++++++++++++++++++++- 1 file changed, 146 insertions(+), 4 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index cb00a6000..6e163aecf 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -1,5 +1,8 @@ import argparse +import itertools +import json import os +import re import time import torch from safetensors.torch import load_file, save_file @@ -14,6 +17,106 @@ CLAMP_QUANTILE = 0.99 +ACCEPTABLE = [12, 17, 20, 26] +SDXL_LAYER_NUM = [12, 20] + +LAYER12 = { + "BASE": True, + "IN00": False, "IN01": False, "IN02": False, "IN03": False, "IN04": True, "IN05": True, + "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "MID": True, + "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": False, "OUT07": False, "OUT08": False, "OUT09": False, "OUT10": False, "OUT11": False +} + +LAYER17 = { + "BASE": True, + "IN00": False, "IN01": True, "IN02": True, "IN03": False, "IN04": True, "IN05": True, + "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "MID": True, + "OUT00": False, "OUT01": False, "OUT02": False, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, +} + +LAYER20 = { + "BASE": True, + "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, + "IN06": True, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "MID": True, + "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": False, "OUT10": False, "OUT11": False, +} + +LAYER26 = { + "BASE": True, + "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, + "IN06": True, "IN07": True, "IN08": True, "IN09": True, "IN10": True, "IN11": True, + "MID": True, + "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, +} + +assert len([v for v in LAYER12.values() if v]) == 12 +assert len([v for v in LAYER17.values() if v]) == 17 +assert len([v for v in LAYER20.values() if v]) == 20 +assert len([v for v in LAYER26.values() if v]) == 26 + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int: + # lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder + if "text_model_encoder_" in lora_name: # LoRA for text encoder + return 0 + + # lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2 + block_idx = -1 # invalid lora name + if not is_sdxl: + NUM_OF_BLOCKS = 12 # up/down blocks + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + up_down = g[0] + i = int(g[1]) + j = int(g[3]) + if up_down == "down": + if g[2] == "resnets" or g[2] == "attentions": + idx = 3 * i + j + 1 + elif g[2] == "downsamplers": + idx = 3 * (i + 1) + else: + return block_idx # invalid lora name + elif up_down == "up": + if g[2] == "resnets" or g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers": + idx = 3 * i + 2 + else: + return block_idx # invalid lora name + + if g[0] == "down": + block_idx = 1 + idx # 1-based index, down block index + elif g[0] == "up": + block_idx = 1 + NUM_OF_BLOCKS + 1 + idx # 1-based index, num blocks, mid block, up block index + + elif "mid_block_" in lora_name: + block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block + else: + if lora_name.startswith("lora_unet_"): + name = lora_name[len("lora_unet_") :] + if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts + block_idx = 1 + elif name.startswith("input_blocks_"): # 1-8 to 2-9 + block_idx = 1 + int(name.split("_")[2]) + elif name.startswith("middle_block_"): # 10 + block_idx = 10 + elif name.startswith("output_blocks_"): # 0-8 to 11-19 + block_idx = 11 + int(name.split("_")[2]) + elif name.startswith("out_"): # 20, No LoRA in sd-scripts + block_idx = 20 + + return block_idx + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -42,12 +145,34 @@ def save_to_file(file_name, state_dict, dtype, metadata): torch.save(state_dict, file_name) -def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): +def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype): logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} - v2 = None + v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2 base_model = None - for model, ratio in zip(models, ratios): + + if lbws: + try: + # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している + lbws = [json.loads(lbw) for lbw in lbws] + except Exception: + raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") + assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" + assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" + assert all(len(lbw) in ACCEPTABLE for lbw in lbws), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" + assert all(all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" + + layer_num = len(lbws[0]) + is_sdxl = True if layer_num in SDXL_LAYER_NUM else False + FLAGS = { + "12": LAYER12.values(), + "17": LAYER17.values(), + "20": LAYER20.values(), + "26": LAYER26.values(), + }[str(layer_num)] + LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] + + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) @@ -57,6 +182,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty if base_model is None: base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + print(dict(zip(LAYER26.keys(), lbw_weights))) + # merge logger.info(f"merging...") for key in tqdm(list(lora_sd.keys())): @@ -93,6 +224,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty # W <- W + U * D scale = alpha / network_dim + if lbw: + index = get_lbw_block_index(key, is_sdxl) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + if device: # and isinstance(scale, torch.Tensor): scale = scale.to(device) @@ -170,6 +307,10 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty def merge(args): assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + if args.lbws: + assert len(args.models) == len(args.lbws), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + else: + args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく def str_to_dtype(p): if p == "float": @@ -187,7 +328,7 @@ def str_to_dtype(p): new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank state_dict, metadata, v2, base_model = merge_lora_models( - args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype + args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype ) logger.info(f"calculating hashes and creating metadata...") @@ -237,6 +378,7 @@ def setup_parser() -> argparse.ArgumentParser: "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") parser.add_argument( "--new_conv_rank", From f4a0bea6dce152e2210f611f94acfdfaa72068fe Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 21:26:06 +0900 Subject: [PATCH 118/163] format by black --- networks/svd_merge_lora.py | 188 +++++++++++++++++++++++++++++-------- 1 file changed, 147 insertions(+), 41 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 6e163aecf..0decd9048 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -11,8 +11,10 @@ import library.model_util as model_util import lora from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) CLAMP_QUANTILE = 0.99 @@ -22,38 +24,118 @@ LAYER12 = { "BASE": True, - "IN00": False, "IN01": False, "IN02": False, "IN03": False, "IN04": True, "IN05": True, - "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "IN00": False, + "IN01": False, + "IN02": False, + "IN03": False, + "IN04": True, + "IN05": True, + "IN06": False, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, "MID": True, - "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": False, "OUT07": False, "OUT08": False, "OUT09": False, "OUT10": False, "OUT11": False + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": False, + "OUT07": False, + "OUT08": False, + "OUT09": False, + "OUT10": False, + "OUT11": False, } LAYER17 = { "BASE": True, - "IN00": False, "IN01": True, "IN02": True, "IN03": False, "IN04": True, "IN05": True, - "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "IN00": False, + "IN01": True, + "IN02": True, + "IN03": False, + "IN04": True, + "IN05": True, + "IN06": False, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, "MID": True, - "OUT00": False, "OUT01": False, "OUT02": False, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, + "OUT00": False, + "OUT01": False, + "OUT02": False, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": True, + "OUT10": True, + "OUT11": True, } LAYER20 = { "BASE": True, - "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, - "IN06": True, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "IN00": True, + "IN01": True, + "IN02": True, + "IN03": True, + "IN04": True, + "IN05": True, + "IN06": True, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, "MID": True, - "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": False, "OUT10": False, "OUT11": False, + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": False, + "OUT10": False, + "OUT11": False, } LAYER26 = { "BASE": True, - "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, - "IN06": True, "IN07": True, "IN08": True, "IN09": True, "IN10": True, "IN11": True, + "IN00": True, + "IN01": True, + "IN02": True, + "IN03": True, + "IN04": True, + "IN05": True, + "IN06": True, + "IN07": True, + "IN08": True, + "IN09": True, + "IN10": True, + "IN11": True, "MID": True, - "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": True, + "OUT10": True, + "OUT11": True, } assert len([v for v in LAYER12.values() if v]) == 12 @@ -145,6 +227,33 @@ def save_to_file(file_name, state_dict, dtype, metadata): torch.save(state_dict, file_name) +def format_lbws(lbws): + try: + # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している + lbws = [json.loads(lbw) for lbw in lbws] + except Exception: + raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") + assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" + assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" + assert all( + len(lbw) in ACCEPTABLE for lbw in lbws + ), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" + assert all( + all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws + ), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" + + layer_num = len(lbws[0]) + is_sdxl = True if layer_num in SDXL_LAYER_NUM else False + FLAGS = { + "12": LAYER12.values(), + "17": LAYER17.values(), + "20": LAYER20.values(), + "26": LAYER26.values(), + }[str(layer_num)] + LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] + return lbws, is_sdxl, LBW_TARGET_IDX + + def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype): logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} @@ -152,25 +261,10 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer base_model = None if lbws: - try: - # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している - lbws = [json.loads(lbw) for lbw in lbws] - except Exception: - raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") - assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" - assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" - assert all(len(lbw) in ACCEPTABLE for lbw in lbws), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" - assert all(all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" - - layer_num = len(lbws[0]) - is_sdxl = True if layer_num in SDXL_LAYER_NUM else False - FLAGS = { - "12": LAYER12.values(), - "17": LAYER17.values(), - "20": LAYER20.values(), - "26": LAYER26.values(), - }[str(layer_num)] - LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] + lbws, is_sdxl, LBW_TARGET_IDX = format_lbws(lbws) + else: + is_sdxl = False + LBW_TARGET_IDX = [] for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") @@ -186,7 +280,7 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer lbw_weights = [1] * 26 for index, value in zip(LBW_TARGET_IDX, lbw): lbw_weights[index] = value - print(dict(zip(LAYER26.keys(), lbw_weights))) + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") # merge logger.info(f"merging...") @@ -306,9 +400,13 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert len(args.models) == len( + args.ratios + ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" if args.lbws: - assert len(args.models) == len(args.lbws), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + assert len(args.models) == len( + args.lbws + ), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" else: args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく @@ -372,10 +470,16 @@ def setup_parser() -> argparse.ArgumentParser: help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument( - "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + "--models", + type=str, + nargs="*", + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") @@ -386,7 +490,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", ) - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) parser.add_argument( "--no_metadata", action="store_true", From b755ebd0a4dd2967171b6b5909624325359a2aa0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 21:29:31 +0900 Subject: [PATCH 119/163] add LBW support for SDXL merge LoRA --- README.md | 14 +++++-- networks/sdxl_merge_lora.py | 75 ++++++++++++++++++++++++++++++++----- 2 files changed, 77 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index d5d2a7f73..0be2f9a70 100644 --- a/README.md +++ b/README.md @@ -139,9 +139,17 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Sep 13, 2024 / 2024-09-13: -- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). Will be included in the next release. - -- `sdxl_merge_lora.py` が OFT をサポートしました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。次のリリースに含まれます。 +- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). +- `svd_merge_lora.py` now supports LBW. Thanks to terracottahaniwa. See PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) for details. +- `sdxl_merge_lora.py` also supports LBW. +- See [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) by hako-mikan for details on LBW. +- These will be included in the next release. + +- `sdxl_merge_lora.py` が OFT をサポートされました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。 +- `svd_merge_lora.py` で LBW がサポートされました。PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) terracottahaniwa 氏に感謝します。 +- `sdxl_merge_lora.py` でも LBW がサポートされました。 +- LBW の詳細は hako-mikan 氏の [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) をご覧ください。 +- 以上は次回リリースに含まれます。 ### Jun 23, 2024 / 2024-06-23: diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index d5b6f7f34..62f5a87d4 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -1,7 +1,9 @@ +import itertools import math import argparse import os import time +import concurrent.futures import torch from safetensors.torch import load_file, save_file from tqdm import tqdm @@ -9,13 +11,13 @@ import library.model_util as model_util import lora import oft +from svd_merge_lora import format_lbws, get_lbw_block_index, LAYER26 from library.utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) -import concurrent.futures def load_state_dict(file_name, dtype): @@ -47,6 +49,7 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): def detect_method_from_training_model(models, dtype): for model in models: + # TODO It is better to use key names to detect the method lora_sd, _ = load_state_dict(model, dtype) for key in tqdm(lora_sd.keys()): if "lora_up" in key or "lora_down" in key: @@ -55,15 +58,20 @@ def detect_method_from_training_model(models, dtype): return "OFT" -def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): - text_encoder1.to(merge_dtype) +def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, merge_dtype): text_encoder1.to(merge_dtype) + text_encoder2.to(merge_dtype) unet.to(merge_dtype) # detect the method: OFT or LoRA_module method = detect_method_from_training_model(models, merge_dtype) logger.info(f"method:{method}") + if lbws: + lbws, _, LBW_TARGET_IDX = format_lbws(lbws) + else: + LBW_TARGET_IDX = [] + # create module map name_to_module = {} for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): @@ -94,12 +102,18 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ lora_name = lora_name.replace(".", "_") name_to_module[lora_name] = child_module - for model, ratio in zip(models, ratios): + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) logger.info(f"merging...") + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") + if method == "LoRA": for key in tqdm(lora_sd.keys()): if "lora_down" in key: @@ -121,6 +135,12 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ alpha = lora_sd.get(alpha_key, dim) scale = alpha / dim + if lbw: + index = get_lbw_block_index(key, True) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + # W <- W + U * D weight = module.weight # logger.info(module_name, down_weight.size(), up_weight.size()) @@ -145,7 +165,6 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ elif method == "OFT": - multiplier = 1.0 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") for key in tqdm(lora_sd.keys()): @@ -183,6 +202,13 @@ def merge_to(key): block_size = out_dim // dim constraint = (0 if alpha is None else alpha) * out_dim + multiplier = 1 + if lbw: + index = get_lbw_block_index(key, False) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + multiplier *= lbw_weights[index] + block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=constraint) @@ -213,17 +239,35 @@ def merge_to(key): list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) -def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): +def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=False): base_alphas = {} # alpha for merged model base_dims = {} + # detect the method: OFT or LoRA_module + method = detect_method_from_training_model(models, merge_dtype) + if method == "OFT": + raise ValueError( + "OFT model is not supported for merging OFT models. / OFTモデルはOFTモデル同士のマージには対応していません" + ) + + if lbws: + lbws, _, LBW_TARGET_IDX = format_lbws(lbws) + else: + LBW_TARGET_IDX = [] + merged_sd = {} v2 = None base_model = None - for model, ratio in zip(models, ratios): + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") + if lora_metadata is not None: if v2 is None: v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず @@ -277,6 +321,12 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): scale = math.sqrt(alpha / base_alpha) * ratio scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + if lbw: + index = get_lbw_block_index(key, True) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + if key in merged_sd: assert ( merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None @@ -329,6 +379,12 @@ def merge(args): assert len(args.models) == len( args.ratios ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + if args.lbws: + assert len(args.models) == len( + args.lbws + ), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + else: + args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく def str_to_dtype(p): if p == "float": @@ -356,7 +412,7 @@ def str_to_dtype(p): ckpt_info, ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu") - merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype) + merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, args.lbws, merge_dtype) if args.no_metadata: sai_metadata = None @@ -372,7 +428,7 @@ def str_to_dtype(p): args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype ) else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle) logger.info(f"calculating hashes and creating metadata...") @@ -427,6 +483,7 @@ def setup_parser() -> argparse.ArgumentParser: help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") parser.add_argument( "--no_metadata", action="store_true", From 93d9fbf60761fc1158e37f45f0d0c142913d70f5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 22:37:11 +0900 Subject: [PATCH 120/163] improve OFT implementation closes #944 --- README.md | 26 ++++++++- gen_img.py | 3 +- networks/check_lora_weights.py | 2 +- networks/oft.py | 96 +++++++++++++++++++++------------- 4 files changed, 89 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 0130ccffc..def528a22 100644 --- a/README.md +++ b/README.md @@ -143,7 +143,31 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. -- en: The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds! +- Improvements in OFT (Orthogonal Finetuning) Implementation + 1. Optimization of Calculation Order: + - Changed the calculation order in the forward method from (Wx)R to W(xR). + - This has improved computational efficiency and processing speed. + 2. Correction of Bias Application: + - In the previous implementation, R was incorrectly applied to the bias. + - The new implementation now correctly handles bias by using F.conv2d and F.linear. + 3. Efficiency Enhancement in Matrix Operations: + - Introduced einsum in both the forward and merge_to methods. + - This has optimized matrix operations, resulting in further speed improvements. + 4. Proper Handling of Data Types: + - Improved to use torch.float32 during calculations and convert results back to the original data type. + - This maintains precision while ensuring compatibility with the original model. + 5. Unified Processing for Conv2d and Linear Layers: + - Implemented a consistent method for applying OFT to both layer types. + - These changes have made the OFT implementation more efficient and accurate, potentially leading to improved model performance and training stability. + + - Additional Information + * Recommended α value for OFT constraint: We recommend using α values between 1e-4 and 1e-2. This differs slightly from the original implementation of "(α\*out_dim\*out_dim)". Our implementation uses "(α\*out_dim)", hence we recommend higher values than the 1e-5 suggested in the original implementation. + + * Performance Improvement: Training speed has been improved by approximately 30%. + + * Inference Environment: This implementation is compatible with and operates within Stable Diffusion web UI (SD1/2 and SDXL). + +- The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds! - See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler. - `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc. diff --git a/gen_img.py b/gen_img.py index d0a8f8141..59bcd5b09 100644 --- a/gen_img.py +++ b/gen_img.py @@ -86,7 +86,8 @@ """ -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): +# def replace_unet_modules(unet: diffusers.models.unets.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): +def replace_unet_modules(unet, mem_eff_attn, xformers, sdpa): if mem_eff_attn: logger.info("Enable memory efficient attention for U-Net") diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 794659c94..f8eab53ba 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -18,7 +18,7 @@ def main(file): keys = list(sd.keys()) for key in keys: - if "lora_up" in key or "lora_down" in key: + if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key: values.append((key, sd[key])) print(f"number of LoRA modules: {len(values)}") diff --git a/networks/oft.py b/networks/oft.py index 461a98698..6321def3b 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -4,13 +4,17 @@ import os from typing import Dict, List, Optional, Tuple, Type, Union from diffusers import AutoencoderKL +import einops from transformers import CLIPTextModel import numpy as np import torch +import torch.nn.functional as F import re from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -45,11 +49,16 @@ def __init__( if type(alpha) == torch.Tensor: alpha = alpha.detach().numpy() - self.constraint = alpha * out_dim + + # constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility + # original alpha is 1e-6, so we use 1e-3 or 1e-4 for alpha + self.constraint = alpha * out_dim + self.register_buffer("alpha", torch.tensor(alpha)) self.block_size = out_dim // self.num_blocks self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size)) + self.I = torch.eye(self.block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) # cpu self.out_dim = out_dim self.shape = org_module.weight.shape @@ -69,27 +78,36 @@ def get_weight(self, multiplier=None): norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I - R = torch.block_diag(*block_R_weighted) - - return R + if self.I.device != block_Q.device: + self.I = self.I.to(block_Q.device) + I = self.I + block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + block_R_weighted = self.multiplier * (block_R - I) + I + return block_R_weighted def forward(self, x, scale=None): - x = self.org_forward(x) if self.multiplier == 0.0: - return x - - R = self.get_weight().to(x.device, dtype=x.dtype) - if x.dim() == 4: - x = x.permute(0, 2, 3, 1) - x = torch.matmul(x, R) - x = x.permute(0, 3, 1, 2) - else: - x = torch.matmul(x, R) - return x + return self.org_forward(x) + org_module = self.org_module[0] + org_dtype = x.dtype + + R = self.get_weight().to(torch.float32) + W = org_module.weight.to(torch.float32) + + if len(W.shape) == 4: # Conv2d + W_reshaped = einops.rearrange(W, "(k n) ... -> k n ...", k=self.num_blocks, n=self.block_size) + RW = torch.einsum("k n m, k n ... -> k m ...", R, W_reshaped) + RW = einops.rearrange(RW, "k m ... -> (k m) ...") + result = F.conv2d( + x, RW.to(org_dtype), org_module.bias, org_module.stride, org_module.padding, org_module.dilation, org_module.groups + ) + else: # Linear + W_reshaped = einops.rearrange(W, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size) + RW = torch.einsum("k n m, k n p -> k m p", R, W_reshaped) + RW = einops.rearrange(RW, "k m p -> (k m) p") + result = F.linear(x, RW.to(org_dtype), org_module.bias) + return result class OFTInfModule(OFTModule): @@ -115,18 +133,19 @@ def forward(self, x, scale=None): return self.org_forward(x) return super().forward(x, scale) - def merge_to(self, multiplier=None, sign=1): - R = self.get_weight(multiplier) * sign - + def merge_to(self, multiplier=None): # get org weight org_sd = self.org_module[0].state_dict() - org_weight = org_sd["weight"] - R = R.to(org_weight.device, dtype=org_weight.dtype) + org_weight = org_sd["weight"].to(torch.float32) - if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R) - else: - weight = torch.einsum("oi, op -> pi", org_weight, R) + R = self.get_weight(multiplier).to(torch.float32) + + weight = org_weight.reshape(self.num_blocks, self.block_size, -1) + weight = torch.einsum("k n m, k n ... -> k m ...", R, weight) + weight = weight.reshape(org_weight.shape) + + # convert back to original dtype + weight = weight.to(org_sd["weight"].dtype) # set weight to org_module org_sd["weight"] = weight @@ -145,8 +164,16 @@ def create_network( ): if network_dim is None: network_dim = 4 # default - if network_alpha is None: - network_alpha = 1.0 + if network_alpha is None: # should be set + logger.info( + "network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します" + ) + network_alpha = 1e-3 + elif network_alpha >= 1: + logger.warning( + "network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3" + " / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨" + ) enable_all_linear = kwargs.get("enable_all_linear", None) enable_conv = kwargs.get("enable_conv", None) @@ -190,12 +217,11 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh else: if dim is None: dim = param.size()[0] - if has_conv2d is None and param.dim() == 4: + if has_conv2d is None and "in_layers_2" in name: has_conv2d = True - if all_linear is None: - if param.dim() == 3 and "attn" not in name: - all_linear = True - if dim is not None and alpha is not None and has_conv2d is not None: + if all_linear is None and "_ff_" in name: + all_linear = True + if dim is not None and alpha is not None and has_conv2d is not None and all_linear is not None: break if has_conv2d is None: has_conv2d = False @@ -241,7 +267,7 @@ def __init__( self.alpha = alpha logger.info( - f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}" + f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}, enable_all_linear: {enable_all_linear}" ) # create module instances From 2d8ee3c28007393386528cfeec0a9b714dafd85b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 14 Sep 2024 15:48:16 +0900 Subject: [PATCH 121/163] OFT for FLUX.1 --- flux_minimal_inference.py | 20 +- networks/lora_flux.py | 6 +- networks/oft.py | 2 +- networks/oft_flux.py | 482 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 504 insertions(+), 6 deletions(-) create mode 100644 networks/oft_flux.py diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index de607c52a..2f1b9a377 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -14,9 +14,11 @@ from PIL import Image import accelerate from transformers import CLIPTextModel +from safetensors.torch import load_file from library import device_utils from library.device_utils import init_ipex, get_preferred_device +from networks import oft_flux init_ipex() @@ -405,7 +407,7 @@ def encode(prpt: str): type=str, nargs="*", default=[], - help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)", + help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)", ) parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") parser.add_argument("--width", type=int, default=target_width) @@ -482,9 +484,19 @@ def is_fp8(dt): else: multiplier = 1.0 - lora_model, weights_sd = lora_flux.create_network_from_weights( - multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True - ) + weights_sd = load_file(weights_file) + is_lora = is_oft = False + for key in weights_sd.keys(): + if key.startswith("lora"): + is_lora = True + if key.startswith("oft"): + is_oft = True + if is_lora or is_oft: + break + + module = lora_flux if is_lora else oft_flux + lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True) + if args.merge_lora_weights: lora_model.merge_to([clip_l, t5xxl], model, weights_sd) else: diff --git a/networks/lora_flux.py b/networks/lora_flux.py index dd267de0f..ea7df8b4d 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -41,7 +41,11 @@ def __init__( module_dropout=None, split_dims: Optional[List[int]] = None, ): - """if alpha == 0 or None, alpha is rank (no scaling).""" + """ + if alpha == 0 or None, alpha is rank (no scaling). + + split_dims is used to mimic the split qkv of FLUX as same as Diffusers + """ super().__init__() self.lora_name = lora_name diff --git a/networks/oft.py b/networks/oft.py index 6321def3b..0c3a5393f 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -51,7 +51,7 @@ def __init__( alpha = alpha.detach().numpy() # constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility - # original alpha is 1e-6, so we use 1e-3 or 1e-4 for alpha + # original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha self.constraint = alpha * out_dim self.register_buffer("alpha", torch.tensor(alpha)) diff --git a/networks/oft_flux.py b/networks/oft_flux.py new file mode 100644 index 000000000..27b8b637a --- /dev/null +++ b/networks/oft_flux.py @@ -0,0 +1,482 @@ +# OFT network module + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +import einops +from transformers import CLIPTextModel +import numpy as np +import torch +import torch.nn.functional as F +import re +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class OFTModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + split_dims: Optional[List[int]] = None, + ): + """ + dim -> num blocks + alpha -> constraint + + split_dims is used to mimic the split qkv of FLUX as same as Diffusers + """ + super().__init__() + self.oft_name = oft_name + self.num_blocks = dim + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().numpy() + self.register_buffer("alpha", torch.tensor(alpha)) + + # No conv2d in FLUX + # if "Linear" in org_module.__class__.__name__: + self.out_dim = org_module.out_features + # elif "Conv" in org_module.__class__.__name__: + # out_dim = org_module.out_channels + + if split_dims is None: + split_dims = [self.out_dim] + else: + assert sum(split_dims) == self.out_dim, "sum of split_dims must be equal to out_dim" + self.split_dims = split_dims + + # assert all dim is divisible by num_blocks + for split_dim in self.split_dims: + assert split_dim % self.num_blocks == 0, "split_dim must be divisible by num_blocks" + + self.constraint = [alpha * split_dim for split_dim in self.split_dims] + self.block_size = [split_dim // self.num_blocks for split_dim in self.split_dims] + self.oft_blocks = torch.nn.ParameterList( + [torch.nn.Parameter(torch.zeros(self.num_blocks, block_size, block_size)) for block_size in self.block_size] + ) + self.I = [torch.eye(block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) for block_size in self.block_size] + + self.shape = org_module.weight.shape + self.multiplier = multiplier + self.org_module = [org_module] # moduleにならないようにlistに入れる + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + if self.I[0].device != self.oft_blocks[0].device: + self.I = [I.to(self.oft_blocks[0].device) for I in self.I] + + block_R_weighted_list = [] + for i in range(len(self.oft_blocks)): + block_Q = self.oft_blocks[i] - self.oft_blocks[i].transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint[i]) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + + I = self.I[i] + block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + block_R_weighted = self.multiplier * (block_R - I) + I + + block_R_weighted_list.append(block_R_weighted) + + return block_R_weighted_list + + def forward(self, x, scale=None): + if self.multiplier == 0.0: + return self.org_forward(x) + + org_module = self.org_module[0] + org_dtype = x.dtype + + R = self.get_weight() + W = org_module.weight.to(torch.float32) + B = org_module.bias.to(torch.float32) + + # split W to match R + results = [] + d2 = 0 + for i in range(len(R)): + d1 = d2 + d2 += self.split_dims[i] + + W1 = W[d1:d2] + W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) + RW_1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) + RW_1 = einops.rearrange(RW_1, "k m p -> (k m) p") + + B1 = B[d1:d2] + result = F.linear(x, RW_1.to(org_dtype), B1.to(org_dtype)) + results.append(result) + + result = torch.cat(results, dim=-1) + return result + + +class OFTInfModule(OFTModule): + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + split_dims: Optional[List[int]] = None, + **kwargs, + ): + # no dropout for inference + super().__init__(oft_name, org_module, multiplier, dim, alpha, split_dims) + self.enabled = True + self.network: OFTNetwork = None + + def set_network(self, network): + self.network = network + + def forward(self, x, scale=None): + if not self.enabled: + return self.org_forward(x) + return super().forward(x, scale) + + def merge_to(self, multiplier=None): + # get org weight + org_sd = self.org_module[0].state_dict() + W = org_sd["weight"].to(torch.float32) + R = self.get_weight(multiplier).to(torch.float32) + + d2 = 0 + W_list = [] + for i in range(len(self.oft_blocks)): + d1 = d2 + d2 += self.split_dims[i] + + W1 = W[d1:d2] + W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) + W1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) + W1 = einops.rearrange(W1, "k m p -> (k m) p") + + W_list.append(W1) + + W = torch.cat(W_list, dim=-1) + + # convert back to original dtype + W = W.to(org_sd["weight"].dtype) + + # set weight to org_module + org_sd["weight"] = W + self.org_module[0].load_state_dict(org_sd) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: # should be set + logger.info( + "network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します" + ) + network_alpha = 1e-3 + elif network_alpha >= 1: + logger.warning( + "network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3" + " / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨" + ) + + # attn only or all linear (FFN) layers + enable_all_linear = kwargs.get("enable_all_linear", None) + # enable_conv = kwargs.get("enable_conv", None) + if enable_all_linear is not None: + enable_all_linear = bool(enable_all_linear) + # if enable_conv is not None: + # enable_conv = bool(enable_conv) + + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=network_dim, + alpha=network_alpha, + enable_all_linear=enable_all_linear, + varbose=True, + ) + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # check dim, alpha and if weights have for conv2d + dim = None + alpha = None + all_linear = None + for name, param in weights_sd.items(): + if name.endswith(".alpha"): + if alpha is None: + alpha = param.item() + elif "qkv" in name: + continue # ignore qkv + else: + if dim is None: + dim = param.size()[0] + if all_linear is None and "_mlp" in name: + all_linear = True + if dim is not None and alpha is not None and all_linear is not None: + break + if all_linear is None: + all_linear = False + + module_class = OFTInfModule if for_inference else OFTModule + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=dim, + alpha=alpha, + enable_all_linear=all_linear, + module_class=module_class, + ) + return network, weights_sd + + +class OFTNetwork(torch.nn.Module): + FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR = ["DoubleStreamBlock", "SingleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY = ["SelfAttention"] + OFT_PREFIX_UNET = "oft_unet" + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + dim: int = 4, + alpha: float = 1, + enable_all_linear: Optional[bool] = False, + module_class: Union[Type[OFTModule], Type[OFTInfModule]] = OFTModule, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.train_t5xxl = False # make compatible with LoRA + self.multiplier = multiplier + + self.dim = dim + self.alpha = alpha + + logger.info( + f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_all_linear: {enable_all_linear}" + ) + + # create module instances + def create_modules( + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[OFTModule]: + prefix = self.OFT_PREFIX_UNET + ofts = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = "Linear" in child_module.__class__.__name__ + + if is_linear: + oft_name = prefix + "." + name + "." + child_name + oft_name = oft_name.replace(".", "_") + # logger.info(oft_name) + + if "double" in oft_name and "qkv" in oft_name: + split_dims = [3072] * 3 + elif "single" in oft_name and "linear1" in oft_name: + split_dims = [3072] * 3 + [12288] + else: + split_dims = None + + oft = module_class(oft_name, child_module, self.multiplier, dim, alpha, split_dims) + ofts.append(oft) + return ofts + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + if enable_all_linear: + target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR + else: + target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY + + self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) + logger.info(f"create OFT for Flux: {len(self.unet_ofts)} modules.") + + # assertion + names = set() + for oft in self.unet_ofts: + assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}" + names.add(oft.oft_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for oft in self.unet_ofts: + oft.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + assert apply_unet, "apply_unet must be True" + + for oft in self.unet_ofts: + oft.apply_to() + self.add_module(oft.oft_name, oft) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + logger.info("enable OFT for U-Net") + + for oft in self.unet_ofts: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(oft.oft_name): + sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key] + oft.load_state_dict(sd_for_lora, False) + oft.merge_to() + + logger.info(f"weights are merged") + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(ofts): + params = [] + for oft in ofts: + params.extend(oft.parameters()) + + # logger.info num of params + num_params = 0 + for p in params: + num_params += p.numel() + logger.info(f"OFT params: {num_params}") + return params + + param_data = {"params": enumerate_params(self.unet_ofts)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + oft.merge_to() + # sd = org_module.state_dict() + # org_weight = sd["weight"] + # lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype) + # sd["weight"] = org_weight + lora_weight + # assert sd["weight"].shape == org_weight.shape + # org_module.load_state_dict(sd) + + org_module._lora_restored = False + oft.enabled = False From c9ff4de90597e933b441502d45c175fe46b99714 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 14 Sep 2024 22:17:52 +0900 Subject: [PATCH 122/163] Add support for specifying rank for each layer in FLUX.1 --- README.md | 61 ++++++++++++++++++++++++ networks/lora_flux.py | 107 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 161 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 6e32fa31d..9a9794796 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 14, 2024: +- You can now specify the rank for each layer in FLUX.1. See [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) for details. +- OFT is now supported with FLUX.1. See [FLUX.1 OFT training](#flux1-oft-training) for details. + Sep 11, 2024: Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev! @@ -46,6 +50,7 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. ` - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) +- [FLUX.1 OFT training](#flux1-oft-training) - [FLUX.1 fine-tuning](#flux1-fine-tuning) - [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning) - [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models) @@ -191,6 +196,62 @@ In the implementation of Black Forest Labs' model, the projection layers of q/k/ The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. +#### Specify rank for each layer in FLUX.1 + +You can specify the rank for each layer in FLUX.1 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer. + +When network_args is not specified, the default value (`network_dim`) is applied, same as before. + +|network_args|target layer| +|---|---| +|img_attn_dim|img_attn in DoubleStreamBlock| +|txt_attn_dim|txt_attn in DoubleStreamBlock| +|img_mlp_dim|img_mlp in DoubleStreamBlock| +|txt_mlp_dim|txt_mlp in DoubleStreamBlock| +|img_mod_dim|img_mod in DoubleStreamBlock| +|txt_mod_dim|txt_mod in DoubleStreamBlock| +|single_dim|linear1 and linear2 in SingleStreamBlock| +|single_mod_dim|modulation in SingleStreamBlock| + +example: +``` +--network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2" +"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" +``` + +You can apply LoRA to the conditioning layers of Flux by specifying `in_dims` in network_args. When specifying, be sure to specify 5 numbers in `[]` as a comma-separated list. + +example: +``` +--network_args "in_dims=[4,2,2,2,4]" +``` + +Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt_in`. The above example applies LoRA to all conditioning layers, with rank 4 for `img_in`, 2 for `time_in`, `vector_in`, `guidance_in`, and 4 for `txt_in`. + +If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`. + +### FLUX.1 OFT training + +You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different. + +- Change `--network_module` from `networks.lora_flux` to `networks.oft_flux`. +- `--network_dim` is the number of OFT blocks. Unlike LoRA rank, the smaller the dim, the larger the model. We recommend about 64 or 128. Please make the output dimension of the target layer of OFT divisible by the value of `--network_dim` (an error will occur if it is not divisible). Valid values are 64, 128, 256, 512, 1024, etc. +- `--network_alpha` is treated as a constraint for OFT. We recommend about 1e-2 to 1e-4. The default value when omitted is 1, which is too large, so be sure to specify it. +- CLIP/T5XXL is not supported. Specify `--network_train_unet_only`. +- `--network_args` specifies the hyperparameters of OFT. The following are valid: + - Specify `enable_all_linear=True` to target all linear connections in the MLP layer. The default is False, which targets only attention. + +Currently, there is no environment to infer FLUX.1 OFT. Inference is only possible with `flux_minimal_inference.py` (specify OFT model with `--lora`). + +Sample command is below. It will work with 24GB VRAM GPUs with the batch size of 1. + +``` +--network_module networks.oft_flux --network_dim 128 --network_alpha 1e-3 +--network_args "enable_all_linear=True" --learning_rate 1e-5 +``` + +The training can be done with 16GB VRAM GPUs without `--enable_all_linear` option and with Adafactor optimizer. + ### Inference for FLUX.1 with LoRA model The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. diff --git a/networks/lora_flux.py b/networks/lora_flux.py index ea7df8b4d..a34cde1a8 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -316,6 +316,44 @@ def create_network( else: conv_alpha = float(conv_alpha) + # attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv + img_attn_dim = kwargs.get("img_attn_dim", None) + txt_attn_dim = kwargs.get("txt_attn_dim", None) + img_mlp_dim = kwargs.get("img_mlp_dim", None) + txt_mlp_dim = kwargs.get("txt_mlp_dim", None) + img_mod_dim = kwargs.get("img_mod_dim", None) + txt_mod_dim = kwargs.get("txt_mod_dim", None) + single_dim = kwargs.get("single_dim", None) # SingleStreamBlock + single_mod_dim = kwargs.get("single_mod_dim", None) # SingleStreamBlock + if img_attn_dim is not None: + img_attn_dim = int(img_attn_dim) + if txt_attn_dim is not None: + txt_attn_dim = int(txt_attn_dim) + if img_mlp_dim is not None: + img_mlp_dim = int(img_mlp_dim) + if txt_mlp_dim is not None: + txt_mlp_dim = int(txt_mlp_dim) + if img_mod_dim is not None: + img_mod_dim = int(img_mod_dim) + if txt_mod_dim is not None: + txt_mod_dim = int(txt_mod_dim) + if single_dim is not None: + single_dim = int(single_dim) + if single_mod_dim is not None: + single_mod_dim = int(single_mod_dim) + type_dims = [img_attn_dim, txt_attn_dim, img_mlp_dim, txt_mlp_dim, img_mod_dim, txt_mod_dim, single_dim, single_mod_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # in_dims [img, time, vector, guidance, txt] + in_dims = kwargs.get("in_dims", None) + if in_dims is not None: + in_dims = in_dims.strip() + if in_dims.startswith("[") and in_dims.endswith("]"): + in_dims = in_dims[1:-1] + in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval? + assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)" + # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) if rank_dropout is not None: @@ -339,6 +377,11 @@ def create_network( if train_t5xxl is not None: train_t5xxl = True if train_t5xxl == "True" else False + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -354,7 +397,9 @@ def create_network( train_blocks=train_blocks, split_qkv=split_qkv, train_t5xxl=train_t5xxl, - varbose=True, + type_dims=type_dims, + in_dims=in_dims, + verbose=verbose, ) loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) @@ -462,7 +507,9 @@ def __init__( train_blocks: Optional[str] = None, split_qkv: bool = False, train_t5xxl: bool = False, - varbose: Optional[bool] = False, + type_dims: Optional[List[int]] = None, + in_dims: Optional[List[int]] = None, + verbose: Optional[bool] = False, ) -> None: super().__init__() self.multiplier = multiplier @@ -478,12 +525,17 @@ def __init__( self.split_qkv = split_qkv self.train_t5xxl = train_t5xxl + self.type_dims = type_dims + self.in_dims = in_dims + self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None self.loraplus_text_encoder_lr_ratio = None if modules_dim is not None: logger.info(f"create LoRA network from weights") + self.in_dims = [0] * 5 # create in_dims + # verbose = True else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") logger.info( @@ -502,7 +554,12 @@ def __init__( # create module instances def create_modules( - is_flux: bool, text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str] + is_flux: bool, + text_encoder_idx: Optional[int], + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, ) -> List[LoRAModule]: prefix = ( self.LORA_PREFIX_FLUX @@ -513,16 +570,22 @@ def create_modules( loras = [] skipped = [] for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # dirty hack for all modules + module = root_module # search all modules + for child_name, child_module in module.named_modules(): is_linear = child_module.__class__.__name__ == "Linear" is_conv2d = child_module.__class__.__name__ == "Conv2d" is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) if is_linear or is_conv2d: - lora_name = prefix + "." + name + "." + child_name + lora_name = prefix + "." + (name + "." if name else "") + child_name lora_name = lora_name.replace(".", "_") + if filter is not None and not filter in lora_name: + continue + dim = None alpha = None @@ -534,8 +597,25 @@ def create_modules( else: # 通常、すべて対象とする if is_linear or is_conv2d_1x1: - dim = self.lora_dim + dim = default_dim if default_dim is not None else self.lora_dim alpha = self.alpha + + if type_dims is not None: + identifier = [ + ("img_attn",), + ("txt_attn",), + ("img_mlp",), + ("txt_mlp",), + ("img_mod",), + ("txt_mod",), + ("single_blocks", "linear"), + ("modulation",), + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d + break + elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha @@ -566,6 +646,9 @@ def create_modules( split_dims=split_dims, ) loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched return loras, skipped # create LoRA for text encoder @@ -594,10 +677,20 @@ def create_modules( self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules) + + # img, time, vector, guidance, txt + if self.in_dims: + for filter, in_dim in zip(["_img_in", "_time_in", "_vector_in", "_guidance_in", "_txt_in"], self.in_dims): + loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) + self.unet_loras.extend(loras) + logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") skipped = skipped_te + skipped_un - if varbose and len(skipped) > 0: + if verbose and len(skipped) > 0: logger.warning( f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) From 6445bb2bc974cec51256ae38c1be0900e90e6f87 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 14 Sep 2024 22:37:26 +0900 Subject: [PATCH 123/163] update README --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9a9794796..c94ea3598 100644 --- a/README.md +++ b/README.md @@ -213,10 +213,12 @@ When network_args is not specified, the default value (`network_dim`) is applied |single_dim|linear1 and linear2 in SingleStreamBlock| |single_mod_dim|modulation in SingleStreamBlock| +`"verbose=True"` is also available for debugging. It shows the rank of each layer. + example: ``` --network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2" -"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" +"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" "verbose=True" ``` You can apply LoRA to the conditioning layers of Flux by specifying `in_dims` in network_args. When specifying, be sure to specify 5 numbers in `[]` as a comma-separated list. From 9f44ef133083c530874c6cf022a4de8fda3edae2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 15 Sep 2024 13:52:23 +0900 Subject: [PATCH 124/163] add diffusers to FLUX.1 conversion script --- README.md | 19 ++- tools/convert_diffusers_to_flux.py | 223 +++++++++++++++++++++++++++++ 2 files changed, 241 insertions(+), 1 deletion(-) create mode 100644 tools/convert_diffusers_to_flux.py diff --git a/README.md b/README.md index c94ea3598..7d6c336e6 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,12 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 15, 2024: + +Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported. + +The implementation is based on 2kpr's code. Thanks to 2kpr! + Sep 14, 2024: - You can now specify the rank for each layer in FLUX.1. See [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) for details. - OFT is now supported with FLUX.1. See [FLUX.1 OFT training](#flux1-oft-training) for details. @@ -57,6 +63,7 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. ` - [Convert FLUX LoRA](#convert-flux-lora) - [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) - [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) +- [Convert Diffusers to FLUX.1](#convert-diffusers-to-flux1) ### FLUX.1 LoRA training @@ -355,7 +362,7 @@ If you use LoRA in the inference environment, converting it to AI-toolkit format Note that re-conversion will increase the size of LoRA. -CLIP-L LoRA is not supported. +CLIP-L/T5XXL LoRA is not supported. ### Merge LoRA to FLUX.1 checkpoint @@ -435,6 +442,16 @@ resolution = [512, 512] num_repeats = 1 ``` +### Convert Diffusers to FLUX.1 + +Script: `convert_diffusers_to_flux1.py` + +Converts Diffusers models to FLUX.1 models. The script is experimental. See `--help` for options. schnell and dev models are supported. AE/CLIP/T5XXL are not supported. The diffusers folder is a parent folder of `transfomer` folder. + +``` +python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_folder_or_00001_safetensors --save_to path/to/flux1.safetensors --mem_eff_load_save --save_precision bf16 +``` + ## SD3 training SD3 training is done with `sd3_train.py`. diff --git a/tools/convert_diffusers_to_flux.py b/tools/convert_diffusers_to_flux.py new file mode 100644 index 000000000..9d8f7c74b --- /dev/null +++ b/tools/convert_diffusers_to_flux.py @@ -0,0 +1,223 @@ +# This script converts the diffusers of a Flux model to a safetensors file of a Flux.1 model. +# It is based on the implementation by 2kpr. Thanks to 2kpr! +# Major changes: +# - Iterates over three safetensors files to reduce memory usage, not loading all tensors at once. +# - Makes reverse map from diffusers map to avoid loading all tensors. +# - Removes dependency on .json file for weights mapping. +# - Adds support for custom memory efficient load and save functions. +# - Supports saving with different precision. +# - Supports .safetensors file as input. + +# Copyright 2024 2kpr. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import os +from pathlib import Path +import safetensors +from safetensors.torch import safe_open +import torch +from tqdm import tqdm + +from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + +BFL_TO_DIFFUSERS_MAP = { + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +def convert(args): + # if diffusers_path is folder, get safetensors file + diffusers_path = Path(args.diffusers_path) + if diffusers_path.is_dir(): + diffusers_path = Path.joinpath(diffusers_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") + + flux_path = Path(args.save_to) + if not os.path.exists(flux_path.parent): + os.makedirs(flux_path.parent) + + if not diffusers_path.exists(): + logger.error(f"Error: Missing transformer safetensors file: {diffusers_path}") + return + + mem_eff_flag = args.mem_eff_load_save + save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None + + # make reverse map from diffusers map + diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) + for b in range(NUM_DOUBLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("double_blocks."): + block_prefix = f"transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for b in range(NUM_SINGLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("single_blocks."): + block_prefix = f"single_transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): + for i, weight in enumerate(weights): + diffusers_to_bfl_map[weight] = (i, key) + + # iterate over three safetensors files to reduce memory usage + flux_sd = {} + for i in range(3): + # replace 00001 with 0000i + current_diffusers_path = Path(str(diffusers_path).replace("00001", f"0000{i+1}")) + logger.info(f"Loading diffusers file: {current_diffusers_path}") + + open_func = MemoryEfficientSafeOpen if mem_eff_flag else (lambda x: safe_open(x, framework="pt")) + with open_func(current_diffusers_path) as f: + for diffusers_key in tqdm(f.keys()): + if diffusers_key in diffusers_to_bfl_map: + tensor = f.get_tensor(diffusers_key).to("cpu") + if save_dtype is not None: + tensor = tensor.to(save_dtype) + + index, bfl_key = diffusers_to_bfl_map[diffusers_key] + if bfl_key not in flux_sd: + flux_sd[bfl_key] = [] + flux_sd[bfl_key].append((index, tensor)) + else: + logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}") + return + + # concat tensors if multiple tensors are mapped to a single key, sort by index + for key, values in flux_sd.items(): + if len(values) == 1: + flux_sd[key] = values[0][1] + else: + flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])]) + + # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + if "final_layer.adaLN_modulation.1.weight" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"]) + if "final_layer.adaLN_modulation.1.bias" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"]) + + # save flux_sd to safetensors file + logger.info(f"Saving Flux safetensors file: {flux_path}") + if mem_eff_flag: + mem_eff_save_file(flux_sd, flux_path) + else: + safetensors.torch.save_file(flux_sd, flux_path) + + logger.info("Conversion completed.") + + +def setup_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--diffusers_path", + default=None, + type=str, + required=True, + help="Path to the original Flux diffusers folder or *-00001-of-00003.safetensors file." + " / 元のFlux diffusersフォルダーまたは*-00001-of-00003.safetensorsファイルへのパス", + ) + parser.add_argument( + "--save_to", + default=None, + type=str, + required=True, + help="Output path for the Flux safetensors file. / Flux safetensorsファイルの出力先", + ) + parser.add_argument( + "--mem_eff_load_save", + action="store_true", + help="use custom memory efficient load and save functions for FLUX.1 model" + " / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する", + ) + parser.add_argument( + "--save_precision", + type=str, + default=None, + help="precision in saving, default is same as loading precision" + "float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz" + " / 保存時に精度を変更して保存する、デフォルトは読み込み時と同じ精度", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + args = parser.parse_args() + convert(args) From be078bdaca41084a20edb952b98a82f3e05d2dad Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 15 Sep 2024 13:59:17 +0900 Subject: [PATCH 125/163] fix typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7d6c336e6..f79fe21af 100644 --- a/README.md +++ b/README.md @@ -446,7 +446,7 @@ resolution = [512, 512] Script: `convert_diffusers_to_flux1.py` -Converts Diffusers models to FLUX.1 models. The script is experimental. See `--help` for options. schnell and dev models are supported. AE/CLIP/T5XXL are not supported. The diffusers folder is a parent folder of `transfomer` folder. +Converts Diffusers models to FLUX.1 models. The script is experimental. See `--help` for options. schnell and dev models are supported. AE/CLIP/T5XXL are not supported. The diffusers folder is a parent folder of `rmer` folder. ``` python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_folder_or_00001_safetensors --save_to path/to/flux1.safetensors --mem_eff_load_save --save_precision bf16 From 96c677b4594ed6f28f3ef896f6deca7c3aced25d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 16 Sep 2024 10:42:09 +0900 Subject: [PATCH 126/163] fix to work lienar/cosine lr scheduler closes #1602 ref #1393 --- library/train_util.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 742d057e0..60afd4219 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4707,6 +4707,15 @@ def wrap_check_needless_num_warmup_steps(return_vals): **lr_scheduler_kwargs, ) + # these schedulers do not require `num_decay_steps` + if name == SchedulerType.LINEAR or name == SchedulerType.COSINE: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **lr_scheduler_kwargs, + ) + # All other schedulers require `num_decay_steps` if num_decay_steps is None: raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") @@ -5837,14 +5846,9 @@ def sample_image_inference( wandb_tracker = accelerator.get_tracker("wandb") import wandb + # not to commit images to avoid inconsistency between training and logging steps - wandb_tracker.log( - {f"sample_{i}": wandb.Image( - image, - caption=prompt # positive prompt as a caption - )}, - commit=False - ) + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption # endregion From d8d15f1a7e09ca217930288b41bd239881126b93 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 16 Sep 2024 23:14:09 +0900 Subject: [PATCH 127/163] add support for specifying blocks in FLUX.1 LoRA training --- README.md | 24 ++++++++++++- networks/lora_flux.py | 82 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f79fe21af..24217d8b7 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 16, 2024: + + Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details. + Sep 15, 2024: Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported. @@ -54,9 +58,12 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. ` - [FLUX.1 LoRA training](#flux1-lora-training) - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) - - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) + - [Distribution of timesteps](#distribution-of-timesteps) - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) + - [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) + - [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) - [FLUX.1 OFT training](#flux1-oft-training) +- [Inference for FLUX.1 with LoRA model](#inference-for-flux1-with-lora-model) - [FLUX.1 fine-tuning](#flux1-fine-tuning) - [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning) - [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models) @@ -239,6 +246,21 @@ Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`. +#### Specify blocks to train in FLUX.1 LoRA training + +You can specify the blocks to train in FLUX.1 LoRA training by specifying `train_double_block_indices` and `train_single_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`. The number of double blocks is 19, and the number of single blocks is 38, so the valid range is 0-18 and 0-37, respectively. `all` is also available to train all blocks, `none` is also available to train no blocks. + +example: +``` +--network_args "train_double_block_indices=0,1,8-12,18" "train_single_block_indices=3,10,20-25,37" +``` + +``` +--network_args "train_double_block_indices=none" "train_single_block_indices=10-15" +``` + +If you specify one of `train_double_block_indices` or `train_single_block_indices`, the other will be trained as usual. + ### FLUX.1 OFT training You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different. diff --git a/networks/lora_flux.py b/networks/lora_flux.py index a34cde1a8..f549ac18f 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -24,6 +24,10 @@ logger = logging.getLogger(__name__) +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + + class LoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. @@ -354,6 +358,50 @@ def create_network( in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval? assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)" + # double/single train blocks + def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: + """ + Parse a block selection string and return a list of booleans. + + Args: + selection (str): A string specifying which blocks to select. + total_blocks (int): The total number of blocks available. + + Returns: + List[bool]: A list of booleans indicating which blocks are selected. + """ + if selection == "all": + return [True] * total_blocks + if selection == "none" or selection == "": + return [False] * total_blocks + + selected = [False] * total_blocks + ranges = selection.split(",") + + for r in ranges: + if "-" in r: + start, end = map(str.strip, r.split("-")) + start = int(start) + end = int(end) + assert 0 <= start < total_blocks, f"invalid start index: {start}" + assert 0 <= end < total_blocks, f"invalid end index: {end}" + assert start <= end, f"invalid range: {start}-{end}" + for i in range(start, end + 1): + selected[i] = True + else: + index = int(r) + assert 0 <= index < total_blocks, f"invalid index: {index}" + selected[index] = True + + return selected + + train_double_block_indices = kwargs.get("train_double_block_indices", None) + train_single_block_indices = kwargs.get("train_single_block_indices", None) + if train_double_block_indices is not None: + train_double_block_indices = parse_block_selection(train_double_block_indices, NUM_DOUBLE_BLOCKS) + if train_single_block_indices is not None: + train_single_block_indices = parse_block_selection(train_single_block_indices, NUM_SINGLE_BLOCKS) + # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) if rank_dropout is not None: @@ -399,6 +447,8 @@ def create_network( train_t5xxl=train_t5xxl, type_dims=type_dims, in_dims=in_dims, + train_double_block_indices=train_double_block_indices, + train_single_block_indices=train_single_block_indices, verbose=verbose, ) @@ -509,6 +559,8 @@ def __init__( train_t5xxl: bool = False, type_dims: Optional[List[int]] = None, in_dims: Optional[List[int]] = None, + train_double_block_indices: Optional[List[bool]] = None, + train_single_block_indices: Optional[List[bool]] = None, verbose: Optional[bool] = False, ) -> None: super().__init__() @@ -527,6 +579,8 @@ def __init__( self.type_dims = type_dims self.in_dims = in_dims + self.train_double_block_indices = train_double_block_indices + self.train_single_block_indices = train_single_block_indices self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -600,7 +654,7 @@ def create_modules( dim = default_dim if default_dim is not None else self.lora_dim alpha = self.alpha - if type_dims is not None: + if is_flux and type_dims is not None: identifier = [ ("img_attn",), ("txt_attn",), @@ -613,9 +667,33 @@ def create_modules( ] for i, d in enumerate(type_dims): if d is not None and all([id in lora_name for id in identifier[i]]): - dim = d + dim = d # may be 0 for skip break + if ( + is_flux + and dim + and ( + self.train_double_block_indices is not None + or self.train_single_block_indices is not None + ) + and ("double" in lora_name or "single" in lora_name) + ): + # "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..." + block_index = int(lora_name.split("_")[4]) # bit dirty + if ( + "double" in lora_name + and self.train_double_block_indices is not None + and not self.train_double_block_indices[block_index] + ): + dim = 0 + elif ( + "single" in lora_name + and self.train_single_block_indices is not None + and not self.train_single_block_indices[block_index] + ): + dim = 0 + elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha From 0cbe95bcc7e88f518802f29fe2b99da806963267 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 17 Sep 2024 21:21:28 +0900 Subject: [PATCH 128/163] fix text_encoder_lr to work with int closes #1608 --- networks/lora_flux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index f549ac18f..91e9cd77f 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -966,8 +966,8 @@ def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr # if float, use the same value for both text encoders if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): text_encoder_lr = [default_lr, default_lr] - elif isinstance(text_encoder_lr, float): - text_encoder_lr = [text_encoder_lr, text_encoder_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr)] elif len(text_encoder_lr) == 1: text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] From a2ad7e5644f08141fe053a2b63446d70d777bdcf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 17 Sep 2024 21:42:14 +0900 Subject: [PATCH 129/163] blocks_to_swap=0 means no swap --- flux_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train.py b/flux_train.py index 33481df8f..5d8326b1d 100644 --- a/flux_train.py +++ b/flux_train.py @@ -265,7 +265,7 @@ def train(args): flux.requires_grad_(True) - is_swapping_blocks = args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None + is_swapping_blocks = args.double_blocks_to_swap or args.single_blocks_to_swap if is_swapping_blocks: # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! From bbd160b4ca9293881c222f9b9e1d832af69699db Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Wed, 18 Sep 2024 07:55:04 +0900 Subject: [PATCH 130/163] sd3 schedule free opt (#1605) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * New ScheduleFree support for Flux (#1600) * init * use no schedule * fix typo * update for eval() * fix typo * update * Update train_util.py * Update requirements.txt * update sfwrapper WIP * no need to check schedulefree optimizer * remove debug print * comment out schedulefree wrapper * update readme --------- Co-authored-by: 青龍聖者@bdsqlsz <865105819@qq.com> --- README.md | 8 +++ library/train_util.py | 152 ++++++++++++++++++++++++++++++++++++++++-- requirements.txt | 1 + 3 files changed, 154 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 24217d8b7..dc9862927 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,14 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 18, 2024: + +- Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. + - `schedulefree` is added to the dependencies. Please update the library if necessary. + - AdamWScheduleFree or SGDScheduleFree can be used. Specify `adamwschedulefree` or `sgdschedulefree` in `--optimizer_type`. + - Wrapper classes are not available for now. + - These can be used not only for FLUX.1 training but also for other training scripts after merging to the dev/main branch. + Sep 16, 2024: Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details. diff --git a/library/train_util.py b/library/train_util.py index 60afd4219..a54f23ff6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3303,6 +3303,20 @@ def int_or_float(value): help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', ) + # parser.add_argument( + # "--optimizer_schedulefree_wrapper", + # action="store_true", + # help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用", + # ) + + # parser.add_argument( + # "--schedulefree_wrapper_args", + # type=str, + # default=None, + # nargs="*", + # help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")', + # ) + parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") parser.add_argument( "--lr_scheduler_args", @@ -4582,26 +4596,146 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + if optimizer_type == "AdamWScheduleFree".lower(): + optimizer_class = sf.AdamWScheduleFree + logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "SGDScheduleFree".lower(): + optimizer_class = sf.SGDScheduleFree + logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + # make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop + optimizer.train() + if optimizer is None: # 任意のoptimizerを使う - optimizer_type = args.optimizer_type # lowerでないやつ(微妙) - logger.info(f"use {optimizer_type} | {optimizer_kwargs}") - if "." not in optimizer_type: + case_sensitive_optimizer_type = args.optimizer_type # not lower + logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}") + + if "." not in case_sensitive_optimizer_type: # from torch.optim optimizer_module = torch.optim - else: - values = optimizer_type.split(".") + else: # from other library + values = case_sensitive_optimizer_type.split(".") optimizer_module = importlib.import_module(".".join(values[:-1])) - optimizer_type = values[-1] + case_sensitive_optimizer_type = values[-1] - optimizer_class = getattr(optimizer_module, optimizer_type) + optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + """ + # wrap any of above optimizer with schedulefree, if optimizer is not schedulefree + if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + + schedulefree_wrapper_kwargs = {} + if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0: + for arg in args.schedulefree_wrapper_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + schedulefree_wrapper_kwargs[key] = value + + sf_wrapper = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs) + sf_wrapper.train() # make optimizer as train mode + + # we need to make optimizer as a subclass of torch.optim.Optimizer, we make another Proxy class over SFWrapper + class OptimizerProxy(torch.optim.Optimizer): + def __init__(self, sf_wrapper): + self._sf_wrapper = sf_wrapper + + def __getattr__(self, name): + return getattr(self._sf_wrapper, name) + + # override properties + @property + def state(self): + return self._sf_wrapper.state + + @state.setter + def state(self, state): + self._sf_wrapper.state = state + + @property + def param_groups(self): + return self._sf_wrapper.param_groups + + @param_groups.setter + def param_groups(self, param_groups): + self._sf_wrapper.param_groups = param_groups + + @property + def defaults(self): + return self._sf_wrapper.defaults + + @defaults.setter + def defaults(self, defaults): + self._sf_wrapper.defaults = defaults + + def add_param_group(self, param_group): + self._sf_wrapper.add_param_group(param_group) + + def load_state_dict(self, state_dict): + self._sf_wrapper.load_state_dict(state_dict) + + def state_dict(self): + return self._sf_wrapper.state_dict() + + def zero_grad(self): + self._sf_wrapper.zero_grad() + + def step(self, closure=None): + self._sf_wrapper.step(closure) + + def train(self): + self._sf_wrapper.train() + + def eval(self): + self._sf_wrapper.eval() + + # isinstance チェックをパスするためのメソッド + def __instancecheck__(self, instance): + return isinstance(instance, (type(self), Optimizer)) + + optimizer = OptimizerProxy(sf_wrapper) + + logger.info(f"wrap optimizer with ScheduleFreeWrapper | {schedulefree_wrapper_kwargs}") + """ + optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) return optimizer_name, optimizer_args, optimizer +def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool: + return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper + + +def get_dummy_scheduler(optimizer: Optimizer) -> Any: + # dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers. + # this scheduler is used for logging only. + # this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler + class DummyScheduler: + def __init__(self, optimizer: Optimizer): + self.optimizer = optimizer + + def step(self): + pass + + def get_last_lr(self): + return [group["lr"] for group in self.optimizer.param_groups] + + return DummyScheduler(optimizer) + + # Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler # Add some checking and features to the original function. @@ -4610,6 +4744,10 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): """ Unified API to get any scheduler from its name. """ + # if schedulefree optimizer, return dummy scheduler + if is_schedulefree_optimizer(optimizer, args): + return get_dummy_scheduler(optimizer) + name = args.lr_scheduler num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps num_warmup_steps: Optional[int] = ( diff --git a/requirements.txt b/requirements.txt index 9a4fa0c15..bab53f20f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ pytorch-lightning==1.9.0 bitsandbytes==0.43.3 prodigyopt==1.0 lion-pytorch==0.0.6 +schedulefree==1.2.7 tensorboard safetensors==0.4.4 # gradio==3.16.2 From e74502117bcf161ef5698fb0adba4f9fa0171b8d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 18 Sep 2024 08:04:32 +0900 Subject: [PATCH 131/163] update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index dc9862927..034a260ff 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ The command to install PyTorch is as follows: Sep 18, 2024: - Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. + - Details of the schedule-free optimizer can be found in [facebookresearch/schedule_free](https://github.com/facebookresearch/schedule_free). - `schedulefree` is added to the dependencies. Please update the library if necessary. - AdamWScheduleFree or SGDScheduleFree can be used. Specify `adamwschedulefree` or `sgdschedulefree` in `--optimizer_type`. - Wrapper classes are not available for now. From 1286e00bb0fc34c296f24b7057777f1c37cf8e11 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 18 Sep 2024 21:31:54 +0900 Subject: [PATCH 132/163] fix to call train/eval in schedulefree #1605 --- README.md | 3 +++ flux_train.py | 10 ++++++++++ library/train_util.py | 15 ++++++++++++++- train_network.py | 6 ++++++ 4 files changed, 33 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 034a260ff..843ae181b 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 18, 2024 (update 1): +Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now. + Sep 18, 2024: - Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. diff --git a/flux_train.py b/flux_train.py index 5d8326b1d..bc4e62793 100644 --- a/flux_train.py +++ b/flux_train.py @@ -347,8 +347,13 @@ def train(args): logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset @@ -760,6 +765,7 @@ def optimizer_hook(parameter: torch.Tensor): progress_bar.update(1) global_step += 1 + optimizer_eval_fn() flux_train_utils.sample_images( accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) @@ -778,6 +784,7 @@ def optimizer_hook(parameter: torch.Tensor): global_step, accelerator.unwrap_model(flux), ) + optimizer_train_fn() current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if len(accelerator.trackers) > 0: @@ -800,6 +807,7 @@ def optimizer_hook(parameter: torch.Tensor): accelerator.wait_for_everyone() + optimizer_eval_fn() if args.save_every_n_epochs is not None: if accelerator.is_main_process: flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( @@ -816,12 +824,14 @@ def optimizer_hook(parameter: torch.Tensor): flux_train_utils.sample_images( accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) + optimizer_train_fn() is_main_process = accelerator.is_main_process # if is_main_process: flux = accelerator.unwrap_model(flux) accelerator.end_training() + optimizer_eval_fn() if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) diff --git a/library/train_util.py b/library/train_util.py index a54f23ff6..fe9deb940 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -13,6 +13,7 @@ import time from typing import ( Any, + Callable, Dict, List, NamedTuple, @@ -4715,8 +4716,20 @@ def __instancecheck__(self, instance): return optimizer_name, optimizer_args, optimizer +def get_optimizer_train_eval_fn(optimizer: Optimizer, args: argparse.Namespace) -> Tuple[Callable, Callable]: + if not is_schedulefree_optimizer(optimizer, args): + # return dummy func + return lambda: None, lambda: None + + # get train and eval functions from optimizer + train_fn = optimizer.train + eval_fn = optimizer.eval + + return train_fn, eval_fn + + def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool: - return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper + return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper def get_dummy_scheduler(optimizer: Optimizer) -> Any: diff --git a/train_network.py b/train_network.py index 34385ae08..55faa143e 100644 --- a/train_network.py +++ b/train_network.py @@ -498,6 +498,7 @@ def train(self, args): # accelerator.print(f"trainable_params: {k} = {v}") optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset @@ -1199,6 +1200,7 @@ def remove_model(old_ckpt_name): progress_bar.update(1) global_step += 1 + optimizer_eval_fn() self.sample_images( accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet ) @@ -1217,6 +1219,7 @@ def remove_model(old_ckpt_name): if remove_step_no is not None: remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) remove_model(remove_ckpt_name) + optimizer_train_fn() current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) @@ -1243,6 +1246,7 @@ def remove_model(old_ckpt_name): accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 + optimizer_eval_fn() if args.save_every_n_epochs is not None: saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: @@ -1258,6 +1262,7 @@ def remove_model(old_ckpt_name): train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + optimizer_train_fn() # end of epoch @@ -1268,6 +1273,7 @@ def remove_model(old_ckpt_name): network = accelerator.unwrap_model(network) accelerator.end_training() + optimizer_eval_fn() if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) From e7040669bc9a31706fe9fedec14978b05223f968 Mon Sep 17 00:00:00 2001 From: Maru-mee <151493593+Maru-mee@users.noreply.github.com> Date: Thu, 19 Sep 2024 15:47:06 +0900 Subject: [PATCH 133/163] Bug fix: alpha_mask load --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index a46d94877..5a8da90e1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2207,7 +2207,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph if alpha_mask: if "alpha_mask" not in npz: return False - if npz["alpha_mask"].shape[0:2] != reso: # HxW + if (npz["alpha_mask"].shape[1], npz["alpha_mask"].shape[0]) != reso: # HxW => WxH != reso return False else: if "alpha_mask" in npz: From 9c757c2fba43d4e91d773cf6e9b7e2e8e3e8b376 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 19 Sep 2024 21:14:57 +0900 Subject: [PATCH 134/163] fix SDXL block index to match LBW --- networks/svd_merge_lora.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 0decd9048..b4b9e3bfd 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -184,18 +184,19 @@ def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int: elif "mid_block_" in lora_name: block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block else: + # SDXL: some numbers are skipped if lora_name.startswith("lora_unet_"): name = lora_name[len("lora_unet_") :] if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts block_idx = 1 elif name.startswith("input_blocks_"): # 1-8 to 2-9 block_idx = 1 + int(name.split("_")[2]) - elif name.startswith("middle_block_"): # 10 - block_idx = 10 - elif name.startswith("output_blocks_"): # 0-8 to 11-19 - block_idx = 11 + int(name.split("_")[2]) - elif name.startswith("out_"): # 20, No LoRA in sd-scripts - block_idx = 20 + elif name.startswith("middle_block_"): # 13 + block_idx = 13 + elif name.startswith("output_blocks_"): # 0-8 to 14-22 + block_idx = 14 + int(name.split("_")[2]) + elif name.startswith("out_"): # 23, No LoRA in sd-scripts + block_idx = 23 return block_idx From 3957372ded6fda20553acaf169993a422b829bdc Mon Sep 17 00:00:00 2001 From: Ed McManus Date: Thu, 19 Sep 2024 14:30:03 -0700 Subject: [PATCH 135/163] Retain alpha in `pil_resize` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently the alpha channel is dropped by `pil_resize()` when `--alpha_mask` is supplied and the image width does not exceed the bucket. This codepath is entered on the last line, here: ``` def trim_and_resize_if_required( random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] original_size = (image_width, image_height) # size before resize if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする if image_width > resized_size[0] and image_height > resized_size[1]: image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ else: image = pil_resize(image, resized_size) ``` --- library/utils.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/library/utils.py b/library/utils.py index a0bb19650..2171c7190 100644 --- a/library/utils.py +++ b/library/utils.py @@ -305,13 +305,26 @@ def _convert_float8(byte_tensor, dtype_str, shape): raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") def pil_resize(image, size, interpolation=Image.LANCZOS): - pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + # Check if the image has an alpha channel + has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False - # use Pillow resize + if has_alpha: + # Convert BGRA to RGBA + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) + else: + # Convert BGR to RGB + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # Resize the image resized_pil = pil_image.resize(size, interpolation) - # return cv2 image - resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + # Convert back to cv2 format + if has_alpha: + # Convert RGBA to BGRA + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) + else: + # Convert RGB to BGR + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) return resized_cv2 From de4bb657b089cc28f4127e891b927895892e20b5 Mon Sep 17 00:00:00 2001 From: Ed McManus Date: Thu, 19 Sep 2024 14:38:32 -0700 Subject: [PATCH 136/163] Update utils.py Cleanup --- library/utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/library/utils.py b/library/utils.py index 2171c7190..8a0c782c0 100644 --- a/library/utils.py +++ b/library/utils.py @@ -305,25 +305,19 @@ def _convert_float8(byte_tensor, dtype_str, shape): raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") def pil_resize(image, size, interpolation=Image.LANCZOS): - # Check if the image has an alpha channel has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False if has_alpha: - # Convert BGRA to RGBA pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) else: - # Convert BGR to RGB pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - # Resize the image resized_pil = pil_image.resize(size, interpolation) # Convert back to cv2 format if has_alpha: - # Convert RGBA to BGRA resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) else: - # Convert RGB to BGR resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) return resized_cv2 From 0535cd29b926530255d5400374813432ec52c3df Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Fri, 20 Sep 2024 10:05:22 +0800 Subject: [PATCH 137/163] fix: backward compatibility for text_encoder_lr --- train_network.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 55faa143e..dfa51a9c8 100644 --- a/train_network.py +++ b/train_network.py @@ -471,7 +471,11 @@ def train(self, args): if support_multiple_lrs: text_encoder_lr = args.text_encoder_lr else: - text_encoder_lr = None if args.text_encoder_lr is None or len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] + # toml backward compatibility + if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float): + text_encoder_lr = args.text_encoder_lr + else: + text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] try: if support_multiple_lrs: results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate) From 583d4a436c1cef57fce405d0167fb7ce575fc768 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 20 Sep 2024 22:22:24 +0900 Subject: [PATCH 138/163] add compatibility for int LR (D-Adaptation etc.) #1620 --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index dfa51a9c8..b24f89b1e 100644 --- a/train_network.py +++ b/train_network.py @@ -472,7 +472,7 @@ def train(self, args): text_encoder_lr = args.text_encoder_lr else: # toml backward compatibility - if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float): + if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float) or isinstance(args.text_encoder_lr, int): text_encoder_lr = args.text_encoder_lr else: text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] From e1f23af1bc733a1a89c35cf1be1301006c744b4a Mon Sep 17 00:00:00 2001 From: recris Date: Sat, 21 Sep 2024 12:58:32 +0100 Subject: [PATCH 139/163] make timestep sampling behave in the standard way when huber loss is used --- library/train_util.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5a8da90e1..72d2d8112 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5124,34 +5124,27 @@ def save_sd_model_on_train_end_common( def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - - # TODO: if a huber loss is selected, it will use constant timesteps for each batch - # as. In the future there may be a smarter way + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device='cpu') if args.loss_type == "huber" or args.loss_type == "smooth_l1": - timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu") - timestep = timesteps.item() - if args.huber_schedule == "exponential": alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps - huber_c = math.exp(-alpha * timestep) + huber_c = torch.exp(-alpha * timesteps) elif args.huber_schedule == "snr": - alphas_cumprod = noise_scheduler.alphas_cumprod[timestep] + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c elif args.huber_schedule == "constant": - huber_c = args.huber_c + huber_c = torch.full((b_size,), args.huber_c) else: raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - - timesteps = timesteps.repeat(b_size).to(device) + huber_c = huber_c.to(device) elif args.loss_type == "l2": - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) - huber_c = 1 # may be anything, as it's not used + huber_c = None # may be anything, as it's not used else: raise NotImplementedError(f"Unknown loss type {args.loss_type}") - timesteps = timesteps.long() + timesteps = timesteps.long().to(device) return timesteps, huber_c @@ -5190,20 +5183,21 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): return noise, noisy_latents, timesteps, huber_c -# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already def conditional_loss( - model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1 + model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor] ): if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) elif loss_type == "huber": + huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) elif loss_type == "smooth_l1": + huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) From 29177d2f0389bd13e3f12c95d463fb0e1c58f9a1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 23 Sep 2024 21:14:03 +0900 Subject: [PATCH 140/163] retain alpha in pil_resize backport #1619 --- library/utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/library/utils.py b/library/utils.py index 5b7e657b2..49d46a546 100644 --- a/library/utils.py +++ b/library/utils.py @@ -83,13 +83,20 @@ def setup_logging(args=None, log_level=None, reset=False): def pil_resize(image, size, interpolation=Image.LANCZOS): - pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False + + if has_alpha: + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) + else: + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - # use Pillow resize resized_pil = pil_image.resize(size, interpolation) - # return cv2 image - resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + # Convert back to cv2 format + if has_alpha: + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) + else: + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) return resized_cv2 From ab7b23187062db86d34fc82db95f7266a68ab5c4 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 25 Sep 2024 19:38:52 +0800 Subject: [PATCH 141/163] init --- library/train_util.py | 21 ++++++++++++++++++--- requirements.txt | 2 +- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5a8da90e1..bdf7774e4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2994,7 +2994,7 @@ def int_or_float(value): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, AdEMAMix8bit, PagedAdEMAMix8bit, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", ) # backward compatibility @@ -4032,7 +4032,7 @@ def task(): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -4141,7 +4141,22 @@ def get_optimizer(args, trainable_params): raise AttributeError( "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - + elif optimizer_type == "Ademamix8bit".lower(): + logger.info(f"use 8-bit Ademamix optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.AdEMAMix8bit + except AttributeError: + raise AttributeError( + "No Ademamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / Ademamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + elif optimizer_type == "PagedAdemamix8bit".lower(): + logger.info(f"use 8-bit PagedAdemamix optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.PagedAdEMAMix8bit + except AttributeError: + raise AttributeError( + "No PagedAdemamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / PagedAdemamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): diff --git a/requirements.txt b/requirements.txt index 15e6e58f1..e6e1bf6fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ ftfy==6.1.1 opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 -bitsandbytes==0.43.0 +bitsandbytes==0.44.0 prodigyopt==1.0 lion-pytorch==0.0.6 tensorboard From e74f58148c5994889463afa42bb6fc5d6447a75e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 25 Sep 2024 20:55:50 +0900 Subject: [PATCH 142/163] update README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index def528a22..9eabdaeef 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris! + - Improvements in OFT (Orthogonal Finetuning) Implementation 1. Optimization of Calculation Order: - Changed the calculation order in the forward method from (Wx)R to W(xR). From 1beddd84e5c4db729a84356db227d981dc18cf8d Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 25 Sep 2024 22:58:26 +0800 Subject: [PATCH 143/163] delete code for cleaning --- library/train_util.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index bdf7774e4..c4845c54b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4141,22 +4141,7 @@ def get_optimizer(args, trainable_params): raise AttributeError( "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - elif optimizer_type == "Ademamix8bit".lower(): - logger.info(f"use 8-bit Ademamix optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.AdEMAMix8bit - except AttributeError: - raise AttributeError( - "No Ademamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / Ademamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) - elif optimizer_type == "PagedAdemamix8bit".lower(): - logger.info(f"use 8-bit PagedAdemamix optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.PagedAdEMAMix8bit - except AttributeError: - raise AttributeError( - "No PagedAdemamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / PagedAdemamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): From 56a7bc171d48089fb50f8638537e42d07c579db3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 08:26:31 +0900 Subject: [PATCH 144/163] new block swap for FLUX.1 fine tuning --- README.md | 47 ++++++-- flux_train.py | 251 ++++++++++++++++++++++++++--------------- library/flux_models.py | 168 +++++++++++++++------------ 3 files changed, 297 insertions(+), 169 deletions(-) diff --git a/README.md b/README.md index ef691e918..7d623f900 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 26, 2024: +The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. + + Sep 18, 2024 (update 1): Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now. @@ -307,6 +311,8 @@ python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_ The memory-efficient training with block swap is based on 2kpr's implementation. Thanks to 2kpr! +__`--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. These options is still available, but they will be removed in the future. Please use `--blocks_to_swap` instead. These options are equivalent to specifying `double_blocks_to_swap + single_blocks_to_swap // 2` in `--blocks_to_swap`.__ + Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended. ``` @@ -319,39 +325,62 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 --timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 ---fused_backward_pass --double_blocks_to_swap 6 --cpu_offload_checkpointing --full_bf16 +--fused_backward_pass --blocks_to_swap 8 --full_bf16 ``` (The command is multi-line for readability. Please combine it into one line.) -Options are almost the same as LoRA training. The difference is `--full_bf16`, `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. +Options are almost the same as LoRA training. The difference is `--full_bf16`, `--fused_backward_pass` and `--blocks_to_swap`. `--cpu_offload_checkpointing` is also available. `--full_bf16` enables the training with bf16 (weights and gradients). `--fused_backward_pass` enables the fusing of the optimizer step into the backward pass for each parameter. This reduces the memory usage during training. Only Adafactor optimizer is supported for now. Stochastic rounding is also enabled when `--fused_backward_pass` and `--full_bf16` are specified. -`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. +`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency and stochastic rounding. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. -`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. Please see the next chapter for details. +`--blocks_to_swap` is the number of blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. The recommended maximum value is 36. -`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. +`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. All these options are experimental and may change in the future. The increasing the number of blocks to swap may reduce the memory usage, but the training speed will be slower. `--cpu_offload_checkpointing` also slows down the training. -Swap 6 double blocks and use cpu offload checkpointing may be a good starting point. Please try different settings according to VRAM usage and training speed. +Swap 8 blocks without cpu offload checkpointing may be a good starting point for 24GB VRAM GPUs. Please try different settings according to VRAM usage and training speed. The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results. +#### How to use block swap + +There are two possible ways to use block swap. It is unknown which is better. + +1. Swap the minimum number of blocks that fit in VRAM with batch size 1 and shorten the training speed of one step. + + The above command example is for this usage. + +2. Swap many blocks to increase the batch size and shorten the training speed per data. + + For example, swapping 20 blocks seems to increase the batch size to about 6. In this case, the training speed per data will be relatively faster than 1. + +#### Training with <24GB VRAM GPUs + +Swap 28 blocks without cpu offload checkpointing may be working with 12GB VRAM GPUs. Please try different settings according to VRAM size of your GPU. + +T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement for FLUX.1 fine-tuning. + #### Key Features for FLUX.1 fine-tuning -1. Technical details of double/single block swap: +1. Technical details of block swap: - Reduce memory usage by transferring double and single blocks of FLUX.1 from GPU to CPU when they are not needed. - During forward pass, the weights of the blocks that have finished calculation are transferred to CPU, and the weights of the blocks to be calculated are transferred to GPU. - The same is true for the backward pass, but the order is reversed. The gradients remain on the GPU. - Since the transfer between CPU and GPU takes time, the training will be slower. - - `--double_blocks_to_swap` and `--single_blocks_to_swap` specify the number of blocks to swap. For example, `--double_blocks_to_swap 6` swaps 6 blocks at each step of training, but the remaining 13 blocks are always on the GPU. - - About 640MB of memory can be saved per double block, and about 320MB of memory can be saved per single block. + - `--blocks_to_swap` specify the number of blocks to swap. + - About 640MB of memory can be saved per block. + - Since the memory usage of one double block and two single blocks is almost the same, the transfer of single blocks is done in units of two. For example, consider the case of `--blocks_to_swap 6`. + - Before the forward pass, all double blocks and 26 (=38-12) single blocks are on the GPU. The last 12 single blocks are on the CPU. + - In the forward pass, the 6 double blocks that have finished calculation (the first 6 blocks) are transferred to the CPU, and the 12 single blocks to be calculated (the last 12 blocks) are transferred to the GPU. + - The same is true for the backward pass, but in reverse order. The 12 single blocks that have finished calculation are transferred to the CPU, and the 6 double blocks to be calculated are transferred to the GPU. + - After the backward pass, the blocks are back to their original locations. 2. Sample Image Generation: - Sample image generation during training is now supported. diff --git a/flux_train.py b/flux_train.py index bc4e62793..bf34208f1 100644 --- a/flux_train.py +++ b/flux_train.py @@ -11,10 +11,12 @@ # - Per-block fused optimizer instances import argparse +from concurrent.futures import ThreadPoolExecutor import copy import math import os from multiprocessing import Value +import time from typing import List import toml @@ -265,14 +267,30 @@ def train(args): flux.requires_grad_(True) - is_swapping_blocks = args.double_blocks_to_swap or args.single_blocks_to_swap + # block swap + + # backward compatibility + if args.blocks_to_swap is None: + blocks_to_swap = args.double_blocks_to_swap or 0 + if args.single_blocks_to_swap is not None: + blocks_to_swap += args.single_blocks_to_swap // 2 + if blocks_to_swap > 0: + logger.warning( + "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + ) + logger.info( + f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + ) + args.blocks_to_swap = blocks_to_swap + del blocks_to_swap + + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 if is_swapping_blocks: # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! - logger.info( - f"enable block swap: double_blocks_to_swap={args.double_blocks_to_swap}, single_blocks_to_swap={args.single_blocks_to_swap}" - ) - flux.enable_block_swap(args.double_blocks_to_swap, args.single_blocks_to_swap) + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + flux.enable_block_swap(args.blocks_to_swap) if not cache_latents: # load VAE here if not cached @@ -443,82 +461,120 @@ def train(args): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) + # memory efficient block swapping + + def get_block_unit(dbl_blocks, sgl_blocks, index: int): + if index < len(dbl_blocks): + return (dbl_blocks[index],) + else: + index -= len(dbl_blocks) + index *= 2 + return (sgl_blocks[index], sgl_blocks[index + 1]) + + def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device): + def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc): + # print(f"Backward: Move block {bidx_to_cpu} to CPU") + for block in blocks_to_cpu: + block = block.to("cpu", non_blocking=True) + torch.cuda.empty_cache() + + # print(f"Backward: Move block {bidx_to_cuda} to CUDA") + for block in blocks_to_cuda: + block = block.to(dvc, non_blocking=True) + + torch.cuda.synchronize() + # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}") + return bidx_to_cpu, bidx_to_cuda + + blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu) + blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda) + + futures[block_idx_to_cuda] = thread_pool.submit( + move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device + ) + + def wait_blocks_move(block_idx, futures): + if block_idx not in futures: + return + # print(f"Backward: Wait for block {block_idx}") + # start_time = time.perf_counter() + future = futures.pop(block_idx) + future.result() + # print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") + # torch.cuda.synchronize() + # print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s") + if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - double_blocks_to_swap = args.double_blocks_to_swap - single_blocks_to_swap = args.single_blocks_to_swap + blocks_to_swap = args.blocks_to_swap num_double_blocks = 19 # len(flux.double_blocks) num_single_blocks = 38 # len(flux.single_blocks) - handled_double_block_indices = set() - handled_single_block_indices = set() + num_block_units = num_double_blocks + num_single_blocks // 2 + handled_unit_indices = set() + + n = 1 # only asyncronous purpose, no need to increase this number + # n = 2 + # n = max(1, os.cpu_count() // 2) + thread_pool = ThreadPoolExecutor(max_workers=n) + futures = {} for param_group, param_name_group in zip(optimizer.param_groups, param_names): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: grad_hook = None - if double_blocks_to_swap: - if param_name.startswith("double_blocks"): - block_idx = int(param_name.split(".")[1]) - if ( - block_idx not in handled_double_block_indices - and block_idx >= (num_double_blocks - double_blocks_to_swap) - 1 - and block_idx < num_double_blocks - 1 - ): - # swap next (already backpropagated) block - handled_double_block_indices.add(block_idx) - block_idx_cpu = block_idx + 1 - block_idx_cuda = double_blocks_to_swap - (num_double_blocks - block_idx_cpu) - - # create swap hook - def create_double_swap_grad_hook(bidx, bidx_cuda): - def __grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - # swap blocks if necessary - flux.double_blocks[bidx].to("cpu") - flux.double_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move double block {bidx} to cpu and {bidx_cuda} to device") - - return __grad_hook - - grad_hook = create_double_swap_grad_hook(block_idx_cpu, block_idx_cuda) - if single_blocks_to_swap: - if param_name.startswith("single_blocks"): + if blocks_to_swap: + is_double = param_name.startswith("double_blocks") + is_single = param_name.startswith("single_blocks") + if is_double or is_single: block_idx = int(param_name.split(".")[1]) - if ( - block_idx not in handled_single_block_indices - and block_idx >= (num_single_blocks - single_blocks_to_swap) - 1 - and block_idx < num_single_blocks - 1 - ): - handled_single_block_indices.add(block_idx) - block_idx_cpu = block_idx + 1 - block_idx_cuda = single_blocks_to_swap - (num_single_blocks - block_idx_cpu) - # print(param_name, block_idx_cpu, block_idx_cuda) - - # create swap hook - def create_single_swap_grad_hook(bidx, bidx_cuda): - def __grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - # swap blocks if necessary - flux.single_blocks[bidx].to("cpu") - flux.single_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move single block {bidx} to cpu and {bidx_cuda} to device") - - return __grad_hook - - grad_hook = create_single_swap_grad_hook(block_idx_cpu, block_idx_cuda) + unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2 + if unit_idx not in handled_unit_indices: + # swap following (already backpropagated) block + handled_unit_indices.add(unit_idx) + + # if n blocks were already backpropagated + num_blocks_propagated = num_block_units - unit_idx - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap + waiting = unit_idx > 0 and unit_idx <= blocks_to_swap + if swapping or waiting: + block_idx_to_cpu = num_block_units - num_blocks_propagated + block_idx_to_cuda = blocks_to_swap - num_blocks_propagated + block_idx_to_wait = unit_idx - 1 + + # create swap hook + def create_swap_grad_hook( + bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool + ): + def __grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + # print(f"Backward: {uidx}, {swpng}, {wtng}") + if swpng: + submit_move_blocks( + futures, + thread_pool, + bidx_to_cpu, + bidx_to_cuda, + flux.double_blocks, + flux.single_blocks, + accelerator.device, + ) + if wtng: + wait_blocks_move(bidx_to_wait, futures) + + return __grad_hook + + grad_hook = create_swap_grad_hook( + block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting + ) if grad_hook is None: @@ -547,10 +603,15 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} - double_blocks_to_swap = args.double_blocks_to_swap - single_blocks_to_swap = args.single_blocks_to_swap + blocks_to_swap = args.blocks_to_swap num_double_blocks = 19 # len(flux.double_blocks) num_single_blocks = 38 # len(flux.single_blocks) + num_block_units = num_double_blocks + num_single_blocks // 2 + + n = 1 # only asyncronous purpose, no need to increase this number + # n = max(1, os.cpu_count() // 2) + thread_pool = ThreadPoolExecutor(max_workers=n) + futures = {} for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: @@ -571,18 +632,30 @@ def optimizer_hook(parameter: torch.Tensor): optimizers[i].zero_grad(set_to_none=True) # swap blocks if necessary - if btype == "double" and double_blocks_to_swap: - if bidx >= num_double_blocks - double_blocks_to_swap: - bidx_cuda = double_blocks_to_swap - (num_double_blocks - bidx) - flux.double_blocks[bidx].to("cpu") - flux.double_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move double block {bidx} to cpu and {bidx_cuda} to device") - elif btype == "single" and single_blocks_to_swap: - if bidx >= num_single_blocks - single_blocks_to_swap: - bidx_cuda = single_blocks_to_swap - (num_single_blocks - bidx) - flux.single_blocks[bidx].to("cpu") - flux.single_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move single block {bidx} to cpu and {bidx_cuda} to device") + if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)): + unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2 + num_blocks_propagated = num_block_units - unit_idx + + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap + waiting = unit_idx > 0 and unit_idx <= blocks_to_swap + + if swapping: + block_idx_to_cpu = num_block_units - num_blocks_propagated + block_idx_to_cuda = blocks_to_swap - num_blocks_propagated + # print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}") + submit_move_blocks( + futures, + thread_pool, + block_idx_to_cpu, + block_idx_to_cuda, + flux.double_blocks, + flux.single_blocks, + accelerator.device, + ) + + if waiting: + block_idx_to_wait = unit_idx - 1 + wait_blocks_move(block_idx_to_wait, futures) return optimizer_hook @@ -881,24 +954,26 @@ def setup_parser() -> argparse.ArgumentParser: help="skip latents validity check / latentsの正当性チェックをスキップする", ) parser.add_argument( - "--double_blocks_to_swap", + "--blocks_to_swap", type=int, default=None, help="[EXPERIMENTAL] " - "Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes." + "Sets the number of blocks (~640MB) to swap during the forward and backward passes." "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップする'変換ブロック'(約640MB)の数を設定します。" + " / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。" "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) parser.add_argument( "--single_blocks_to_swap", type=int, default=None, - help="[EXPERIMENTAL] " - "Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes." - "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップする'変換ブロック'(約320MB)の数を設定します。" - "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", ) parser.add_argument( "--cpu_offload_checkpointing", diff --git a/library/flux_models.py b/library/flux_models.py index b5726c298..a35dbc106 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -2,9 +2,12 @@ # license: Apache-2.0 License +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass import math -from typing import Optional +import os +import time +from typing import Dict, List, Optional from library.device_utils import init_ipex, clean_memory_on_device @@ -917,8 +920,10 @@ def __init__(self, params: FluxParams): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False - self.double_blocks_to_swap = None - self.single_blocks_to_swap = None + self.blocks_to_swap = None + + self.thread_pool: Optional[ThreadPoolExecutor] = None + self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2 @property def device(self): @@ -956,38 +961,52 @@ def disable_gradient_checkpointing(self): print("FLUX: Gradient checkpointing disabled.") - def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optional[int]): - self.double_blocks_to_swap = double_blocks - self.single_blocks_to_swap = single_blocks + def enable_block_swap(self, num_blocks: int): + self.blocks_to_swap = num_blocks + + n = 1 # async block swap. 1 is enough + # n = 2 + # n = max(1, os.cpu_count() // 2) + self.thread_pool = ThreadPoolExecutor(max_workers=n) def move_to_device_except_swap_blocks(self, device: torch.device): # assume model is on cpu - if self.double_blocks_to_swap: + if self.blocks_to_swap: save_double_blocks = self.double_blocks - self.double_blocks = None - if self.single_blocks_to_swap: save_single_blocks = self.single_blocks + self.double_blocks = None self.single_blocks = None self.to(device) - if self.double_blocks_to_swap: + if self.blocks_to_swap: self.double_blocks = save_double_blocks - if self.single_blocks_to_swap: self.single_blocks = save_single_blocks + def get_block_unit(self, index: int): + if index < len(self.double_blocks): + return (self.double_blocks[index],) + else: + index -= len(self.double_blocks) + index *= 2 + return self.single_blocks[index], self.single_blocks[index + 1] + + def get_unit_index(self, is_double: bool, index: int): + if is_double: + return index + else: + return len(self.double_blocks) + index // 2 + def prepare_block_swap_before_forward(self): - # move last n blocks to cpu: they are on cuda - if self.double_blocks_to_swap: - for i in range(len(self.double_blocks) - self.double_blocks_to_swap): - self.double_blocks[i].to(self.device) - for i in range(len(self.double_blocks) - self.double_blocks_to_swap, len(self.double_blocks)): - self.double_blocks[i].to("cpu") # , non_blocking=True) - if self.single_blocks_to_swap: - for i in range(len(self.single_blocks) - self.single_blocks_to_swap): - self.single_blocks[i].to(self.device) - for i in range(len(self.single_blocks) - self.single_blocks_to_swap, len(self.single_blocks)): - self.single_blocks[i].to("cpu") # , non_blocking=True) + # make: first n blocks are on cuda, and last n blocks are on cpu + if self.blocks_to_swap is None: + raise ValueError("Block swap is not enabled.") + for i in range(self.num_block_units - self.blocks_to_swap): + for b in self.get_block_unit(i): + b.to(self.device) + for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): + for b in self.get_block_unit(i): + b.to("cpu") clean_memory_on_device(self.device) def forward( @@ -1017,69 +1036,73 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) - if not self.double_blocks_to_swap: + if not self.blocks_to_swap: for block in self.double_blocks: img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: - # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning - for block_idx in range(self.double_blocks_to_swap): - block = self.double_blocks[len(self.double_blocks) - self.double_blocks_to_swap + block_idx] - if block.parameters().__next__().device.type != "cpu": - block.to("cpu") # , non_blocking=True) - # print(f"Moved double block {len(self.double_blocks) - self.double_blocks_to_swap + block_idx} to cpu.") - - block = self.double_blocks[block_idx] - if block.parameters().__next__().device.type == "cpu": - block.to(self.device) - # print(f"Moved double block {block_idx} to cuda.") - - to_cpu_block_index = 0 + futures = {} + + def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda): + # print(f"Moving {bidx_to_cpu} to cpu.") + for block in blocks_to_cpu: + block.to("cpu", non_blocking=True) + torch.cuda.empty_cache() + + # print(f"Moving {bidx_to_cuda} to cuda.") + for block in blocks_to_cuda: + block.to(self.device, non_blocking=True) + + torch.cuda.synchronize() + # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") + return block_idx_to_cpu, block_idx_to_cuda + + blocks_to_cpu = self.get_block_unit(block_idx_to_cpu) + blocks_to_cuda = self.get_block_unit(block_idx_to_cuda) + # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") + return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda) + + def wait_for_blocks_move(block_idx, ftrs): + if block_idx not in ftrs: + return + # print(f"Waiting for move blocks: {block_idx}") + # start_time = time.perf_counter() + ftr = ftrs.pop(block_idx) + ftr.result() + # torch.cuda.synchronize() + # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + for block_idx, block in enumerate(self.double_blocks): - # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda - moving = block_idx >= len(self.double_blocks) - self.double_blocks_to_swap - if moving: - block.to(self.device) # move to cuda - # print(f"Moved double block {block_idx} to cuda.") + # print(f"Double block {block_idx}") + unit_idx = self.get_unit_index(is_double=True, index=block_idx) + wait_for_blocks_move(unit_idx, futures) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if moving: - self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) - # print(f"Moved double block {to_cpu_block_index} to cpu.") - to_cpu_block_index += 1 + if unit_idx < self.blocks_to_swap: + block_idx_to_cpu = unit_idx + block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx + future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) + futures[block_idx_to_cuda] = future - img = torch.cat((txt, img), 1) + img = torch.cat((txt, img), 1) - if not self.single_blocks_to_swap: - for block in self.single_blocks: - img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - else: - # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning - for block_idx in range(self.single_blocks_to_swap): - block = self.single_blocks[len(self.single_blocks) - self.single_blocks_to_swap + block_idx] - if block.parameters().__next__().device.type != "cpu": - block.to("cpu") # , non_blocking=True) - # print(f"Moved single block {len(self.single_blocks) - self.single_blocks_to_swap + block_idx} to cpu.") - - block = self.single_blocks[block_idx] - if block.parameters().__next__().device.type == "cpu": - block.to(self.device) - # print(f"Moved single block {block_idx} to cuda.") - - to_cpu_block_index = 0 for block_idx, block in enumerate(self.single_blocks): - # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda - moving = block_idx >= len(self.single_blocks) - self.single_blocks_to_swap - if moving: - block.to(self.device) # move to cuda - # print(f"Moved single block {block_idx} to cuda.") + # print(f"Single block {block_idx}") + unit_idx = self.get_unit_index(is_double=False, index=block_idx) + if block_idx % 2 == 0: + wait_for_blocks_move(unit_idx, futures) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if moving: - self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) - # print(f"Moved single block {to_cpu_block_index} to cpu.") - to_cpu_block_index += 1 + if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap: + block_idx_to_cpu = unit_idx + block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx + future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) + futures[block_idx_to_cuda] = future img = img[:, txt.shape[1] :, ...] @@ -1088,6 +1111,7 @@ def forward( vec = vec.to(self.device) img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img From da94fd934eb4951d1cb132abc9d2a355e44d7abf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 08:27:48 +0900 Subject: [PATCH 145/163] fix typos --- flux_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train.py b/flux_train.py index bf34208f1..022467ea7 100644 --- a/flux_train.py +++ b/flux_train.py @@ -516,7 +516,7 @@ def wait_blocks_move(block_idx, futures): num_block_units = num_double_blocks + num_single_blocks // 2 handled_unit_indices = set() - n = 1 # only asyncronous purpose, no need to increase this number + n = 1 # only asynchronous purpose, no need to increase this number # n = 2 # n = max(1, os.cpu_count() // 2) thread_pool = ThreadPoolExecutor(max_workers=n) @@ -608,7 +608,7 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): num_single_blocks = 38 # len(flux.single_blocks) num_block_units = num_double_blocks + num_single_blocks // 2 - n = 1 # only asyncronous purpose, no need to increase this number + n = 1 # only asynchronous purpose, no need to increase this number # n = max(1, os.cpu_count() // 2) thread_pool = ThreadPoolExecutor(max_workers=n) futures = {} From bf91bea2e4363e5b3e0db11f0955ab93a19a0452 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 20:51:40 +0900 Subject: [PATCH 146/163] fix flip_aug, alpha_mask, random_crop issue in caching --- README.md | 2 ++ library/train_util.py | 44 +++++++++++++++++++++++++++++++------------ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 9eabdaeef..b67a2c4e1 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly. + - Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris! - Improvements in OFT (Orthogonal Finetuning) Implementation diff --git a/library/train_util.py b/library/train_util.py index 72d2d8112..a31d00c69 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -998,9 +998,26 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # sort by resolution image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) - # split by resolution - batches = [] - batch = [] + # split by resolution and some conditions + class Condition: + def __init__(self, reso, flip_aug, alpha_mask, random_crop): + self.reso = reso + self.flip_aug = flip_aug + self.alpha_mask = alpha_mask + self.random_crop = random_crop + + def __eq__(self, other): + return ( + self.reso == other.reso + and self.flip_aug == other.flip_aug + and self.alpha_mask == other.alpha_mask + and self.random_crop == other.random_crop + ) + + batches: List[Tuple[Condition, List[ImageInfo]]] = [] + batch: List[ImageInfo] = [] + current_condition = None + logger.info("checking cache validity...") for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] @@ -1021,28 +1038,31 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc if cache_available: # do not add to batch continue - # if last member of batch has different resolution, flush the batch - if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: - batches.append(batch) + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + batches.append((current_condition, batch)) batch = [] batch.append(info) + current_condition = condition # if number of data in batch is enough, flush the batch if len(batch) >= vae_batch_size: - batches.append(batch) + batches.append((current_condition, batch)) batch = [] + current_condition = None if len(batch) > 0: - batches.append(batch) + batches.append((current_condition, batch)) if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only return # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded logger.info("caching latents...") - for batch in tqdm(batches, smoothing=1, total=len(batches)): - cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): + cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する @@ -2315,7 +2335,7 @@ def debug_dataset(train_dataset, show_input_ids=False): if "alpha_masks" in example and example["alpha_masks"] is not None: alpha_mask = example["alpha_masks"][j] logger.info(f"alpha mask size: {alpha_mask.size()}") - alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8) + alpha_mask = (alpha_mask.numpy() * 255.0).astype(np.uint8) if os.name == "nt": cv2.imshow("alpha_mask", alpha_mask) @@ -5124,7 +5144,7 @@ def save_sd_model_on_train_end_common( def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device='cpu') + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") if args.loss_type == "huber" or args.loss_type == "smooth_l1": if args.huber_schedule == "exponential": From 392e8dedd84e469b125e2935e3ecf02e6270a5b2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 21:14:11 +0900 Subject: [PATCH 147/163] fix flip_aug, alpha_mask, random_crop issue in caching in caching strategy --- library/train_util.py | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 319337a47..17dd447eb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -993,9 +993,26 @@ def new_cache_latents(self, model: Any, is_main_process: bool): # sort by resolution image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) - # split by resolution - batches = [] - batch = [] + # split by resolution and some conditions + class Condition: + def __init__(self, reso, flip_aug, alpha_mask, random_crop): + self.reso = reso + self.flip_aug = flip_aug + self.alpha_mask = alpha_mask + self.random_crop = random_crop + + def __eq__(self, other): + return ( + self.reso == other.reso + and self.flip_aug == other.flip_aug + and self.alpha_mask == other.alpha_mask + and self.random_crop == other.random_crop + ) + + batches: List[Tuple[Condition, List[ImageInfo]]] = [] + batch: List[ImageInfo] = [] + current_condition = None + logger.info("checking cache validity...") for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] @@ -1016,20 +1033,23 @@ def new_cache_latents(self, model: Any, is_main_process: bool): if cache_available: # do not add to batch continue - # if last member of batch has different resolution, flush the batch - if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: - batches.append(batch) + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + batches.append((current_condition, batch)) batch = [] batch.append(info) + current_condition = condition # if number of data in batch is enough, flush the batch if len(batch) >= caching_strategy.batch_size: - batches.append(batch) + batches.append((current_condition, batch)) batch = [] + current_condition = None if len(batch) > 0: - batches.append(batch) + batches.append((current_condition, batch)) # if cache to disk, don't cache latents in non-main process, set to info only if caching_strategy.cache_to_disk and not is_main_process: @@ -1041,9 +1061,8 @@ def new_cache_latents(self, model: Any, is_main_process: bool): # iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded logger.info("caching latents...") - for batch in tqdm(batches, smoothing=1, total=len(batches)): - # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) - caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): + caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと From a94bc84dec8e85e8a71217b4d2570a52c6779b73 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 21:37:31 +0900 Subject: [PATCH 148/163] fix to work bitsandbytes optimizers with full path #1640 --- library/train_util.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b40945ab8..47c367683 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3014,7 +3014,11 @@ def int_or_float(value): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, AdEMAMix8bit, PagedAdEMAMix8bit, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, " + "Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, " + "DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, " + "AdaFactor. " + "Also, you can use any optimizer by specifying the full path to the class, like 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit'.", ) # backward compatibility @@ -4105,6 +4109,7 @@ def get_optimizer(args, trainable_params): lr = args.learning_rate optimizer = None + optimizer_class = None if optimizer_type == "Lion".lower(): try: @@ -4162,7 +4167,8 @@ def get_optimizer(args, trainable_params): "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + if optimizer_class is not None: + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}") @@ -4338,6 +4344,7 @@ def get_optimizer(args, trainable_params): optimizer_class = getattr(optimizer_module, optimizer_type) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + # for logging optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) From ce49ced699298aa885d9a64b969fe8c77f30893b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 21:37:40 +0900 Subject: [PATCH 149/163] update readme --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b67a2c4e1..9f024c1c9 100644 --- a/README.md +++ b/README.md @@ -140,9 +140,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress - __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries. - - transformers, accelerate and huggingface_hub are updated. + - bitsandbytes, transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds! + - There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes). + - Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly. - Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris! From a9aa52658a0d9ba7910a1d1983b650bc9de7153e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 28 Sep 2024 17:12:56 +0900 Subject: [PATCH 150/163] fix sample generation is not working in FLUX1 fine tuning #1647 --- library/flux_models.py | 5 +++-- library/flux_train_utils.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index a35dbc106..0bc1c02b9 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -999,8 +999,9 @@ def get_unit_index(self, is_double: bool, index: int): def prepare_block_swap_before_forward(self): # make: first n blocks are on cuda, and last n blocks are on cpu - if self.blocks_to_swap is None: - raise ValueError("Block swap is not enabled.") + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + # raise ValueError("Block swap is not enabled.") + return for i in range(self.num_block_units - self.blocks_to_swap): for b in self.get_block_unit(i): b.to(self.device) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f77d4b585..1d1eb9d24 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -313,6 +313,7 @@ def denoise( guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + model.prepare_block_swap_before_forward() pred = model( img=img, img_ids=img_ids, @@ -325,7 +326,8 @@ def denoise( ) img = img + (t_prev - t_curr) * pred - + + model.prepare_block_swap_before_forward() return img From 822fe578591e44ac949830e03a8841e222483052 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 28 Sep 2024 20:57:27 +0900 Subject: [PATCH 151/163] add workaround for 'Some tensors share memory' error #1614 --- networks/convert_flux_lora.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py index bd4c1cf78..fe6466ebc 100644 --- a/networks/convert_flux_lora.py +++ b/networks/convert_flux_lora.py @@ -412,6 +412,10 @@ def main(args): state_dict = convert_ai_toolkit_to_sd_scripts(state_dict) elif args.src == "sd-scripts" and args.dst == "ai-toolkit": state_dict = convert_sd_scripts_to_ai_toolkit(state_dict) + + # eliminate 'shared tensors' + for k in list(state_dict.keys()): + state_dict[k] = state_dict[k].detach().clone() else: raise NotImplementedError(f"Conversion from {args.src} to {args.dst} is not supported") From 1a0f5b0c389f4e9fab5edb06b36f203e8894d581 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 00:35:29 +0900 Subject: [PATCH 152/163] re-fix sample generation is not working in FLUX1 split mode #1647 --- flux_train_network.py | 3 +++ library/flux_train_utils.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index a6e57eede..65b121e7c 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -300,6 +300,9 @@ def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.Fl self.flux_lower = flux_lower self.target_device = device + def prepare_block_swap_before_forward(self): + pass + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): self.flux_lower.to("cpu") clean_memory_on_device(self.target_device) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 1d1eb9d24..b3c9184f2 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -196,7 +196,6 @@ def sample_image_inference( tokens_and_masks = tokenize_strategy.tokenize(prompt) # strategy has apply_t5_attn_mask option encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) - print([x.shape if x is not None else None for x in encoded_text_encoder_conds]) # if text_encoder_conds is not cached, use encoded_text_encoder_conds if len(text_encoder_conds) == 0: From fe2aa32484a948f16955909e64c21da7fe1e4e0c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 09:49:25 +0900 Subject: [PATCH 153/163] adjust min/max bucket reso divisible by reso steps #1632 --- README.md | 2 ++ docs/config_README-en.md | 2 ++ docs/config_README-ja.md | 2 ++ fine_tune.py | 2 ++ library/train_util.py | 40 ++++++++++++++++++++++++++++++++------ train_controlnet.py | 2 ++ train_db.py | 2 ++ train_network.py | 2 +- train_textual_inversion.py | 2 +- 9 files changed, 48 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 9f024c1c9..de5cddb92 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - bitsandbytes, transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- There was a bug where the min_bucket_reso/max_bucket_reso in the dataset configuration did not create the correct resolution bucket if it was not divisible by bucket_reso_steps. These values are now warned and automatically rounded to a divisible value. Thanks to Maru-mee for raising the issue. Related PR [#1632](https://github.com/kohya-ss/sd-scripts/pull/1632) + - `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds! - There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes). diff --git a/docs/config_README-en.md b/docs/config_README-en.md index 83bea329b..66a50dc09 100644 --- a/docs/config_README-en.md +++ b/docs/config_README-en.md @@ -128,6 +128,8 @@ These are options related to the configuration of the data set. They cannot be d * `batch_size` * This corresponds to the command-line argument `--train_batch_size`. +* `max_bucket_reso`, `min_bucket_reso` + * Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`. These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each. diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index cc74c341b..0ed95e0eb 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -118,6 +118,8 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 * `batch_size` * コマンドライン引数の `--train_batch_size` と同等です。 +* `max_bucket_reso`, `min_bucket_reso` + * bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。 これらの設定はデータセットごとに固定です。 つまり、データセットに所属するサブセットはこれらの設定を共有することになります。 diff --git a/fine_tune.py b/fine_tune.py index d865cd2de..b556672d2 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -91,6 +91,8 @@ def train(args): ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + train_dataset_group.verify_bucket_reso_steps(64) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return diff --git a/library/train_util.py b/library/train_util.py index 47c367683..0cb6383a4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -653,6 +653,34 @@ def __init__( # caching self.caching_mode = None # None, 'latents', 'text' + def adjust_min_max_bucket_reso_by_steps( + self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int + ) -> Tuple[int, int]: + # make min/max bucket reso to be multiple of bucket_reso_steps + if min_bucket_reso % bucket_reso_steps != 0: + adjusted_min_bucket_reso = min_bucket_reso - min_bucket_reso % bucket_reso_steps + logger.warning( + f"min_bucket_reso is adjusted to be multiple of bucket_reso_steps" + f" / min_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {min_bucket_reso} -> {adjusted_min_bucket_reso}" + ) + min_bucket_reso = adjusted_min_bucket_reso + if max_bucket_reso % bucket_reso_steps != 0: + adjusted_max_bucket_reso = max_bucket_reso + bucket_reso_steps - max_bucket_reso % bucket_reso_steps + logger.warning( + f"max_bucket_reso is adjusted to be multiple of bucket_reso_steps" + f" / max_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {max_bucket_reso} -> {adjusted_max_bucket_reso}" + ) + max_bucket_reso = adjusted_max_bucket_reso + + assert ( + min(resolution) >= min_bucket_reso + ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert ( + max(resolution) <= max_bucket_reso + ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + + return min_bucket_reso, max_bucket_reso + def set_seed(self, seed): self.seed = seed @@ -1533,12 +1561,9 @@ def __init__( self.enable_bucket = enable_bucket if self.enable_bucket: - assert ( - min(resolution) >= min_bucket_reso - ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" - assert ( - max(resolution) <= max_bucket_reso - ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps( + resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps + ) self.min_bucket_reso = min_bucket_reso self.max_bucket_reso = max_bucket_reso self.bucket_reso_steps = bucket_reso_steps @@ -1901,6 +1926,9 @@ def __init__( self.enable_bucket = enable_bucket if self.enable_bucket: + min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps( + resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps + ) self.min_bucket_reso = min_bucket_reso self.max_bucket_reso = max_bucket_reso self.bucket_reso_steps = bucket_reso_steps diff --git a/train_controlnet.py b/train_controlnet.py index c9ac6c5a8..6938c4bcc 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -107,6 +107,8 @@ def train(args): ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + train_dataset_group.verify_bucket_reso_steps(64) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return diff --git a/train_db.py b/train_db.py index 39d8ea6ed..2c7f02582 100644 --- a/train_db.py +++ b/train_db.py @@ -93,6 +93,8 @@ def train(args): if args.no_token_padding: train_dataset_group.disable_token_padding() + train_dataset_group.verify_bucket_reso_steps(64) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return diff --git a/train_network.py b/train_network.py index 7ba073855..044ec3aa8 100644 --- a/train_network.py +++ b/train_network.py @@ -95,7 +95,7 @@ def generate_step_logs( return logs def assert_extra_args(self, args, train_dataset_group): - pass + train_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index ade077c36..96e7bd509 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -99,7 +99,7 @@ def __init__(self): self.is_sdxl = False def assert_extra_args(self, args, train_dataset_group): - pass + train_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) From 1567549220b5936af0c534ca23656ecd2f4882f0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 09:51:36 +0900 Subject: [PATCH 154/163] update help text #1632 --- library/train_util.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0cb6383a4..422dceca2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3865,8 +3865,20 @@ def add_dataset_arguments( action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする", ) - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") + parser.add_argument( + "--min_bucket_reso", + type=int, + default=256, + help="minimum resolution for buckets, must be divisible by bucket_reso_steps " + " / bucketの最小解像度、bucket_reso_stepsで割り切れる必要があります", + ) + parser.add_argument( + "--max_bucket_reso", + type=int, + default=1024, + help="maximum resolution for buckets, must be divisible by bucket_reso_steps " + " / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります", + ) parser.add_argument( "--bucket_reso_steps", type=int, From e0c3630203776dc568c32d67806a0a9d443f5721 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Sun, 29 Sep 2024 09:11:15 +0800 Subject: [PATCH 155/163] Support Sdxl Controlnet (#1648) * Create sdxl_train_controlnet.py * add fuse_background_pass * Update sdxl_train_controlnet.py * add fuse and fix error * update * Update sdxl_train_controlnet.py * Update sdxl_train_controlnet.py * Update sdxl_train_controlnet.py * update * Update sdxl_train_controlnet.py --- library/train_util.py | 2 +- sdxl_train_controlnet.py | 752 +++++++++++++++++++++++++++++++++++++++ train_controlnet.py | 33 +- 3 files changed, 779 insertions(+), 8 deletions(-) create mode 100644 sdxl_train_controlnet.py diff --git a/library/train_util.py b/library/train_util.py index e023f63a2..293fc05ad 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3581,7 +3581,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: # available backends: # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt", "tensort", "ipex", "tvm"], help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") diff --git a/sdxl_train_controlnet.py b/sdxl_train_controlnet.py new file mode 100644 index 000000000..00026d2cc --- /dev/null +++ b/sdxl_train_controlnet.py @@ -0,0 +1,752 @@ +import argparse +import math +import os +import random +from multiprocessing import Value +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from torch.nn.parallel import DistributedDataParallel as DDP +from accelerate.utils import set_seed +from diffusers import DDPMScheduler, ControlNetModel +from diffusers.utils.torch_utils import is_compiled_module +from safetensors.torch import load_file +from library import ( + deepspeed_utils, + sai_model_spec, + sdxl_model_util, + sdxl_original_unet, + sdxl_train_util, +) + +import library.model_util as model_util +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.huggingface_util as huggingface_util +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + add_v_prediction_like_loss, + apply_snr_weight, + prepare_scheduler_for_custom_training, + scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, +) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = { + "loss/current": current_loss, + "loss/average": avr_loss, + "lr": lr_scheduler.get_last_lr()[0], + } + + if args.optimizer_type.lower().startswith("DAdapt".lower()): + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[-1].param_groups[0]["d"] + * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + ) + + return logs + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) + + cache_latents = args.cache_latents + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if use_user_config: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension, + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = (train_dataset_group if args.max_data_loader_n_workers == 0 else None) + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(32) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + else: + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model( + args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype + ) + + # convert U-Net + with torch.no_grad(): + du_unet_sd = sdxl_model_util.convert_sdxl_unet_state_dict_to_diffusers(unet.state_dict()) + unet.to("cpu") + clean_memory_on_device(accelerator.device) + del unet + unet = sdxl_model_util.UNet2DConditionModel(**sdxl_model_util.DIFFUSERS_SDXL_UNET_CONFIG) + unet.load_state_dict(du_unet_sd) + + controlnet = ControlNetModel.from_unet(unet) + + if args.controlnet_model_name_or_path: + filename = args.controlnet_model_name_or_path + if os.path.isfile(filename): + if os.path.splitext(filename)[1] == ".safetensors": + state_dict = load_file(filename) + else: + state_dict = torch.load(filename) + state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) + controlnet.load_state_dict(state_dict) + elif os.path.isdir(filename): + controlnet = ControlNetModel.from_pretrained(filename) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents( + vae, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + ) + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + with torch.no_grad(): + train_dataset_group.cache_text_encoder_outputs( + (tokenizer1, tokenizer2), + (text_encoder1, text_encoder2), + accelerator.device, + None, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) + accelerator.wait_for_everyone() + + # モデルに xformers とか memory efficient attention を組み込む + # train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if args.xformers: + unet.enable_xformers_memory_efficient_attention() + controlnet.enable_xformers_memory_efficient_attention() + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + controlnet.enable_gradient_checkpointing() + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + trainable_params = list(filter(lambda p: p.requires_grad, controlnet.parameters())) + logger.info(f"trainable params count: {len(trainable_params)}") + logger.info( + f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}" + ) + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader)/ accelerator.num_processes/ args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix( + args, optimizer, accelerator.num_processes + ) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + controlnet.to(weight_dtype) + unet.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + controlnet.to(weight_dtype) + unet.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + unet.requires_grad_(False) + text_encoder1.requires_grad_(False) + text_encoder2.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + + # transform DDP after prepare + controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet + + controlnet.train() + + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + text_encoder2.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm( + range(args.max_train_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="steps", + ) + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr( + noise_scheduler + ) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + ( + "controlnet_train" + if args.log_tracker_name is None + else args.log_tracker_name + ), + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + loss_recorder = train_util.LossRecorder() + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name, model, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) + sai_metadata["modelspec.architecture"] = ( + sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet" + ) + state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(ckpt_file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, ckpt_file, sai_metadata) + else: + torch.save(state_dict, ckpt_file) + + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, + args, + 0, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + controlnet=controlnet, + ) + + # training loop + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(controlnet): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = ( + batch["latents"] + .to(accelerator.device) + .to(dtype=weight_dtype) + ) + else: + # latentに変換 + latents = ( + vae.encode(batch["images"].to(dtype=vae_dtype)) + .latent_dist.sample() + .to(dtype=weight_dtype) + ) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print( + "NaN found in latents, replacing with zeros" + ) + latents = torch.nan_to_num(latents, 0, out=latents) + latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + + if ( + "text_encoder_outputs1_list" not in batch + or batch["text_encoder_outputs1_list"] is None + ): + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + with torch.no_grad(): + # Get the text embedding for conditioning + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = ( + train_util.get_hidden_states_sdxl( + args.max_token_length, + input_ids1, + input_ids2, + tokenizer1, + tokenizer2, + text_encoder1, + text_encoder2, + None if not args.full_fp16 else weight_dtype, + ) + ) + else: + encoder_hidden_states1 = ( + batch["text_encoder_outputs1_list"] + .to(accelerator.device) + .to(weight_dtype) + ) + encoder_hidden_states2 = ( + batch["text_encoder_outputs2_list"] + .to(accelerator.device) + .to(weight_dtype) + ) + pool2 = ( + batch["text_encoder_pool2_list"] + .to(accelerator.device) + .to(weight_dtype) + ) + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings( + # orig_size, crop_size, target_size, accelerator.device + # ).to(weight_dtype) + + embs = torch.cat([orig_size, crop_size, target_size]).to(accelerator.device).to(weight_dtype) #B,6 + # concat embeddings + #vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + vector_embedding_dict = { + "text_embeds": pool2, + "time_ids": embs + } + text_embedding = torch.cat( + [encoder_hidden_states1, encoder_hidden_states2], dim=2 + ).to(weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps, huber_c = ( + train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + ) + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + + with accelerator.autocast(): + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=text_embedding, + added_cond_kwargs=vector_embedding_dict, + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + noise_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states=text_embedding, + added_cond_kwargs=vector_embedding_dict, + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), + return_dict=False, + )[0] + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = train_util.conditional_loss( + noise_pred.float(),target.float(),reduction="none",loss_type=args.loss_type,huber_c=huber_c, + ) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss,timesteps,noise_scheduler,args.min_snr_gamma,args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if not args.fused_backward_pass: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + sdxl_train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + controlnet=controlnet, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name,unwrap_model(controlnet)) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name,unwrap_model(controlnet)) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + sdxl_train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + controlnet=controlnet, + ) + + # end of epoch + + if is_main_process: + controlnet = unwrap_model(controlnet) + + accelerator.end_training() + + if is_main_process and (args.save_state or args.save_state_on_train_end): + train_util.save_state_on_train_end(args, accelerator) + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model( + ckpt_name, controlnet, force_sync_upload=True + ) + + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + + return parser + + +if __name__ == "__main__": + # sdxl_original_unet.USE_REENTRANT = False + + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_controlnet.py b/train_controlnet.py index c2945b083..8c7882c8f 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -254,6 +254,7 @@ def __contains__(self, name): accelerator.wait_for_everyone() if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() controlnet.enable_gradient_checkpointing() # 学習に必要なクラスを準備する @@ -304,6 +305,20 @@ def __contains__(self, name): controlnet, optimizer, train_dataloader, lr_scheduler ) + if args.fused_backward_pass: + import library.adafactor_fused + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + unet.requires_grad_(False) text_encoder.requires_grad_(False) unet.to(accelerator.device) @@ -497,13 +512,17 @@ def remove_model(old_ckpt_name): loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = controlnet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + if not args.fused_backward_pass: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: From 8919b31145d38a2a790fae6e8e1c34c205c6794e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 23:07:34 +0900 Subject: [PATCH 156/163] use original ControlNet instead of Diffusers --- gen_img.py | 89 +++- library/sdxl_model_util.py | 2 +- library/sdxl_original_control_net.py | 272 ++++++++++++ library/sdxl_original_unet.py | 14 +- ...controlnet.py => sdxl_train_control_net.py | 390 ++++++++---------- 5 files changed, 528 insertions(+), 239 deletions(-) create mode 100644 library/sdxl_original_control_net.py rename sdxl_train_controlnet.py => sdxl_train_control_net.py (69%) diff --git a/gen_img.py b/gen_img.py index 59bcd5b09..70b3c81ff 100644 --- a/gen_img.py +++ b/gen_img.py @@ -43,8 +43,8 @@ ) from einops import rearrange from tqdm import tqdm -from torchvision import transforms from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor +from accelerate import init_empty_weights import PIL from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -58,6 +58,7 @@ from tools.original_control_net import ControlNetInfo from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel from library.sdxl_original_unet import InferSdxlUNet2DConditionModel +from library.sdxl_original_control_net import SdxlControlNet from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL @@ -352,8 +353,8 @@ def __init__( self.token_replacements_list.append({}) # ControlNet - self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5 - self.control_net_lllites: List[ControlNetLLLite] = [] + self.control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = [] + self.control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない self.gradual_latent: GradualLatent = None @@ -542,7 +543,7 @@ def __call__( else: text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - if self.control_net_lllites: + if self.control_net_lllites or (self.control_nets and self.is_sdxl): # ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う if isinstance(clip_guide_images, PIL.Image.Image): clip_guide_images = [clip_guide_images] @@ -731,7 +732,12 @@ def __call__( num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 if self.control_nets: - guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if not self.is_sdxl: + guided_hints = original_control_net.get_guided_hints( + self.control_nets, num_latent_input, batch_size, clip_guide_images + ) + else: + clip_guide_images = clip_guide_images * 0.5 + 0.5 # [-1, 1] => [0, 1] each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) if self.control_net_lllites: @@ -793,7 +799,7 @@ def __call__( latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo + # disable ControlNet-LLLite or SDXL ControlNet if ratio is set. ControlNet is disabled in ControlNetInfo if self.control_net_lllites: for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)): if not enabled or ratio >= 1.0: @@ -802,9 +808,16 @@ def __call__( logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") control_net.set_cond_image(None) each_control_net_enabled[j] = False + if self.control_nets and self.is_sdxl: + for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)): + if not enabled or ratio >= 1.0: + continue + if ratio < i / len(timesteps): + logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + each_control_net_enabled[j] = False # predict the noise residual - if self.control_nets and self.control_net_enabled: + if self.control_nets and self.control_net_enabled and not self.is_sdxl: if regional_network: num_sub_and_neg_prompts = len(text_embeddings) // batch_size text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt @@ -823,6 +836,31 @@ def __call__( text_embeddings, text_emb_last, ).sample + elif self.control_nets: + input_resi_add_list = [] + mid_add_list = [] + for (control_net, _), enbld in zip(self.control_nets, each_control_net_enabled): + if not enbld: + continue + input_resi_add, mid_add = control_net( + latent_model_input, t, text_embeddings, vector_embeddings, clip_guide_images + ) + input_resi_add_list.append(input_resi_add) + mid_add_list.append(mid_add) + if len(input_resi_add_list) == 0: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + else: + if len(input_resi_add_list) > 1: + # get mean of input_resi_add_list and mid_add_list + input_resi_add_mean = [] + for i in range(len(input_resi_add_list[0])): + input_resi_add_mean.append( + torch.mean(torch.stack([input_resi_add_list[j][i] for j in range(len(input_resi_add_list))], dim=0)) + ) + input_resi_add = input_resi_add_mean + mid_add = torch.mean(torch.stack(mid_add_list), dim=0) + + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add) elif self.is_sdxl: noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) else: @@ -1827,16 +1865,37 @@ def __getattr__(self, item): upscaler.to(dtype).to(device) # ControlNetの処理 - control_nets: List[ControlNetInfo] = [] + control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = [] if args.control_net_models: - for i, model in enumerate(args.control_net_models): - prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + if not is_sdxl: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + else: + for i, model_file in enumerate(args.control_net_models): + multiplier = ( + 1.0 + if not args.control_net_multipliers or len(args.control_net_multipliers) <= i + else args.control_net_multipliers[i] + ) + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + logger.info(f"loading SDXL ControlNet: {model_file}") + from safetensors.torch import load_file + + state_dict = load_file(model_file) - ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) - prep = original_control_net.load_preprocess(prep_type) - control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + logger.info(f"Initalizing SDXL ControlNet with multiplier: {multiplier}") + with init_empty_weights(): + control_net = SdxlControlNet(multiplier=multiplier) + control_net.load_state_dict(state_dict) + control_net.to(dtype).to(device) + control_nets.append((control_net, ratio)) control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] if args.control_net_lllite_models: diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 4fad78a1c..0466c1fa5 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -8,7 +8,7 @@ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet -from .utils import setup_logging +from library.utils import setup_logging setup_logging() import logging diff --git a/library/sdxl_original_control_net.py b/library/sdxl_original_control_net.py new file mode 100644 index 000000000..3af45f4db --- /dev/null +++ b/library/sdxl_original_control_net.py @@ -0,0 +1,272 @@ +# some parts are modified from Diffusers library (Apache License 2.0) + +import math +from types import SimpleNamespace +from typing import Any, Optional +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from einops import rearrange +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sdxl_original_unet +from library.sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl + + +class ControlNetConditioningEmbedding(nn.Module): + def __init__(self): + super().__init__() + + dims = [16, 32, 96, 256] + + self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1) + self.blocks = nn.ModuleList([]) + + for i in range(len(dims) - 1): + channel_in = dims[i] + channel_out = dims[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1) + nn.init.zeros_(self.conv_out.weight) # zero module weight + nn.init.zeros_(self.conv_out.bias) # zero module bias + + def forward(self, x): + x = self.conv_in(x) + x = F.silu(x) + for block in self.blocks: + x = block(x) + x = F.silu(x) + x = self.conv_out(x) + return x + + +class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel): + def __init__(self, multiplier: Optional[float] = None, **kwargs): + super().__init__(**kwargs) + self.multiplier = multiplier + + # remove unet layers + self.output_blocks = nn.ModuleList([]) + del self.out + + self.controlnet_cond_embedding = ControlNetConditioningEmbedding() + + dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280] + self.controlnet_down_blocks = nn.ModuleList([]) + for dim in dims: + self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1)) + nn.init.zeros_(self.controlnet_down_blocks[-1].weight) # zero module weight + nn.init.zeros_(self.controlnet_down_blocks[-1].bias) # zero module bias + + self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1) + nn.init.zeros_(self.controlnet_mid_block.weight) # zero module weight + nn.init.zeros_(self.controlnet_mid_block.bias) # zero module bias + + def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel): + unet_sd = unet.state_dict() + unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")} + sd = super().state_dict() + sd.update(unet_sd) + info = super().load_state_dict(sd, strict=True, assign=True) + return info + + def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any: + # convert state_dict to SAI format + unet_sd = {} + for k in list(state_dict.keys()): + if not k.startswith("controlnet_"): + unet_sd[k] = state_dict.pop(k) + unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd) + state_dict.update(unet_sd) + super().load_state_dict(state_dict, strict=strict, assign=assign) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + # convert state_dict to Diffusers format + state_dict = super().state_dict(destination, prefix, keep_vars) + control_net_sd = {} + for k in list(state_dict.keys()): + if k.startswith("controlnet_"): + control_net_sd[k] = state_dict.pop(k) + state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict) + state_dict.update(control_net_sd) + return state_dict + + def forward( + self, + x: torch.Tensor, + timesteps: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + cond_image: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + if isinstance(layer, sdxl_original_unet.ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, sdxl_original_unet.Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + h = x + multiplier = self.multiplier if self.multiplier is not None else 1.0 + hs = [] + for i, module in enumerate(self.input_blocks): + h = call_module(module, h, emb, context) + if i == 0: + h = self.controlnet_cond_embedding(cond_image) + h + hs.append(self.controlnet_down_blocks[i](h) * multiplier) + + h = call_module(self.middle_block, h, emb, context) + h = self.controlnet_mid_block(h) * multiplier + + return hs, h + + +class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel): + """ + This class is for training purpose only. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + hs = [] + t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + if isinstance(layer, sdxl_original_unet.ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, sdxl_original_unet.Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + h = x + for module in self.input_blocks: + h = call_module(module, h, emb, context) + hs.append(h) + + h = call_module(self.middle_block, h, emb, context) + h = h + mid_add + + for module in self.output_blocks: + resi = hs.pop() + input_resi_add.pop() + h = torch.cat([h, resi], dim=1) + h = call_module(module, h, emb, context) + + h = h.type(x.dtype) + h = call_module(self.out, h, emb, context) + + return h + + +if __name__ == "__main__": + import time + + logger.info("create unet") + unet = SdxlControlledUNet() + unet.to("cuda", torch.bfloat16) + unet.set_use_sdpa(True) + unet.set_gradient_checkpointing(True) + unet.train() + + logger.info("create control_net") + control_net = SdxlControlNet() + control_net.to("cuda") + control_net.set_use_sdpa(True) + control_net.set_gradient_checkpointing(True) + control_net.train() + + logger.info("Initialize control_net from unet") + control_net.init_from_unet(unet) + + unet.requires_grad_(False) + control_net.requires_grad_(True) + + # 使用メモリ量確認用の疑似学習ループ + logger.info("preparing optimizer") + + # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working + + import bitsandbytes + + optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) # not working + # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + + # import transformers + # optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2 + + scaler = torch.cuda.amp.GradScaler(enabled=True) + + logger.info("start training") + steps = 10 + batch_size = 1 + + for step in range(steps): + logger.info(f"step {step}") + if step == 1: + time_start = time.perf_counter() + + x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024 + t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda") + txt = torch.randn(batch_size, 77, 2048).cuda() + vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() + cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda() + + with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): + input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img) + output = unet(x, t, txt, vector, input_resi_add, mid_add) + target = torch.randn_like(output) + loss = torch.nn.functional.mse_loss(output, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + + time_end = time.perf_counter() + logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") + + logger.info("finish training") + sd = control_net.state_dict() + + from safetensors.torch import save_file + + save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors") diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 17c345a89..0aa07d0d6 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -30,7 +30,7 @@ from torch import nn from torch.nn import functional as F from einops import rearrange -from .utils import setup_logging +from library.utils import setup_logging setup_logging() import logging @@ -1156,9 +1156,9 @@ def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_ti self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 self.ds_ratio = ds_ratio - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): r""" - current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink. + current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink and ControlNet. """ _self = self.delegate @@ -1209,6 +1209,8 @@ def call_module(module, h, emb, context): hs.append(h) h = call_module(_self.middle_block, h, emb, context) + if mid_add is not None: + h = h + mid_add for module in _self.output_blocks: # Deep Shrink @@ -1217,7 +1219,11 @@ def call_module(module, h, emb, context): # print("upsample", h.shape, hs[-1].shape) h = resize_like(h, hs[-1]) - h = torch.cat([h, hs.pop()], dim=1) + resi = hs.pop() + if input_resi_add is not None: + resi = resi + input_resi_add.pop() + + h = torch.cat([h, resi], dim=1) h = call_module(module, h, emb, context) # Deep Shrink: in case of depth 0 diff --git a/sdxl_train_controlnet.py b/sdxl_train_control_net.py similarity index 69% rename from sdxl_train_controlnet.py rename to sdxl_train_control_net.py index 00026d2cc..74dcff2af 100644 --- a/sdxl_train_controlnet.py +++ b/sdxl_train_control_net.py @@ -14,6 +14,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed +from accelerate import init_empty_weights from diffusers import DDPMScheduler, ControlNetModel from diffusers.utils.torch_utils import is_compiled_module from safetensors.torch import load_file @@ -23,6 +24,9 @@ sdxl_model_util, sdxl_original_unet, sdxl_train_util, + strategy_base, + strategy_sd, + strategy_sdxl, ) import library.model_util as model_util @@ -41,6 +45,7 @@ scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) +from library.sdxl_original_control_net import SdxlControlNet, SdxlControlledUNet from library.utils import setup_logging, add_logging_arguments setup_logging() @@ -58,10 +63,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche } if args.optimizer_type.lower().startswith("DAdapt".lower()): - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[-1].param_groups[0]["d"] - * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - ) + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] return logs @@ -79,7 +81,14 @@ def train(args): args.seed = random.randint(0, 2**32) set_seed(args.seed) - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) @@ -106,17 +115,18 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) - ds_for_collator = (train_dataset_group if args.max_data_loader_n_workers == 0 else None) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) train_dataset_group.verify_bucket_reso_steps(32) if args.debug_dataset: + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: @@ -162,86 +172,99 @@ def unwrap_model(model): unet, logit_scale, ckpt_info, - ) = sdxl_train_util.load_target_model( - args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype - ) + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) + + unet.to(accelerator.device) # reduce main memory usage + + # convert U-Net to Controlled U-Net + logger.info("convert U-Net to Controlled U-Net") + unet_sd = unet.state_dict() + with init_empty_weights(): + unet = SdxlControlledUNet() + unet.load_state_dict(unet_sd, strict=True, assign=True) + del unet_sd + + # make control net + logger.info("make ControlNet") + if args.controlnet_model_path: + with init_empty_weights(): + control_net = SdxlControlNet() + + logger.info(f"load ControlNet from {args.controlnet_model_path}") + filename = args.controlnet_model_path + if os.path.splitext(filename)[1] == ".safetensors": + state_dict = load_file(filename) + else: + state_dict = torch.load(filename) + info = control_net.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"ControlNet loaded from {filename}: {info}") + else: + control_net = SdxlControlNet() - # convert U-Net - with torch.no_grad(): - du_unet_sd = sdxl_model_util.convert_sdxl_unet_state_dict_to_diffusers(unet.state_dict()) - unet.to("cpu") - clean_memory_on_device(accelerator.device) - del unet - unet = sdxl_model_util.UNet2DConditionModel(**sdxl_model_util.DIFFUSERS_SDXL_UNET_CONFIG) - unet.load_state_dict(du_unet_sd) - - controlnet = ControlNetModel.from_unet(unet) - - if args.controlnet_model_name_or_path: - filename = args.controlnet_model_name_or_path - if os.path.isfile(filename): - if os.path.splitext(filename)[1] == ".safetensors": - state_dict = load_file(filename) - else: - state_dict = torch.load(filename) - state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) - controlnet.load_state_dict(state_dict) - elif os.path.isdir(filename): - controlnet = ControlNetModel.from_pretrained(filename) + logger.info("initialize ControlNet from U-Net") + info = control_net.init_from_unet(unet) + logger.info(f"ControlNet initialized from U-Net: {info}") # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + accelerator.wait_for_everyone() # モデルに xformers とか memory efficient attention を組み込む # train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) if args.xformers: - unet.enable_xformers_memory_efficient_attention() - controlnet.enable_xformers_memory_efficient_attention() + unet.set_use_memory_efficient_attention(True, False) + control_net.set_use_memory_efficient_attention(True, False) + elif args.sdpa: + unet.set_use_sdpa(True) + control_net.set_use_sdpa(True) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - controlnet.enable_gradient_checkpointing() + control_net.enable_gradient_checkpointing() # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - trainable_params = list(filter(lambda p: p.requires_grad, controlnet.parameters())) + trainable_params = list(control_net.parameters()) + # for p in trainable_params: + # p.requires_grad = True logger.info(f"trainable params count: {len(trainable_params)}") - logger.info( - f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}" - ) + logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -257,7 +280,7 @@ def unwrap_model(model): # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader)/ accelerator.num_processes/ args.gradient_accumulation_steps + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) accelerator.print( f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" @@ -267,9 +290,7 @@ def unwrap_model(model): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix( - args, optimizer, accelerator.num_processes - ) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする if args.full_fp16: @@ -277,19 +298,17 @@ def unwrap_model(model): args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" accelerator.print("enable full fp16 training.") - controlnet.to(weight_dtype) - unet.to(weight_dtype) + control_net.to(weight_dtype) elif args.full_bf16: assert ( args.mixed_precision == "bf16" ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" accelerator.print("enable full bf16 training.") - controlnet.to(weight_dtype) - unet.to(weight_dtype) + control_net.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - controlnet, optimizer, train_dataloader, lr_scheduler + control_net, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + control_net, optimizer, train_dataloader, lr_scheduler ) if args.fused_backward_pass: @@ -314,10 +333,8 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): text_encoder2.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) - # transform DDP after prepare - controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet - - controlnet.train() + unet.eval() + control_net.train() # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: @@ -362,26 +379,15 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - progress_bar = tqdm( - range(args.max_train_steps), - smoothing=0, - disable=not accelerator.is_local_main_process, - desc="steps", - ) + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - clip_sample=False, + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr( - noise_scheduler - ) + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) if accelerator.is_main_process: init_kwargs = {} @@ -390,11 +396,7 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - ( - "controlnet_train" - if args.log_tracker_name is None - else args.log_tracker_name - ), + ("sdxl_control_net_train" if args.log_tracker_name is None else args.log_tracker_name), config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs, ) @@ -409,10 +411,8 @@ def save_model(ckpt_name, model, force_sync_upload=False): accelerator.print(f"\nsaving checkpoint: {ckpt_file}") sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) - sai_metadata["modelspec.architecture"] = ( - sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet" - ) - state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) + sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet" + state_dict = model.state_dict() if save_dtype is not None: for key in list(state_dict.keys()): @@ -436,19 +436,19 @@ def remove_model(old_ckpt_name): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) - # For --sample_at_first - sdxl_train_util.sample_images( - accelerator, - args, - 0, - global_step, - accelerator.device, - vae, - [tokenizer1, tokenizer2], - [text_encoder1, text_encoder2], - unet, - controlnet=controlnet, - ) + # # For --sample_at_first + # sdxl_train_util.sample_images( + # accelerator, + # args, + # 0, + # global_step, + # accelerator.device, + # vae, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # unet, + # controlnet=control_net, + # ) # training loop for epoch in range(num_train_epochs): @@ -457,121 +457,63 @@ def remove_model(old_ckpt_name): for step, batch in enumerate(train_dataloader): current_step.value = global_step - with accelerator.accumulate(controlnet): + with accelerator.accumulate(control_net): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = ( - batch["latents"] - .to(accelerator.device) - .to(dtype=weight_dtype) - ) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 - latents = ( - vae.encode(batch["images"].to(dtype=vae_dtype)) - .latent_dist.sample() - .to(dtype=weight_dtype) - ) + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): - accelerator.print( - "NaN found in latents, replacing with zeros" - ) + accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - if ( - "text_encoder_outputs1_list" not in batch - or batch["text_encoder_outputs1_list"] is None - ): - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] with torch.no_grad(): - # Get the text embedding for conditioning input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = ( - train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, - ) + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] ) - else: - encoder_hidden_states1 = ( - batch["text_encoder_outputs1_list"] - .to(accelerator.device) - .to(weight_dtype) - ) - encoder_hidden_states2 = ( - batch["text_encoder_outputs2_list"] - .to(accelerator.device) - .to(weight_dtype) - ) - pool2 = ( - batch["text_encoder_pool2_list"] - .to(accelerator.device) - .to(weight_dtype) - ) + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) # get size embeddings orig_size = batch["original_sizes_hw"] crop_size = batch["crop_top_lefts"] target_size = batch["target_sizes_hw"] - # embs = sdxl_train_util.get_size_embeddings( - # orig_size, crop_size, target_size, accelerator.device - # ).to(weight_dtype) + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) - embs = torch.cat([orig_size, crop_size, target_size]).to(accelerator.device).to(weight_dtype) #B,6 # concat embeddings - #vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) - vector_embedding_dict = { - "text_embeds": pool2, - "time_ids": embs - } - text_embedding = torch.cat( - [encoder_hidden_states1, encoder_hidden_states2], dim=2 - ).to(weight_dtype) + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = ( - train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents ) controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) - with accelerator.autocast(): - down_block_res_samples, mid_block_res_sample = controlnet( - noisy_latents, - timesteps, - encoder_hidden_states=text_embedding, - added_cond_kwargs=vector_embedding_dict, - controlnet_cond=controlnet_image, - return_dict=False, + input_resi_add, mid_add = control_net( + noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image ) - - # Predict the noise residual - noise_pred = unet( - noisy_latents, - timesteps, - encoder_hidden_states=text_embedding, - added_cond_kwargs=vector_embedding_dict, - down_block_additional_residuals=[ - sample.to(dtype=weight_dtype) for sample in down_block_res_samples - ], - mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), - return_dict=False, - )[0] + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, input_resi_add, mid_add) if args.v_parameterization: # v-parameterization training @@ -580,7 +522,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss( - noise_pred.float(),target.float(),reduction="none",loss_type=args.loss_type,huber_c=huber_c, + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) loss = loss.mean([1, 2, 3]) @@ -588,7 +530,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss,timesteps,noise_scheduler,args.min_snr_gamma,args.v_parameterization) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: @@ -601,7 +543,7 @@ def remove_model(old_ckpt_name): accelerator.backward(loss) if not args.fused_backward_pass: if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = controlnet.parameters() + params_to_clip = control_net.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -616,25 +558,25 @@ def remove_model(old_ckpt_name): progress_bar.update(1) global_step += 1 - sdxl_train_util.sample_images( - accelerator, - args, - None, - global_step, - accelerator.device, - vae, - [tokenizer1, tokenizer2], - [text_encoder1, text_encoder2], - unet, - controlnet=controlnet, - ) + # sdxl_train_util.sample_images( + # accelerator, + # args, + # None, + # global_step, + # accelerator.device, + # vae, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # unet, + # controlnet=control_net, + # ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name,unwrap_model(controlnet)) + save_model(ckpt_name, unwrap_model(control_net)) if args.save_state: train_util.save_and_remove_state_stepwise(args, accelerator, global_step) @@ -650,14 +592,14 @@ def remove_model(old_ckpt_name): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) @@ -668,7 +610,7 @@ def remove_model(old_ckpt_name): saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name,unwrap_model(controlnet)) + save_model(ckpt_name, unwrap_model(control_net)) remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) if remove_epoch_no is not None: @@ -688,13 +630,13 @@ def remove_model(old_ckpt_name): [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet, - controlnet=controlnet, + controlnet=control_net, ) # end of epoch if is_main_process: - controlnet = unwrap_model(controlnet) + control_net = unwrap_model(control_net) accelerator.end_training() @@ -703,9 +645,7 @@ def remove_model(old_ckpt_name): if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model( - ckpt_name, controlnet, force_sync_upload=True - ) + save_model(ckpt_name, control_net, force_sync_upload=True) logger.info("model saved.") @@ -717,26 +657,38 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) - train_util.add_masked_loss_arguments(parser) + # train_util.add_masked_loss_arguments(parser) deepspeed_utils.add_deepspeed_arguments(parser) - train_util.add_sd_saving_arguments(parser) + # train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) sdxl_train_util.add_sdxl_training_arguments(parser) parser.add_argument( - "--controlnet_model_name_or_path", + "--controlnet_model_path", type=str, default=None, help="controlnet model name or path / controlnetのモデル名またはパス", ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) parser.add_argument( "--no_half_vae", action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) - return parser From 0243c65877a7700ffab1e782690f26080a0deadc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 23:09:56 +0900 Subject: [PATCH 157/163] fix typo --- gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gen_img.py b/gen_img.py index 70b3c81ff..421d5c0b9 100644 --- a/gen_img.py +++ b/gen_img.py @@ -1890,7 +1890,7 @@ def __getattr__(self, item): state_dict = load_file(model_file) - logger.info(f"Initalizing SDXL ControlNet with multiplier: {multiplier}") + logger.info(f"Initializing SDXL ControlNet with multiplier: {multiplier}") with init_empty_weights(): control_net = SdxlControlNet(multiplier=multiplier) control_net.load_state_dict(state_dict) From 012e7e63a5b1acdf69c72eee4cb330a5a6defc41 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 23:18:16 +0900 Subject: [PATCH 158/163] fix to work linear/cosine scheduler closes #1651 ref #1393 --- library/train_util.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 422dceca2..27910dc90 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4496,6 +4496,15 @@ def wrap_check_needless_num_warmup_steps(return_vals): **lr_scheduler_kwargs, ) + # these schedulers do not require `num_decay_steps` + if name == SchedulerType.LINEAR or name == SchedulerType.COSINE: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **lr_scheduler_kwargs, + ) + # All other schedulers require `num_decay_steps` if num_decay_steps is None: raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") From 793999d116638548fc16579b712f44456ee3034e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 30 Sep 2024 23:39:32 +0900 Subject: [PATCH 159/163] sample generation in SDXL ControlNet training --- library/sdxl_lpw_stable_diffusion.py | 168 +++++++---------------- library/strategy_base.py | 192 ++++++++++++++++++++++++++- library/strategy_sdxl.py | 39 +++++- library/train_util.py | 35 +++-- sdxl_train_control_net.py | 55 ++++---- 5 files changed, 323 insertions(+), 166 deletions(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index 03b182566..9196eb0f2 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -13,12 +13,20 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from diffusers import SchedulerMixin, StableDiffusionPipeline -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.models import AutoencoderKL +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.utils import logging from PIL import Image -from library import sdxl_model_util, sdxl_train_util, train_util +from library import ( + sdxl_model_util, + sdxl_train_util, + strategy_base, + strategy_sdxl, + train_util, + sdxl_original_unet, + sdxl_original_control_net, +) try: @@ -537,7 +545,7 @@ def __init__( vae: AutoencoderKL, text_encoder: List[CLIPTextModel], tokenizer: List[CLIPTokenizer], - unet: UNet2DConditionModel, + unet: Union[sdxl_original_unet.SdxlUNet2DConditionModel, sdxl_original_control_net.SdxlControlledUNet], scheduler: SchedulerMixin, # clip_skip: int, safety_checker: StableDiffusionSafetyChecker, @@ -594,74 +602,6 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - is_sdxl_text_encoder2, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `list(int)`): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - is_sdxl_text_encoder2=is_sdxl_text_encoder2, - ) - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ?? - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - if text_pool is not None: - text_pool = text_pool.repeat(1, num_images_per_prompt) - text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1) - - if do_classifier_free_guidance: - bs_embed, seq_len, _ = uncond_embeddings.shape - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - if uncond_pool is not None: - uncond_pool = uncond_pool.repeat(1, num_images_per_prompt) - uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1) - - return text_embeddings, text_pool, uncond_embeddings, uncond_pool - - return text_embeddings, text_pool, None, None - def check_inputs(self, prompt, height, width, strength, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @@ -792,7 +732,7 @@ def __call__( max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, - controlnet=None, + controlnet: sdxl_original_control_net.SdxlControlNet = None, controlnet_image=None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, @@ -896,32 +836,24 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - # 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す - # To simplify the implementation, switch the tokenzer/text encoder and call it twice - text_embeddings_list = [] - text_pool = None - uncond_embeddings_list = [] - uncond_pool = None - for i in range(len(self.tokenizers)): - self.tokenizer = self.tokenizers[i] - self.text_encoder = self.text_encoders[i] - - text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - is_sdxl_text_encoder2=i == 1, - ) - text_embeddings_list.append(text_embeddings) - uncond_embeddings_list.append(uncond_embeddings) + tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy: strategy_sdxl.SdxlTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - if tp1 is not None: - text_pool = tp1 - if up1 is not None: - uncond_pool = up1 + text_input_ids, text_weights = tokenize_strategy.tokenize_with_weights(prompt) + hidden_states_1, hidden_states_2, text_pool = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, self.text_encoders, text_input_ids, text_weights + ) + text_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1) + + if do_classifier_free_guidance: + input_ids, weights = tokenize_strategy.tokenize_with_weights(negative_prompt or "") + hidden_states_1, hidden_states_2, uncond_pool = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, self.text_encoders, input_ids, weights + ) + uncond_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1) + else: + uncond_embeddings = None + uncond_pool = None unet_dtype = self.unet.dtype dtype = unet_dtype @@ -970,23 +902,23 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # create size embs and concat embeddings for SDXL - orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype) + orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(device, dtype) crop_size = torch.zeros_like(orig_size) target_size = orig_size - embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype) + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(device, dtype) # make conditionings + text_pool = text_pool.to(device, dtype) if do_classifier_free_guidance: - text_embeddings = torch.cat(text_embeddings_list, dim=2) - uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2) - text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype) + text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(device, dtype) - cond_vector = torch.cat([text_pool, embs], dim=1) - uncond_vector = torch.cat([uncond_pool, embs], dim=1) - vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype) + uncond_pool = uncond_pool.to(device, dtype) + cond_vector = torch.cat([text_pool, embs], dim=1).to(dtype) + uncond_vector = torch.cat([uncond_pool, embs], dim=1).to(dtype) + vector_embedding = torch.cat([uncond_vector, cond_vector]) else: - text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype) - vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype) + text_embedding = text_embeddings.to(device, dtype) + vector_embedding = torch.cat([text_pool, embs], dim=1) # 8. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): @@ -994,22 +926,14 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - unet_additional_args = {} - if controlnet is not None: - down_block_res_samples, mid_block_res_sample = controlnet( - latent_model_input, - t, - encoder_hidden_states=text_embeddings, - controlnet_cond=controlnet_image, - conditioning_scale=1.0, - guess_mode=False, - return_dict=False, - ) - unet_additional_args["down_block_additional_residuals"] = down_block_res_samples - unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample + # FIXME SD1 ControlNet is not working # predict the noise residual - noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) + if controlnet is not None: + input_resi_add, mid_add = controlnet(latent_model_input, t, text_embedding, vector_embedding, controlnet_image) + noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding, input_resi_add, mid_add) + else: + noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training # perform guidance diff --git a/library/strategy_base.py b/library/strategy_base.py index e7d3a97ef..10820afa1 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -1,6 +1,7 @@ # base class for platform strategies. this file defines the interface for strategies import os +import re from typing import Any, List, Optional, Tuple, Union import numpy as np @@ -22,6 +23,24 @@ class TokenizeStrategy: _strategy = None # strategy instance: actual strategy class + _re_attention = re.compile( + r"""\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, + ) + @classmethod def set_strategy(cls, strategy): if cls._strategy is not None: @@ -54,7 +73,151 @@ def _load_tokenizer( def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: raise NotImplementedError - def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor: + def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + raise NotImplementedError + + def _get_weighted_input_ids( + self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + max_length includes starting and ending tokens. + """ + + def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in TokenizeStrategy._re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + def get_prompts_with_weights(text: str, max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token. + + No padding, starting or ending token is included. + """ + truncated = False + + texts_and_weights = parse_prompt_attention(text) + tokens = [] + weights = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + tokens += token + # copy the weight by length of token + weights += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(tokens) > max_length: + truncated = True + break + # truncate + if len(tokens) > max_length: + truncated = True + tokens = tokens[:max_length] + weights = weights[:max_length] + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens)) + weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights)) + return tokens, weights + + if max_length is None: + max_length = tokenizer.model_max_length + + tokens, weights = get_prompts_with_weights(text, max_length - 2) + tokens, weights = pad_tokens_and_weights( + tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id + ) + return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0) + + def _get_input_ids( + self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False + ) -> torch.Tensor: """ for SD1.5/2.0/SDXL TODO support batch input @@ -62,7 +225,10 @@ def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Option if max_length is None: max_length = tokenizer.model_max_length - 2 - input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids + if weighted: + input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length) + else: + input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids if max_length > tokenizer.model_max_length: input_ids = input_ids.squeeze(0) @@ -101,6 +267,17 @@ def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Option iids_list.append(ids_chunk) input_ids = torch.stack(iids_list) # 3,77 + + if weighted: + weights = weights.squeeze(0) + new_weights = torch.ones(input_ids.shape) + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): + b = i // (tokenizer.model_max_length - 2) + new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2] + weights = new_weights + + if weighted: + return input_ids, weights return input_ids @@ -126,6 +303,17 @@ def encode_tokens( :return: list of output embeddings for each architecture """ raise NotImplementedError + + def encode_tokens_with_weights( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Encode tokens into embeddings and outputs. + :param tokens: list of token tensors for each TextModel + :param weights: list of weight tensors for each TextModel + :return: list of output embeddings for each architecture + """ + raise NotImplementedError class TextEncoderOutputsCachingStrategy: diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index 3eb0ab6f6..b48e6d55a 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -37,6 +37,22 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0), ) + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + text = [text] if isinstance(text, str) else text + tokens1_list, tokens2_list = [], [] + weights1_list, weights2_list = [], [] + for t in text: + tokens1, weights1 = self._get_weighted_input_ids(self.tokenizer1, t, self.max_length) + tokens2, weights2 = self._get_weighted_input_ids(self.tokenizer2, t, self.max_length) + tokens1_list.append(tokens1) + tokens2_list.append(tokens2) + weights1_list.append(weights1) + weights2_list.append(weights2) + return (torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)), ( + torch.stack(weights1_list, dim=0), + torch.stack(weights2_list, dim=0), + ) + class SdxlTextEncodingStrategy(TextEncodingStrategy): def __init__(self) -> None: @@ -98,7 +114,10 @@ def _get_hidden_states_sdxl( ): # input_ids: b,n,77 -> b*n, 77 b_size = input_ids1.size()[0] - max_token_length = input_ids1.size()[1] * input_ids1.size()[2] + if input_ids1.size()[1] == 1: + max_token_length = None + else: + max_token_length = input_ids1.size()[1] * input_ids1.size()[2] input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 input_ids1 = input_ids1.to(text_encoder1.device) @@ -172,6 +191,24 @@ def encode_tokens( ) return [hidden_states1, hidden_states2, pool2] + def encode_tokens_with_weights( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] + ) -> List[torch.Tensor]: + hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens) + + # apply weights + if weights[0].shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + hidden_states1 = hidden_states1 * weights[0].squeeze(1).unsqueeze(2) + hidden_states2 = hidden_states2 * weights[1].squeeze(1).unsqueeze(2) + else: + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for weight, hidden_states in zip(weights, [hidden_states1, hidden_states2]): + for i in range(weight.shape[1]): + hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[:, i, 1:-1] + + return [hidden_states1, hidden_states2, pool2] + class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" diff --git a/library/train_util.py b/library/train_util.py index 293fc05ad..b559616f2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -74,6 +74,7 @@ import cv2 import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline +from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec @@ -3581,7 +3582,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: # available backends: # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt", "tensort", "ipex", "tvm"], + choices=[ + "eager", + "aot_eager", + "inductor", + "aot_ts_nvfuser", + "nvprims_nvfuser", + "cudagraphs", + "ofi", + "fx2trt", + "onnxrt", + "tensort", + "ipex", + "tvm", + ], help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") @@ -5850,8 +5864,8 @@ def sample_images_common( pipe_class, accelerator: Accelerator, args: argparse.Namespace, - epoch, - steps, + epoch: int, + steps: int, device, vae, tokenizer, @@ -5910,11 +5924,7 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - # schedulers: dict = {} cannot find where this is used - default_scheduler = get_my_scheduler( - sample_sampler=args.sample_sampler, - v_parameterization=args.v_parameterization, - ) + default_scheduler = get_my_scheduler(sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization) pipeline = pipe_class( text_encoder=text_encoder, @@ -5975,21 +5985,18 @@ def sample_images_common( # clear pipeline and cache to reduce vram usage del pipeline - # I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here. - # with torch.cuda.device(torch.cuda.current_device()): - # torch.cuda.empty_cache() - clean_memory_on_device(accelerator.device) - torch.set_rng_state(rng_state) if torch.cuda.is_available() and cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, - pipeline, + pipeline: Union[StableDiffusionLongPromptWeightingPipeline, SdxlStableDiffusionLongPromptWeightingPipeline], save_dir, prompt_dict, epoch, diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 74dcff2af..583a27dcc 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -83,6 +83,7 @@ def train(args): tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizer1, tokenizer2 = tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2 # this is used for sampling images # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( @@ -436,19 +437,19 @@ def remove_model(old_ckpt_name): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) - # # For --sample_at_first - # sdxl_train_util.sample_images( - # accelerator, - # args, - # 0, - # global_step, - # accelerator.device, - # vae, - # [tokenizer1, tokenizer2], - # [text_encoder1, text_encoder2], - # unet, - # controlnet=control_net, - # ) + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, + args, + 0, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], + unet, + controlnet=control_net, + ) # training loop for epoch in range(num_train_epochs): @@ -484,7 +485,7 @@ def remove_model(old_ckpt_name): input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( - tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] + tokenize_strategy, [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], [input_ids1, input_ids2] ) if args.full_fp16: encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) @@ -558,18 +559,18 @@ def remove_model(old_ckpt_name): progress_bar.update(1) global_step += 1 - # sdxl_train_util.sample_images( - # accelerator, - # args, - # None, - # global_step, - # accelerator.device, - # vae, - # [tokenizer1, tokenizer2], - # [text_encoder1, text_encoder2], - # unet, - # controlnet=control_net, - # ) + sdxl_train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], + unet, + controlnet=control_net, + ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -628,7 +629,7 @@ def remove_model(old_ckpt_name): accelerator.device, vae, [tokenizer1, tokenizer2], - [text_encoder1, text_encoder2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], unet, controlnet=control_net, ) From c2440f9e53239e7e5dee426f611800d3e38a7f0e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 3 Oct 2024 21:32:21 +0900 Subject: [PATCH 160/163] fix cond image normlization, add independent LR for control --- library/sdxl_train_util.py | 3 ++- library/train_util.py | 20 +++++++++++++++++++- sdxl_train_control_net.py | 30 +++++++++++++++++++++++++----- 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f009b5779..aaf77b8dd 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -12,7 +12,6 @@ from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet -from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline from .utils import setup_logging setup_logging() @@ -378,4 +377,6 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin def sample_images(*args, **kwargs): + from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline + return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) diff --git a/library/train_util.py b/library/train_util.py index b559616f2..07c253a0e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -31,6 +31,7 @@ import subprocess from io import BytesIO import toml +# from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm @@ -912,6 +913,23 @@ def make_buckets(self): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) + # # run in parallel + # max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes) + # with ThreadPoolExecutor(max_workers) as executor: + # futures = [] + # for info in tqdm(self.image_data.values(), desc="loading image sizes"): + # if info.image_size is None: + # def get_and_set_image_size(info): + # info.image_size = self.get_image_size(info.absolute_path) + # futures.append(executor.submit(get_and_set_image_size, info)) + # # consume futures to reduce memory usage and prevent Ctrl-C hang + # if len(futures) >= max_workers: + # for future in futures: + # future.result() + # futures = [] + # for future in futures: + # future.result() + if self.enable_bucket: logger.info("make buckets") else: @@ -1826,7 +1844,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] missing_captions = [] - for img_path in img_paths: + for img_path in tqdm(img_paths, desc="read caption"): cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) if cap_for_img is None and subset.class_tokens is None: logger.warning( diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 583a27dcc..b902cda69 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -253,11 +253,20 @@ def unwrap_model(model): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - trainable_params = list(control_net.parameters()) - # for p in trainable_params: - # p.requires_grad = True - logger.info(f"trainable params count: {len(trainable_params)}") - logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + trainable_params = [] + ctrlnet_params = [] + unet_params = [] + for name, param in control_net.named_parameters(): + if name.startswith("controlnet_"): + ctrlnet_params.append(param) + else: + unet_params.append(param) + trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr}) + trainable_params.append({"params": unet_params, "lr": args.learning_rate}) + all_params = ctrlnet_params + unet_params + + logger.info(f"trainable params count: {len(all_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -456,6 +465,8 @@ def remove_model(old_ckpt_name): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + control_net.train() + for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(control_net): @@ -510,6 +521,9 @@ def remove_model(old_ckpt_name): controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + # '-1 to +1' to '0 to 1' + controlnet_image = (controlnet_image + 1) / 2 + with accelerator.autocast(): input_resi_add, mid_add = control_net( noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image @@ -690,6 +704,12 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--control_net_lr", + type=float, + default=1e-4, + help="learning rate for controlnet / controlnetの学習率", + ) return parser From ba08a898940c80a6551111fdd77b53c6d3a019ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 4 Oct 2024 20:35:16 +0900 Subject: [PATCH 161/163] call optimizer eval/train for sample_at_first, also set train after resuming closes #1667 --- flux_train.py | 2 ++ train_network.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/flux_train.py b/flux_train.py index 022467ea7..81c13e4cc 100644 --- a/flux_train.py +++ b/flux_train.py @@ -706,7 +706,9 @@ def optimizer_hook(parameter: torch.Tensor): accelerator.unwrap_model(flux).prepare_block_swap_before_forward() # For --sample_at_first + optimizer_eval_fn() flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb accelerator.log({}, step=0) diff --git a/train_network.py b/train_network.py index 7b2b76a1b..f0d397b9e 100644 --- a/train_network.py +++ b/train_network.py @@ -1042,7 +1042,9 @@ def remove_model(old_ckpt_name): text_encoder = None # For --sample_at_first + optimizer_eval_fn() self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb accelerator.log({}, step=0) From 83e3048cb089bf6726751609da26da751b8383ae Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 6 Oct 2024 21:32:21 +0900 Subject: [PATCH 162/163] load Diffusers format, check schnell/dev --- README.md | 4 + flux_minimal_inference.py | 15 +-- flux_train.py | 15 ++- flux_train_network.py | 17 ++- library/flux_utils.py | 178 +++++++++++++++++++++++++++-- tools/convert_diffusers_to_flux.py | 78 +------------ 6 files changed, 196 insertions(+), 111 deletions(-) diff --git a/README.md b/README.md index 789fe514a..c567758a5 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 6, 2024: +- In FLUX.1 LoRA training and fine-tuning, the specified weight file (*.safetensors) is automatically determined to be dev or schnell. This allows schnell models to be loaded correctly. Note that LoRA training with schnell models and fine-tuning with schnell models are unverified. +- FLUX.1 LoRA training and fine-tuning can now load weights in Diffusers format in addition to BFL format (a single *.safetensors file). Please specify the parent directory of `transformer` or `diffusion_pytorch_model-00001-of-00003.safetensors` with the full path. However, Diffusers format CLIP/T5XXL is not supported. Saving is supported only in BFL format. + Sep 26, 2024: The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 2f1b9a377..7ab224f1b 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -419,9 +419,6 @@ def encode(prpt: str): steps = args.steps guidance_scale = args.guidance - name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way - is_schnell = name == "schnell" - def is_fp8(dt): return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] @@ -455,12 +452,8 @@ def is_fp8(dt): # if is_fp8(t5xxl_dtype): # t5xxl = accelerator.prepare(t5xxl) - t5xxl_max_length = 256 if is_schnell else 512 - tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) - encoding_strategy = strategy_flux.FluxTextEncodingStrategy() - # DiT - model = flux_utils.load_flow_model(name, args.ckpt_path, None, loading_device) + is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device) model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype @@ -469,8 +462,12 @@ def is_fp8(dt): # if args.offload: # model = model.to("cpu") + t5xxl_max_length = 256 if is_schnell else 512 + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) + encoding_strategy = strategy_flux.FluxTextEncodingStrategy() + # AE - ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device) + ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device) ae.eval() # if is_fp8(ae_dtype): # ae = accelerator.prepare(ae) diff --git a/flux_train.py b/flux_train.py index 81c13e4cc..ecc87c0a8 100644 --- a/flux_train.py +++ b/flux_train.py @@ -137,6 +137,7 @@ def train(args): train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) if args.debug_dataset: if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( @@ -144,9 +145,8 @@ def train(args): args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False ) ) - name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" t5xxl_max_token_length = ( - args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if name == "schnell" else 512) + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) ) strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) @@ -177,12 +177,11 @@ def train(args): weight_dtype, save_dtype = train_util.prepare_dtype(args) # モデルを読み込む - name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # load VAE for caching latents ae = None if cache_latents: - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae.to(accelerator.device, dtype=weight_dtype) ae.requires_grad_(False) ae.eval() @@ -196,7 +195,7 @@ def train(args): # prepare tokenize strategy if args.t5xxl_max_token_length is None: - if name == "schnell": + if is_schnell: t5xxl_max_token_length = 256 else: t5xxl_max_token_length = 512 @@ -258,8 +257,8 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - flux = flux_utils.load_flow_model( - name, args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + _, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) if args.gradient_checkpointing: @@ -294,7 +293,7 @@ def train(args): if not cache_latents: # load VAE here if not cached - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu") ae.requires_grad_(False) ae.eval() ae.to(accelerator.device, dtype=weight_dtype) diff --git a/flux_train_network.py b/flux_train_network.py index 65b121e7c..5d14bd28e 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -2,7 +2,7 @@ import copy import math import random -from typing import Any +from typing import Any, Optional import torch from accelerate import Accelerator @@ -24,6 +24,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() self.sample_prompts_te_outputs = None + self.is_schnell: Optional[bool] = None def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -57,19 +58,15 @@ def assert_extra_args(self, args, train_dataset_group): train_dataset_group.verify_bucket_reso_steps(32) # TODO check this - def get_flux_model_name(self, args): - return "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" - def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models - name = self.get_flux_model_name(args) # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) loading_dtype = None if args.fp8_base else weight_dtype # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future - model = flux_utils.load_flow_model( - name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + self.is_schnell, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors ) if args.fp8_base: # check dtype of model @@ -100,7 +97,7 @@ def load_target_model(self, args, weight_dtype, accelerator): elif t5xxl.dtype == torch.float8_e4m3fn: logger.info("Loaded fp8 T5XXL model") - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model @@ -142,10 +139,10 @@ def prepare_split_model(self, model, weight_dtype, accelerator): return flux_lower def get_tokenize_strategy(self, args): - name = self.get_flux_model_name(args) + _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) if args.t5xxl_max_token_length is None: - if name == "schnell": + if is_schnell: t5xxl_max_token_length = 256 else: t5xxl_max_token_length = 512 diff --git a/library/flux_utils.py b/library/flux_utils.py index 7b0a41a8a..713814e28 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,9 +1,11 @@ import json -from typing import Optional, Union +import os +from typing import List, Optional, Tuple, Union import einops import torch from safetensors.torch import load_file +from safetensors import safe_open from accelerate import init_empty_weights from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config @@ -17,6 +19,8 @@ logger = logging.getLogger(__name__) MODEL_VERSION_FLUX_V1 = "flux1" +MODEL_NAME_DEV = "dev" +MODEL_NAME_SCHNELL = "schnell" # temporary copy from sd3_utils TODO refactor @@ -39,10 +43,35 @@ def load_safetensors( return load_file(path) # prevent device invalid Error +def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, List[str]]: + # check the state dict: Diffusers or BFL, dev or schnell + logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") + + if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers + ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") + if "00001-of-00003" in ckpt_path: + ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)] + else: + ckpt_paths = [ckpt_path] + + keys = [] + for ckpt_path in ckpt_paths: + with safe_open(ckpt_path, framework="pt") as f: + keys.extend(f.keys()) + + is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys + is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) + return is_diffusers, is_schnell, ckpt_paths + + def load_flow_model( - name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False -) -> flux_models.Flux: - logger.info(f"Building Flux model {name}") + ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False +) -> Tuple[bool, flux_models.Flux]: + is_diffusers, is_schnell, ckpt_paths = check_flux_state_dict_diffusers_schnell(ckpt_path) + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + + # build model + logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") with torch.device("meta"): model = flux_models.Flux(flux_models.configs[name].params) if dtype is not None: @@ -50,18 +79,28 @@ def load_flow_model( # load_sft doesn't support torch.device logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + sd = {} + for ckpt_path in ckpt_paths: + sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)) + + # convert Diffusers to BFL + if is_diffusers: + logger.info("Converting Diffusers to BFL") + sd = convert_diffusers_sd_to_bfl(sd) + logger.info("Converted Diffusers to BFL") + info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Flux: {info}") - return model + return is_schnell, model def load_ae( - name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False + ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False ) -> flux_models.AutoEncoder: logger.info("Building AutoEncoder") with torch.device("meta"): - ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype) + # dev and schnell have the same AE params + ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype) logger.info(f"Loading state dict from {ckpt_path}") sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) @@ -246,3 +285,126 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor: """ x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) return x + + +# region Diffusers + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + +BFL_TO_DIFFUSERS_MAP = { + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]: + # make reverse map from diffusers map + diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) + for b in range(NUM_DOUBLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("double_blocks."): + block_prefix = f"transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for b in range(NUM_SINGLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("single_blocks."): + block_prefix = f"single_transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): + for i, weight in enumerate(weights): + diffusers_to_bfl_map[weight] = (i, key) + return diffusers_to_bfl_map + + +def convert_diffusers_sd_to_bfl(diffusers_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + diffusers_to_bfl_map = make_diffusers_to_bfl_map() + + # iterate over three safetensors files to reduce memory usage + flux_sd = {} + for diffusers_key, tensor in diffusers_sd.items(): + if diffusers_key in diffusers_to_bfl_map: + index, bfl_key = diffusers_to_bfl_map[diffusers_key] + if bfl_key not in flux_sd: + flux_sd[bfl_key] = [] + flux_sd[bfl_key].append((index, tensor)) + else: + logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}") + raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}") + + # concat tensors if multiple tensors are mapped to a single key, sort by index + for key, values in flux_sd.items(): + if len(values) == 1: + flux_sd[key] = values[0][1] + else: + flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])]) + + # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + if "final_layer.adaLN_modulation.1.weight" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"]) + if "final_layer.adaLN_modulation.1.bias" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"]) + + return flux_sd + + +# endregion diff --git a/tools/convert_diffusers_to_flux.py b/tools/convert_diffusers_to_flux.py index 9d8f7c74b..65ba7321a 100644 --- a/tools/convert_diffusers_to_flux.py +++ b/tools/convert_diffusers_to_flux.py @@ -29,6 +29,7 @@ import torch from tqdm import tqdm +from library import flux_utils from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file setup_logging() @@ -36,65 +37,6 @@ logger = logging.getLogger(__name__) -NUM_DOUBLE_BLOCKS = 19 -NUM_SINGLE_BLOCKS = 38 - -BFL_TO_DIFFUSERS_MAP = { - "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], - "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], - "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], - "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], - "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], - "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], - "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], - "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], - "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], - "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], - "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], - "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], - "txt_in.weight": ["context_embedder.weight"], - "txt_in.bias": ["context_embedder.bias"], - "img_in.weight": ["x_embedder.weight"], - "img_in.bias": ["x_embedder.bias"], - "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], - "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], - "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], - "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], - "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], - "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], - "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], - "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], - "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], - "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], - "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], - "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], - "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], - "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], - "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], - "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], - "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], - "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], - "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], - "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], - "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], - "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], - "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], - "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], - "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], - "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], - "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], - "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], - "single_blocks.().linear2.weight": ["proj_out.weight"], - "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], - "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], - "single_blocks.().linear2.weight": ["proj_out.weight"], - "single_blocks.().linear2.bias": ["proj_out.bias"], - "final_layer.linear.weight": ["proj_out.weight"], - "final_layer.linear.bias": ["proj_out.bias"], - "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], - "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], -} - def convert(args): # if diffusers_path is folder, get safetensors file @@ -114,23 +56,7 @@ def convert(args): save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None # make reverse map from diffusers map - diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) - for b in range(NUM_DOUBLE_BLOCKS): - for key, weights in BFL_TO_DIFFUSERS_MAP.items(): - if key.startswith("double_blocks."): - block_prefix = f"transformer_blocks.{b}." - for i, weight in enumerate(weights): - diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) - for b in range(NUM_SINGLE_BLOCKS): - for key, weights in BFL_TO_DIFFUSERS_MAP.items(): - if key.startswith("single_blocks."): - block_prefix = f"single_transformer_blocks.{b}." - for i, weight in enumerate(weights): - diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) - for key, weights in BFL_TO_DIFFUSERS_MAP.items(): - if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): - for i, weight in enumerate(weights): - diffusers_to_bfl_map[weight] = (i, key) + diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map() # iterate over three safetensors files to reduce memory usage flux_sd = {} From 886f75345c95cddec8752ffdd4e60a471ee75403 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 10 Oct 2024 08:27:15 +0900 Subject: [PATCH 163/163] support weighted captions for sdxl LoRA and fine tuning --- library/strategy_base.py | 5 ++++- library/strategy_sdxl.py | 3 ++- sdxl_train.py | 38 ++++++++++++++++++++------------------ sdxl_train_control_net.py | 7 ++----- train_network.py | 27 +++++++++++++++++---------- 5 files changed, 45 insertions(+), 35 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index 10820afa1..7981bd0b9 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -74,6 +74,9 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: raise NotImplementedError def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + returns: [tokens1, tokens2, ...], [weights1, weights2, ...] + """ raise NotImplementedError def _get_weighted_input_ids( @@ -303,7 +306,7 @@ def encode_tokens( :return: list of output embeddings for each architecture """ raise NotImplementedError - + def encode_tokens_with_weights( self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] ) -> List[torch.Tensor]: diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index b48e6d55a..6650e2b43 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -174,7 +174,8 @@ def encode_tokens( """ Args: tokenize_strategy: TokenizeStrategy - models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)] + models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]. + If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required tokens: List of tokens, for text_encoder1 and text_encoder2 """ if len(models) == 2: diff --git a/sdxl_train.py b/sdxl_train.py index 7291ddd2f..320169d77 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -104,8 +104,8 @@ def train(args): setup_logging(args, reset=True) assert ( - not args.weighted_captions - ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + not args.weighted_captions or not args.cache_text_encoder_outputs + ), "weighted_captions is not supported when caching text encoder outputs / cache_text_encoder_outputsを使うときはweighted_captionsはサポートされていません" assert ( not args.train_text_encoder or not args.cache_text_encoder_outputs ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" @@ -660,22 +660,24 @@ def optimizer_hook(parameter: torch.Tensor): input_ids1, input_ids2 = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning - # TODO support weighted captions - # if args.weighted_captions: - # encoder_hidden_states = get_weighted_text_embeddings( - # tokenizer, - # text_encoder, - # batch["captions"], - # accelerator.device, - # args.max_token_length // 75 if args.max_token_length else 1, - # clip_skip=args.clip_skip, - # ) - # else: - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( - tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] - ) + if args.weighted_captions: + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states1, encoder_hidden_states2, pool2 = ( + text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + [text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)], + input_ids_list, + weights_list, + ) + ) + else: + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, + [text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)], + [input_ids1, input_ids2], + ) if args.full_fp16: encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index b902cda69..f6cc5a4f9 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -12,24 +12,21 @@ init_ipex() -from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from accelerate import init_empty_weights -from diffusers import DDPMScheduler, ControlNetModel +from diffusers import DDPMScheduler from diffusers.utils.torch_utils import is_compiled_module from safetensors.torch import load_file from library import ( deepspeed_utils, sai_model_spec, sdxl_model_util, - sdxl_original_unet, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, ) -import library.model_util as model_util import library.train_util as train_util import library.config_util as config_util from library.config_util import ( @@ -264,7 +261,7 @@ def unwrap_model(model): trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr}) trainable_params.append({"params": unet_params, "lr": args.learning_rate}) all_params = ctrlnet_params + unet_params - + logger.info(f"trainable params count: {len(all_params)}") logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}") diff --git a/train_network.py b/train_network.py index f0d397b9e..e48e6a070 100644 --- a/train_network.py +++ b/train_network.py @@ -1123,14 +1123,21 @@ def remove_model(old_ckpt_name): with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: - # SD only - encoded_text_encoder_conds = get_weighted_text_embeddings( - tokenizers[0], - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, + # # SD only + # encoded_text_encoder_conds = get_weighted_text_embeddings( + # tokenizers[0], + # text_encoder, + # batch["captions"], + # accelerator.device, + # args.max_token_length // 75 if args.max_token_length else 1, + # clip_skip=args.clip_skip, + # ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids_list, + weights_list, ) else: input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] @@ -1139,8 +1146,8 @@ def remove_model(old_ckpt_name): self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids, ) - if args.full_fp16: - encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + if args.full_fp16: + encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] # if text_encoder_conds is not cached, use encoded_text_encoder_conds if len(text_encoder_conds) == 0: