Skip to content

Commit

Permalink
Merge branch 'sd3' of https://github.com/sdbds/sd-scripts into qinglong
Browse files Browse the repository at this point in the history
# Conflicts:
#	flux_train.py
#	library/train_util.py
#	requirements.txt
#	train_network.py
  • Loading branch information
sdbds committed Sep 21, 2024
2 parents b514704 + 95ff9db commit 93a8d0e
Show file tree
Hide file tree
Showing 14 changed files with 1,786 additions and 167 deletions.
159 changes: 156 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,32 @@ The command to install PyTorch is as follows:

### Recent Updates

Sep 18, 2024 (update 1):
Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now.

Sep 18, 2024:

- Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details.
- Details of the schedule-free optimizer can be found in [facebookresearch/schedule_free](https://github.com/facebookresearch/schedule_free).
- `schedulefree` is added to the dependencies. Please update the library if necessary.
- AdamWScheduleFree or SGDScheduleFree can be used. Specify `adamwschedulefree` or `sgdschedulefree` in `--optimizer_type`.
- Wrapper classes are not available for now.
- These can be used not only for FLUX.1 training but also for other training scripts after merging to the dev/main branch.

Sep 16, 2024:

Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details.

Sep 15, 2024:

Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported.

The implementation is based on 2kpr's code. Thanks to 2kpr!

Sep 14, 2024:
- You can now specify the rank for each layer in FLUX.1. See [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) for details.
- OFT is now supported with FLUX.1. See [FLUX.1 OFT training](#flux1-oft-training) for details.

Sep 11, 2024:
Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev!

Expand Down Expand Up @@ -44,14 +70,19 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `

- [FLUX.1 LoRA training](#flux1-lora-training)
- [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training)
- [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model)
- [Distribution of timesteps](#distribution-of-timesteps)
- [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training)
- [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1)
- [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training)
- [FLUX.1 OFT training](#flux1-oft-training)
- [Inference for FLUX.1 with LoRA model](#inference-for-flux1-with-lora-model)
- [FLUX.1 fine-tuning](#flux1-fine-tuning)
- [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning)
- [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models)
- [Convert FLUX LoRA](#convert-flux-lora)
- [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint)
- [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training)
- [Convert Diffusers to FLUX.1](#convert-diffusers-to-flux1)

### FLUX.1 LoRA training

Expand Down Expand Up @@ -191,6 +222,79 @@ In the implementation of Black Forest Labs' model, the projection layers of q/k/

The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large.

#### Specify rank for each layer in FLUX.1

You can specify the rank for each layer in FLUX.1 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer.

When network_args is not specified, the default value (`network_dim`) is applied, same as before.

|network_args|target layer|
|---|---|
|img_attn_dim|img_attn in DoubleStreamBlock|
|txt_attn_dim|txt_attn in DoubleStreamBlock|
|img_mlp_dim|img_mlp in DoubleStreamBlock|
|txt_mlp_dim|txt_mlp in DoubleStreamBlock|
|img_mod_dim|img_mod in DoubleStreamBlock|
|txt_mod_dim|txt_mod in DoubleStreamBlock|
|single_dim|linear1 and linear2 in SingleStreamBlock|
|single_mod_dim|modulation in SingleStreamBlock|

`"verbose=True"` is also available for debugging. It shows the rank of each layer.

example:
```
--network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2"
"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" "verbose=True"
```

You can apply LoRA to the conditioning layers of Flux by specifying `in_dims` in network_args. When specifying, be sure to specify 5 numbers in `[]` as a comma-separated list.

example:
```
--network_args "in_dims=[4,2,2,2,4]"
```

Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt_in`. The above example applies LoRA to all conditioning layers, with rank 4 for `img_in`, 2 for `time_in`, `vector_in`, `guidance_in`, and 4 for `txt_in`.

If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`.

#### Specify blocks to train in FLUX.1 LoRA training

You can specify the blocks to train in FLUX.1 LoRA training by specifying `train_double_block_indices` and `train_single_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`. The number of double blocks is 19, and the number of single blocks is 38, so the valid range is 0-18 and 0-37, respectively. `all` is also available to train all blocks, `none` is also available to train no blocks.

example:
```
--network_args "train_double_block_indices=0,1,8-12,18" "train_single_block_indices=3,10,20-25,37"
```

```
--network_args "train_double_block_indices=none" "train_single_block_indices=10-15"
```

If you specify one of `train_double_block_indices` or `train_single_block_indices`, the other will be trained as usual.

### FLUX.1 OFT training

You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different.

- Change `--network_module` from `networks.lora_flux` to `networks.oft_flux`.
- `--network_dim` is the number of OFT blocks. Unlike LoRA rank, the smaller the dim, the larger the model. We recommend about 64 or 128. Please make the output dimension of the target layer of OFT divisible by the value of `--network_dim` (an error will occur if it is not divisible). Valid values are 64, 128, 256, 512, 1024, etc.
- `--network_alpha` is treated as a constraint for OFT. We recommend about 1e-2 to 1e-4. The default value when omitted is 1, which is too large, so be sure to specify it.
- CLIP/T5XXL is not supported. Specify `--network_train_unet_only`.
- `--network_args` specifies the hyperparameters of OFT. The following are valid:
- Specify `enable_all_linear=True` to target all linear connections in the MLP layer. The default is False, which targets only attention.

Currently, there is no environment to infer FLUX.1 OFT. Inference is only possible with `flux_minimal_inference.py` (specify OFT model with `--lora`).

Sample command is below. It will work with 24GB VRAM GPUs with the batch size of 1.

```
--network_module networks.oft_flux --network_dim 128 --network_alpha 1e-3
--network_args "enable_all_linear=True" --learning_rate 1e-5
```

The training can be done with 16GB VRAM GPUs without `--enable_all_linear` option and with Adafactor optimizer.

### Inference for FLUX.1 with LoRA model

The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
Expand Down Expand Up @@ -292,7 +396,7 @@ If you use LoRA in the inference environment, converting it to AI-toolkit format

Note that re-conversion will increase the size of LoRA.

CLIP-L LoRA is not supported.
CLIP-L/T5XXL LoRA is not supported.

### Merge LoRA to FLUX.1 checkpoint

Expand Down Expand Up @@ -372,6 +476,16 @@ resolution = [512, 512]
num_repeats = 1
```

### Convert Diffusers to FLUX.1

Script: `convert_diffusers_to_flux1.py`

Converts Diffusers models to FLUX.1 models. The script is experimental. See `--help` for options. schnell and dev models are supported. AE/CLIP/T5XXL are not supported. The diffusers folder is a parent folder of `rmer` folder.

```
python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_folder_or_00001_safetensors --save_to path/to/flux1.safetensors --mem_eff_load_save --save_precision bf16
```

## SD3 training

SD3 training is done with `sd3_train.py`.
Expand Down Expand Up @@ -567,7 +681,31 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
- transformers, accelerate and huggingface_hub are updated.
- If you encounter any issues, please report them.

- en: The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds!
- Improvements in OFT (Orthogonal Finetuning) Implementation
1. Optimization of Calculation Order:
- Changed the calculation order in the forward method from (Wx)R to W(xR).
- This has improved computational efficiency and processing speed.
2. Correction of Bias Application:
- In the previous implementation, R was incorrectly applied to the bias.
- The new implementation now correctly handles bias by using F.conv2d and F.linear.
3. Efficiency Enhancement in Matrix Operations:
- Introduced einsum in both the forward and merge_to methods.
- This has optimized matrix operations, resulting in further speed improvements.
4. Proper Handling of Data Types:
- Improved to use torch.float32 during calculations and convert results back to the original data type.
- This maintains precision while ensuring compatibility with the original model.
5. Unified Processing for Conv2d and Linear Layers:
- Implemented a consistent method for applying OFT to both layer types.
- These changes have made the OFT implementation more efficient and accurate, potentially leading to improved model performance and training stability.

- Additional Information
* Recommended α value for OFT constraint: We recommend using α values between 1e-4 and 1e-2. This differs slightly from the original implementation of "(α\*out_dim\*out_dim)". Our implementation uses "(α\*out_dim)", hence we recommend higher values than the 1e-5 suggested in the original implementation.

* Performance Improvement: Training speed has been improved by approximately 30%.

* Inference Environment: This implementation is compatible with and operates within Stable Diffusion web UI (SD1/2 and SDXL).

- The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds!
- See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler.
- `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc.

Expand Down Expand Up @@ -704,6 +842,21 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します

- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。


### Sep 13, 2024 / 2024-09-13:

- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580).
- `svd_merge_lora.py` now supports LBW. Thanks to terracottahaniwa. See PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) for details.
- `sdxl_merge_lora.py` also supports LBW.
- See [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) by hako-mikan for details on LBW.
- These will be included in the next release.

- `sdxl_merge_lora.py` が OFT をサポートされました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。
- `svd_merge_lora.py` で LBW がサポートされました。PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) terracottahaniwa 氏に感謝します。
- `sdxl_merge_lora.py` でも LBW がサポートされました。
- LBW の詳細は hako-mikan 氏の [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) をご覧ください。
- 以上は次回リリースに含まれます。

### Jun 23, 2024 / 2024-06-23:

- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.)
Expand Down
20 changes: 16 additions & 4 deletions flux_minimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
from PIL import Image
import accelerate
from transformers import CLIPTextModel
from safetensors.torch import load_file

from library import device_utils
from library.device_utils import init_ipex, get_preferred_device
from networks import oft_flux

init_ipex()

Expand Down Expand Up @@ -405,7 +407,7 @@ def encode(prpt: str):
type=str,
nargs="*",
default=[],
help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)",
help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
parser.add_argument("--width", type=int, default=target_width)
Expand Down Expand Up @@ -482,9 +484,19 @@ def is_fp8(dt):
else:
multiplier = 1.0

lora_model, weights_sd = lora_flux.create_network_from_weights(
multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True
)
weights_sd = load_file(weights_file)
is_lora = is_oft = False
for key in weights_sd.keys():
if key.startswith("lora"):
is_lora = True
if key.startswith("oft"):
is_oft = True
if is_lora or is_oft:
break

module = lora_flux if is_lora else oft_flux
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)

if args.merge_lora_weights:
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
else:
Expand Down
12 changes: 11 additions & 1 deletion flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def unwrap_model(model):

flux.requires_grad_(True)

is_swapping_blocks = args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None
is_swapping_blocks = args.double_blocks_to_swap or args.single_blocks_to_swap
if is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you!
Expand Down Expand Up @@ -354,8 +354,13 @@ def unwrap_model(model):

logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")

if train_util.is_schedulefree_optimizer(optimizers[0], args):
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize, model=flux)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)

# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
Expand Down Expand Up @@ -782,6 +787,7 @@ def optimizer_hook(parameter: torch.Tensor):
progress_bar.update(1)
global_step += 1

optimizer_eval_fn()
flux_train_utils.sample_images(
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
)
Expand All @@ -800,6 +806,7 @@ def optimizer_hook(parameter: torch.Tensor):
global_step,
unwrap_model(flux),
)
optimizer_train_fn()

current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if len(accelerator.trackers) > 0:
Expand All @@ -822,6 +829,7 @@ def optimizer_hook(parameter: torch.Tensor):

accelerator.wait_for_everyone()

optimizer_eval_fn()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
Expand All @@ -838,12 +846,14 @@ def optimizer_hook(parameter: torch.Tensor):
flux_train_utils.sample_images(
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
)
optimizer_train_fn()

is_main_process = accelerator.is_main_process
# if is_main_process:
flux = unwrap_model(flux)

accelerator.end_training()
optimizer_eval_fn()

if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)
Expand Down
3 changes: 2 additions & 1 deletion gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@
"""


def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
# def replace_unet_modules(unet: diffusers.models.unets.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
def replace_unet_modules(unet, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
logger.info("Enable memory efficient attention for U-Net")

Expand Down
Loading

0 comments on commit 93a8d0e

Please sign in to comment.