Skip to content

Commit

Permalink
Merge branch 'sd3' of https://github.com/kohya-ss/sd-scripts into sd3
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Oct 12, 2024
2 parents 1bb33b7 + 0d3058b commit b5e958b
Show file tree
Hide file tree
Showing 19 changed files with 1,592 additions and 243 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,27 @@ The command to install PyTorch is as follows:

### Recent Updates

Oct 12, 2024:

- Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)!
- It should work with all training scripts, but it is unverified.
- Set up multi-GPU training with `accelerate config`.
- Specify `--rdzv_backend=c10d` when launching `accelerate launch`. You can also edit `config.yaml` directly.
```
accelerate launch --rdzv_backend=c10d sdxl_train_network.py ...
```
- In multi-GPU training, the memory of multiple GPUs is not integrated. In other words, even if you have two 12GB VRAM GPUs, you cannot train the model that requires 24GB VRAM. Training that can be done with 12GB VRAM is executed at (up to) twice the speed.
Oct 11, 2024:
- ControlNet training for SDXL has been implemented in this branch. Please use `sdxl_train_control_net.py`.
- For details on defining the dataset, see [here](docs/train_lllite_README.md#creating-a-dataset-configuration-file).
- The learning rate for the copy part of the U-Net is specified by `--learning_rate`. The learning rate for the added modules in ControlNet is specified by `--control_net_lr`. The optimal value is still unknown, but try around U-Net `1e-5` and ControlNet `1e-4`.
- If you want to generate sample images, specify the control image as `--cn path/to/control/image`.
- The trained weights are automatically converted and saved in Diffusers format. It should be available in ComfyUI.
- Weighting of prompts (captions) during training in SDXL is now supported (e.g., `(some text)`, `[some text]`, `(some text:1.4)`, etc.). The function is enabled by specifying `--weighted_captions`.
- The default is `False`. It is same as before, and the parentheses are used as normal text.
- If `--weighted_captions` is specified, please use `\` to escape the parentheses in the prompt. For example, `\(some text:1.4\)`.
Oct 6, 2024:
- In FLUX.1 LoRA training and fine-tuning, the specified weight file (*.safetensors) is automatically determined to be dev or schnell. This allows schnell models to be loaded correctly. Note that LoRA training with schnell models and fine-tuning with schnell models are unverified.
- FLUX.1 LoRA training and fine-tuning can now load weights in Diffusers format in addition to BFL format (a single *.safetensors file). Please specify the parent directory of `transformer` or `diffusion_pytorch_model-00001-of-00003.safetensors` with the full path. However, Diffusers format CLIP/T5XXL is not supported. Saving is supported only in BFL format.
Expand Down
2 changes: 1 addition & 1 deletion docs/train_lllite_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ for img_file in img_files:

### Creating a dataset configuration file

You can use the command line arguments of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.
You can use the command line argument `--conditioning_data_dir` of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.

```toml
[general]
Expand Down
17 changes: 6 additions & 11 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,22 +366,17 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
if args.weighted_captions:
# TODO move to strategy_sd.py
encoder_hidden_states = get_weighted_text_embeddings(
tokenize_strategy.tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, [text_encoder], input_ids_list, weights_list
)[0]
else:
input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder], [input_ids]
)[0]
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
Expand Down
89 changes: 74 additions & 15 deletions gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
)
from einops import rearrange
from tqdm import tqdm
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor
from accelerate import init_empty_weights
import PIL
from PIL import Image
from PIL.PngImagePlugin import PngInfo
Expand All @@ -58,6 +58,7 @@
from tools.original_control_net import ControlNetInfo
from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.sdxl_original_control_net import SdxlControlNet
from library.original_unet import FlashAttentionFunction
from networks.control_net_lllite import ControlNetLLLite
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
Expand Down Expand Up @@ -352,8 +353,8 @@ def __init__(
self.token_replacements_list.append({})

# ControlNet
self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5
self.control_net_lllites: List[ControlNetLLLite] = []
self.control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = []
self.control_net_lllites: List[Tuple[ControlNetLLLite, float]] = []
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない

self.gradual_latent: GradualLatent = None
Expand Down Expand Up @@ -542,7 +543,7 @@ def __call__(
else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])

if self.control_net_lllites:
if self.control_net_lllites or (self.control_nets and self.is_sdxl):
# ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う
if isinstance(clip_guide_images, PIL.Image.Image):
clip_guide_images = [clip_guide_images]
Expand Down Expand Up @@ -731,7 +732,12 @@ def __call__(
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1

if self.control_nets:
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
if not self.is_sdxl:
guided_hints = original_control_net.get_guided_hints(
self.control_nets, num_latent_input, batch_size, clip_guide_images
)
else:
clip_guide_images = clip_guide_images * 0.5 + 0.5 # [-1, 1] => [0, 1]
each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets)

if self.control_net_lllites:
Expand Down Expand Up @@ -793,7 +799,7 @@ def __call__(
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo
# disable ControlNet-LLLite or SDXL ControlNet if ratio is set. ControlNet is disabled in ControlNetInfo
if self.control_net_lllites:
for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)):
if not enabled or ratio >= 1.0:
Expand All @@ -802,9 +808,16 @@ def __call__(
logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
control_net.set_cond_image(None)
each_control_net_enabled[j] = False
if self.control_nets and self.is_sdxl:
for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)):
if not enabled or ratio >= 1.0:
continue
if ratio < i / len(timesteps):
logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
each_control_net_enabled[j] = False

# predict the noise residual
if self.control_nets and self.control_net_enabled:
if self.control_nets and self.control_net_enabled and not self.is_sdxl:
if regional_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
Expand All @@ -823,6 +836,31 @@ def __call__(
text_embeddings,
text_emb_last,
).sample
elif self.control_nets:
input_resi_add_list = []
mid_add_list = []
for (control_net, _), enbld in zip(self.control_nets, each_control_net_enabled):
if not enbld:
continue
input_resi_add, mid_add = control_net(
latent_model_input, t, text_embeddings, vector_embeddings, clip_guide_images
)
input_resi_add_list.append(input_resi_add)
mid_add_list.append(mid_add)
if len(input_resi_add_list) == 0:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
else:
if len(input_resi_add_list) > 1:
# get mean of input_resi_add_list and mid_add_list
input_resi_add_mean = []
for i in range(len(input_resi_add_list[0])):
input_resi_add_mean.append(
torch.mean(torch.stack([input_resi_add_list[j][i] for j in range(len(input_resi_add_list))], dim=0))
)
input_resi_add = input_resi_add_mean
mid_add = torch.mean(torch.stack(mid_add_list), dim=0)

noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add)
elif self.is_sdxl:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
else:
Expand Down Expand Up @@ -1827,16 +1865,37 @@ def __getattr__(self, item):
upscaler.to(dtype).to(device)

# ControlNetの処理
control_nets: List[ControlNetInfo] = []
control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = []
if args.control_net_models:
for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
if not is_sdxl:
for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]

ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
prep = original_control_net.load_preprocess(prep_type)
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
else:
for i, model_file in enumerate(args.control_net_models):
multiplier = (
1.0
if not args.control_net_multipliers or len(args.control_net_multipliers) <= i
else args.control_net_multipliers[i]
)
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]

logger.info(f"loading SDXL ControlNet: {model_file}")
from safetensors.torch import load_file

state_dict = load_file(model_file)

ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
prep = original_control_net.load_preprocess(prep_type)
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
logger.info(f"Initializing SDXL ControlNet with multiplier: {multiplier}")
with init_empty_weights():
control_net = SdxlControlNet(multiplier=multiplier)
control_net.load_state_dict(state_dict)
control_net.to(dtype).to(device)
control_nets.append((control_net, ratio))

control_net_lllites: List[Tuple[ControlNetLLLite, float]] = []
if args.control_net_lllite_models:
Expand Down
Loading

0 comments on commit b5e958b

Please sign in to comment.