diff --git a/README-ja.md b/README-ja.md index 29c33a659..f70f882d7 100644 --- a/README-ja.md +++ b/README-ja.md @@ -1,7 +1,3 @@ -SDXLがサポートされました。sdxlブランチはmainブランチにマージされました。リポジトリを更新したときにはUpgradeの手順を実行してください。また accelerate のバージョンが上がっていますので、accelerate config を再度実行してください。 - -SDXL学習については[こちら](./README.md#sdxl-training)をご覧ください(英語です)。 - ## リポジトリについて Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。 @@ -21,6 +17,7 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma * [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど * [データセット設定](./docs/config_README-ja.md) +* [SDXL学習](./docs/train_SDXL-en.md) (英語版) * [DreamBoothの学習について](./docs/train_db_README-ja.md) * [fine-tuningのガイド](./docs/fine_tune_README_ja.md): * [LoRAの学習について](./docs/train_network_README-ja.md) @@ -44,9 +41,7 @@ PowerShellを使う場合、venvを使えるようにするためには以下の ## Windows環境でのインストール -スクリプトはPyTorch 2.0.1でテストしています。PyTorch 1.12.1でも動作すると思われます。 - -以下の例ではPyTorchは2.0.1/CUDA 11.8版をインストールします。CUDA 11.6版やPyTorch 1.12.1を使う場合は適宜書き換えください。 +スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.0.1、1.12.1でも動作すると思われます。 (なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。) @@ -59,20 +54,20 @@ cd sd-scripts python -m venv venv .\venv\Scripts\activate -pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118 pip install --upgrade -r requirements.txt -pip install xformers==0.0.20 +pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118 accelerate config ``` コマンドプロンプトでも同一です。 -(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。) +注:`bitsandbytes==0.43.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。 -accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。) +この例では PyTorch および xfomers は2.1.2/CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。 -※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます(……)。数字キーの0、1、2……で選択できますので、そちらを使ってください。 +accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。) ```txt - This machine @@ -87,41 +82,6 @@ accelerate configの質問には以下のように答えてください。(bf1 ※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問( ``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。) -### オプション:`bitsandbytes`(8bit optimizer)を使う - -`bitsandbytes`はオプションになりました。Linuxでは通常通りpipでインストールできます(0.41.1または以降のバージョンを推奨)。 - -Windowsでは0.35.0または0.41.1を推奨します。 - -- `bitsandbytes` 0.35.0: 安定しているとみられるバージョンです。AdamW8bitは使用できますが、他のいくつかの8bit optimizer、学習時の`full_bf16`オプションは使用できません。 -- `bitsandbytes` 0.41.1: Lion8bit、PagedAdamW8bit、PagedLion8bitをサポートします。`full_bf16`が使用できます。 - -注:`bitsandbytes` 0.35.0から0.41.0までのバージョンには問題があるようです。 https://github.com/TimDettmers/bitsandbytes/issues/659 - -以下の手順に従い、`bitsandbytes`をインストールしてください。 - -### 0.35.0を使う場合 - -PowerShellの例です。コマンドプロンプトではcpの代わりにcopyを使ってください。 - -```powershell -cd sd-scripts -.\venv\Scripts\activate -pip install bitsandbytes==0.35.0 - -cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ -cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py -cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py -``` - -### 0.41.1を使う場合 - -jllllll氏の配布されている[こちら](https://github.com/jllllll/bitsandbytes-windows-webui) または他の場所から、Windows用のwhlファイルをインストールしてください。 - -```powershell -python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui -``` - ## アップグレード 新しいリリースがあった場合、以下のコマンドで更新できます。 @@ -151,4 +111,3 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause - diff --git a/README.md b/README.md index e5f8e7cec..34b11c70a 100644 --- a/README.md +++ b/README.md @@ -1,176 +1,3 @@ -# Training Stable Cascade Stage C - -This is an experimental feature. There may be bugs. - -__Feb 25, 2024 Update:__ Fixed a bug that the LoRA weights trained can be loaded in ComfyUI. If you still have a problem, please let me know. - -__Feb 25, 2024 Update:__ Fixed a bug that Stage C training with mixed precision behaves the same as `--full_bf16` (fp16) regardless of `--full_bf16` (fp16) specified. - -This is because the Stage C weights were loaded in bf16/fp16. With this fix, the memory usage without `--full_bf16` (fp16) specified will increase, so you may need to specify `--full_bf16` (fp16) as needed. - -__Feb 22, 2024 Update:__ Fixed a bug that LoRA is not applied to some modules (to_q/k/v and to_out) in Attention. Also, the model structure of Stage C has been changed, and you can choose xformers and SDPA (SDPA was used before). Please specify `--sdpa` or `--xformers` option. - -__Feb 20, 2024 Update:__ There was a problem with the preprocessing of the EfficientNetEncoder, and the latents became invalid (the saturation of the training results decreases). If you have cached `_sc_latents.npz` files with `--cache_latents_to_disk`, please delete them before training. - -## Usage - -Training is run with `stable_cascade_train_stage_c.py`. - -The main options are the same as `sdxl_train.py`. The following options have been added. - -- `--effnet_checkpoint_path`: Specifies the path to the EfficientNetEncoder weights. -- `--stage_c_checkpoint_path`: Specifies the path to the Stage C weights. -- `--text_model_checkpoint_path`: Specifies the path to the Text Encoder weights. If omitted, the model from Hugging Face will be used. -- `--save_text_model`: Saves the model downloaded from Hugging Face to `--text_model_checkpoint_path`. -- `--previewer_checkpoint_path`: Specifies the path to the Previewer weights. Used to generate sample images during training. -- `--adaptive_loss_weight`: Uses [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) . If omitted, P2LossWeight is used. The official settings use Adaptive Loss Weight. - -The learning rate is set to 1e-4 in the official settings. - -The first time, specify `--text_model_checkpoint_path` and `--save_text_model` to save the Text Encoder weights. From the next time, specify `--text_model_checkpoint_path` to load the saved weights. - -Sample image generation during training is done with Perviewer. Perviewer is a simple decoder that converts EfficientNetEncoder latents to images. - -Some of the options for SDXL are simply ignored or cause an error (especially noise-related options such as `--noise_offset`). `--vae_batch_size` and `--no_half_vae` are applied directly to the EfficientNetEncoder (when `bf16` is specified for mixed precision, `--no_half_vae` is not necessary). - -Options for latents and Text Encoder output caches can be used as is, but since the EfficientNetEncoder is much lighter than the VAE, you may not need to use the cache unless memory is particularly tight. - -`--gradient_checkpointing`, `--full_bf16`, and `--full_fp16` (untested) to reduce memory consumption can be used as is. - -A scale of about 4 is suitable for sample image generation. - -Since the official settings use `bf16` for training, training with `fp16` may be unstable. - -The code for training the Text Encoder is also written, but it is untested. - -### Command line sample - -```batch -accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 stable_cascade_train_stage_c.py --mixed_precision bf16 --save_precision bf16 --max_data_loader_n_workers 2 --persistent_data_loader_workers --gradient_checkpointing --learning_rate 1e-4 --optimizer_type adafactor --optimizer_args "scale_parameter=False" "relative_step=False" "warmup_init=False" --max_train_epochs 10 --save_every_n_epochs 1 --save_precision bf16 --output_dir ../output --output_name sc_test - --stage_c_checkpoint_path ../models/stage_c_bf16.safetensors --effnet_checkpoint_path ../models/effnet_encoder.safetensors --previewer_checkpoint_path ../models/previewer.safetensors --dataset_config ../dataset/config_bs1.toml --sample_every_n_epochs 1 --sample_prompts ../dataset/prompts.txt --adaptive_loss_weight -``` - -### About the dataset for fine tuning - -If the latents cache files for SD/SDXL exist (extension `*.npz`), it will be read and an error will occur during training. Please move them to another location in advance. - -After that, run `finetune/prepare_buckets_latents.py` with the `--stable_cascade` option to create latents cache files for Stable Cascade (suffix `_sc_latents.npz` is added). - -## LoRA training - -`stable_cascade_train_c_network.py` is used for LoRA training. The main options are the same as `train_network.py`, and the same options as `stable_cascade_train_stage_c.py` have been added. - -__This is an experimental feature, so the format of the saved weights may change in the future and become incompatible.__ - -There is no compatibility with the official LoRA, and the implementation of Text Encoder embedding training (Pivotal Tuning) in the official implementation is not implemented here. - -Text Encoder LoRA training is implemented, but untested. - -## Image generation - -Basic image generation functionality is available in `stable_cascade_gen_img.py`. See `--help` for usage. - -When using LoRA, specify `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors`. - -The following prompt options are available. - - * `--n` Negative prompt up to the next option. - * `--w` Specifies the width of the generated image. - * `--h` Specifies the height of the generated image. - * `--d` Specifies the seed of the generated image. - * `--l` Specifies the CFG scale of the generated image. - * `--s` Specifies the number of steps in the generation. - * `--t` Specifies the t_start of the generation. - * `--f` Specifies the shift of the generation. - -# Stable Cascade Stage C の学習 - -実験的機能です。不具合があるかもしれません。 - -__2024/2/25 追記:__ 学習される LoRA の重みが ComfyUI で読み込めるよう修正しました。依然として不具合がある場合にはご連絡ください。 - -__2024/2/25 追記:__ Mixed precision 時のStage C の学習が、 `--full_bf16` (fp16) の指定に関わらず `--full_bf16` (fp16) 指定時と同じ動作となる(と思われる)不具合を修正しました。 - -Stage C の重みを bf16/fp16 で読み込んでいたためです。この修正により `--full_bf16` (fp16) 未指定時のメモリ使用量が増えますので、必要に応じて `--full_bf16` (fp16) を指定してください。 - -__2024/2/22 追記:__ LoRA が一部のモジュール(Attention の to_q/k/v および to_out)に適用されない不具合を修正しました。また Stage C のモデル構造を変更し xformers と SDPA を選べるようになりました(今までは SDPA が使用されていました)。`--sdpa` または `--xformers` オプションを指定してください。 - -__2024/2/20 追記:__ EfficientNetEncoder の前処理に不具合があり、latents が不正になっていました(学習結果の彩度が低下する現象が起きます)。`--cache_latents_to_disk` でキャッシュした `_sc_latents.npz` がある場合、いったん削除してから学習してください。 - -## 使い方 - -学習は `stable_cascade_train_stage_c.py` で行います。 - -主なオプションは `sdxl_train.py` と同様です。以下のオプションが追加されています。 - -- `--effnet_checkpoint_path` : EfficientNetEncoder の重みのパスを指定します。 -- `--stage_c_checkpoint_path` : Stage C の重みのパスを指定します。 -- `--text_model_checkpoint_path` : Text Encoder の重みのパスを指定します。省略時は Hugging Face のモデルを使用します。 -- `--save_text_model` : `--text_model_checkpoint_path` にHugging Face からダウンロードしたモデルを保存します。 -- `--previewer_checkpoint_path` : Previewer の重みのパスを指定します。学習中のサンプル画像生成に使用します。 -- `--adaptive_loss_weight` : [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) を用います。省略時は P2LossWeight が使用されます。公式では Adaptive Loss Weight が使用されているようです。 - -学習率は、公式の設定では 1e-4 のようです。 - -初回は `--text_model_checkpoint_path` と `--save_text_model` を指定して、Text Encoder の重みを保存すると良いでしょう。次からは `--text_model_checkpoint_path` を指定して、保存した重みを読み込むことができます。 - -学習中のサンプル画像生成は Perviewer で行われます。Previewer は EfficientNetEncoder の latents を画像に変換する簡易的な decoder です。 - -SDXL の向けの一部のオプションは単に無視されるか、エラーになります(特に `--noise_offset` などのノイズ関係)。`--vae_batch_size` および `--no_half_vae` はそのまま EfficientNetEncoder に適用されます(mixed precision に `bf16` 指定時は `--no_half_vae` は不要のようです)。 - -latents および Text Encoder 出力キャッシュのためのオプションはそのまま使用できますが、EfficientNetEncoder は VAE よりもかなり軽量のため、メモリが特に厳しい場合以外はキャッシュを使用する必要はないかもしれません。 - -メモリ消費を抑えるための `--gradient_checkpointing` 、`--full_bf16`、`--full_fp16`(未テスト)はそのまま使用できます。 - -サンプル画像生成時の Scale には 4 程度が適しているようです。 - -公式の設定では学習に `bf16` を用いているため、`fp16` での学習は不安定かもしれません。 - -Text Encoder 学習のコードも書いてありますが、未テストです。 - -### コマンドラインのサンプル - -[Command-line-sample](#command-line-sample)を参照してください。 - - -### fine tuning方式のデータセットについて - -SD/SDXL 向けの latents キャッシュファイル(拡張子 `*.npz`)が存在するとそれを読み込んでしまい学習時にエラーになります。あらかじめ他の場所に退避しておいてください。 - -その後、`finetune/prepare_buckets_latents.py` をオプション `--stable_cascade` を指定して実行すると、Stable Cascade 向けの latents キャッシュファイル(接尾辞 `_sc_latents.npz` が付きます)が作成されます。 - - -## LoRA 等の学習 - -LoRA の学習は `stable_cascade_train_c_network.py` で行います。主なオプションは `train_network.py` と同様で、`stable_cascade_train_stage_c.py` と同様のオプションが追加されています。 - -__実験的機能のため、保存される重みのフォーマットは将来的に変更され、互換性がなくなる可能性があります。__ - -公式の LoRA と重みの互換性はありません。また公式で実装されている Text Encoder の embedding 学習(Pivotal Tuning)も実装されていません。 - -Text Encoder の LoRA 学習は実装してありますが、未テストです。 - -## 画像生成 - -最低限の画像生成機能が `stable_cascade_gen_img.py` にあります。使用法は `--help` を参照してください。 - -LoRA 使用時は `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors` のように指定します。 - -プロンプトオプションとして以下が使用できます。 - - * `--n` Negative prompt up to the next option. - * `--w` Specifies the width of the generated image. - * `--h` Specifies the height of the generated image. - * `--d` Specifies the seed of the generated image. - * `--l` Specifies the CFG scale of the generated image. - * `--s` Specifies the number of steps in the generation. - * `--t` Specifies the t_start of the generation. - * `--f` Specifies the shift of the generation. - - ---- - -__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training). - This repository contains training, generation and utility scripts for Stable Diffusion. [__Change History__](#change-history) is moved to the bottom of the page. @@ -191,9 +18,9 @@ This repository contains the scripts for: ## About requirements.txt -These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.) +The file does not contain requirements for PyTorch. Because the version of PyTorch depends on the environment, it is not included in the file. Please install PyTorch first according to the environment. See installation instructions below. -The scripts are tested with Pytorch 2.0.1. 1.12.1 is not tested but should work. +The scripts are tested with Pytorch 2.1.2. 2.0.1 and 1.12.1 is not tested but should work. ## Links to usage documentation @@ -203,11 +30,13 @@ Most of the documents are written in Japanese. * [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc... * [Chinese version](./docs/train_README-zh.md) +* [SDXL training](./docs/train_SDXL-en.md) (English version) * [Dataset config](./docs/config_README-ja.md) + * [English version](./docs/config_README-en.md) * [DreamBooth training guide](./docs/train_db_README-ja.md) * [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md): -* [training LoRA](./docs/train_network_README-ja.md) -* [training Textual Inversion](./docs/train_ti_README-ja.md) +* [Training LoRA](./docs/train_network_README-ja.md) +* [Training Textual Inversion](./docs/train_ti_README-ja.md) * [Image generation](./docs/gen_img_README-ja.md) * note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad) @@ -235,14 +64,18 @@ cd sd-scripts python -m venv venv .\venv\Scripts\activate -pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118 pip install --upgrade -r requirements.txt -pip install xformers==0.0.20 +pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118 accelerate config ``` -__Note:__ Now bitsandbytes is optional. Please install any version of bitsandbytes as needed. Installation instructions are in the following section. +If `python -m venv` shows only `python`, change `python` to `py`. + +__Note:__ Now `bitsandbytes==0.43.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually. + +This installation is for CUDA 11.8. If you use a different version of CUDA, please install the appropriate version of PyTorch and xformers. For example, if you use CUDA 12, please install `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` and `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121`. [[datasets]] No.2 ┘ +``` + +The image directory corresponds to each `[[datasets.subsets]]`. Then, multiple `[[datasets.subsets]]` are combined to form one `[[datasets]]`. All `[[datasets]]` and `[[datasets.subsets]]` belong to `[general]`. + +The available options for each registration location may differ, but if the same option is specified, the value in the lower registration location will take precedence. You can check how the `keep_tokens` option is handled in the previous example for better understanding. + +Additionally, the available options may vary depending on the method that the learning approach supports. + +* Options specific to the DreamBooth method +* Options specific to the fine-tuning method +* Options available when using the caption dropout technique + +When using both the DreamBooth method and the fine-tuning method, they can be used together with a learning approach that supports both. +When using them together, a point to note is that the method is determined based on the dataset, so it is not possible to mix DreamBooth method subsets and fine-tuning method subsets within the same dataset. +In other words, if you want to use both methods together, you need to set up subsets of different methods belonging to different datasets. + +In terms of program behavior, if the `metadata_file` option exists, it is determined to be a subset of fine-tuning. Therefore, for subsets belonging to the same dataset, as long as they are either "all have the `metadata_file` option" or "all have no `metadata_file` option," there is no problem. + +Below, the available options will be explained. For options with the same name as the command-line argument, the explanation will be omitted in principle. Please refer to other READMEs. + +### Common options for all learning methods + +These are options that can be specified regardless of the learning method. + +#### Data set specific options + +These are options related to the configuration of the data set. They cannot be described in `datasets.subsets`. + + +| Option Name | Example Setting | `[general]` | `[[datasets]]` | +| ---- | ---- | ---- | ---- | +| `batch_size` | `1` | o | o | +| `bucket_no_upscale` | `true` | o | o | +| `bucket_reso_steps` | `64` | o | o | +| `enable_bucket` | `true` | o | o | +| `max_bucket_reso` | `1024` | o | o | +| `min_bucket_reso` | `128` | o | o | +| `resolution` | `256`, `[512, 512]` | o | o | + +* `batch_size` + * This corresponds to the command-line argument `--train_batch_size`. + +These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each. + +#### Options for Subsets + +These options are related to subset configuration. + +| Option Name | Example | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `color_aug` | `false` | o | o | o | +| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o | +| `flip_aug` | `true` | o | o | o | +| `keep_tokens` | `2` | o | o | o | +| `num_repeats` | `10` | o | o | o | +| `random_crop` | `false` | o | o | o | +| `shuffle_caption` | `true` | o | o | o | +| `caption_prefix` | `"masterpiece, best quality, "` | o | o | o | +| `caption_suffix` | `", from side"` | o | o | o | +| `caption_separator` | (not specified) | o | o | o | +| `keep_tokens_separator` | `“|||”` | o | o | o | +| `secondary_separator` | `“;;;”` | o | o | o | +| `enable_wildcard` | `true` | o | o | o | + +* `num_repeats` + * Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method. +* `caption_prefix`, `caption_suffix` + * Specifies the prefix and suffix strings to be appended to the captions. Shuffling is performed with these strings included. Be cautious when using `keep_tokens`. +* `caption_separator` + * Specifies the string to separate the tags. The default is `,`. This option is usually not necessary to set. +* `keep_tokens_separator` + * Specifies the string to separate the parts to be fixed in the caption. For example, if you specify `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh`, the parts `aaa, bbb` and `ggg, hhh` will remain, and the rest will be shuffled and dropped. The comma in between is not necessary. As a result, the prompt will be `aaa, bbb, eee, ccc, fff, ggg, hhh` or `aaa, bbb, fff, ccc, eee, ggg, hhh`, etc. +* `secondary_separator` + * Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together. +* `enable_wildcard` + * Enables wildcard notation. This will be explained later. + +### DreamBooth-specific options + +DreamBooth-specific options only exist as subsets-specific options. + +#### Subset-specific options + +Options related to the configuration of DreamBooth subsets. + +| Option Name | Example Setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `'C:\hoge'` | - | - | o (required) | +| `caption_extension` | `".txt"` | o | o | o | +| `class_tokens` | `"sks girl"` | - | - | o | +| `cache_info` | `false` | o | o | o | +| `is_reg` | `false` | - | - | o | + +Firstly, note that for `image_dir`, the path to the image files must be specified as being directly in the directory. Unlike the previous DreamBooth method, where images had to be placed in subdirectories, this is not compatible with that specification. Also, even if you name the folder something like "5_cat", the number of repeats of the image and the class name will not be reflected. If you want to set these individually, you will need to explicitly specify them using `num_repeats` and `class_tokens`. + +* `image_dir` + * Specifies the path to the image directory. This is a required option. + * Images must be placed directly under the directory. +* `class_tokens` + * Sets the class tokens. + * Only used during training when a corresponding caption file does not exist. The determination of whether or not to use it is made on a per-image basis. If `class_tokens` is not specified and a caption file is not found, an error will occur. +* `cache_info` + * Specifies whether to cache the image size and caption. If not specified, it is set to `false`. The cache is saved in `metadata_cache.json` in `image_dir`. + * Caching speeds up the loading of the dataset after the first time. It is effective when dealing with thousands of images or more. +* `is_reg` + * Specifies whether the subset images are for normalization. If not specified, it is set to `false`, meaning that the images are not for normalization. + +### Fine-tuning method specific options + +The options for the fine-tuning method only exist for subset-specific options. + +#### Subset-specific options + +These options are related to the configuration of the fine-tuning method's subsets. + +| Option name | Example setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `'C:\hoge'` | - | - | o | +| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o (required) | + +* `image_dir` + * Specify the path to the image directory. Unlike the DreamBooth method, specifying it is not mandatory, but it is recommended to do so. + * The case where it is not necessary to specify is when the `--full_path` is added to the command line when generating the metadata file. + * The images must be placed directly under the directory. +* `metadata_file` + * Specify the path to the metadata file used for the subset. This is a required option. + * It is equivalent to the command-line argument `--in_json`. + * Due to the specification that a metadata file must be specified for each subset, it is recommended to avoid creating a metadata file with images from different directories as a single metadata file. It is strongly recommended to prepare a separate metadata file for each image directory and register them as separate subsets. + +### Options available when caption dropout method can be used + +The options available when the caption dropout method can be used exist only for subsets. Regardless of whether it's the DreamBooth method or fine-tuning method, if it supports caption dropout, it can be specified. + +#### Subset-specific options + +Options related to the setting of subsets that caption dropout can be used for. + +| Option Name | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | +| `caption_dropout_every_n_epochs` | o | o | o | +| `caption_dropout_rate` | o | o | o | +| `caption_tag_dropout_rate` | o | o | o | + +## Behavior when there are duplicate subsets + +In the case of the DreamBooth dataset, if there are multiple `image_dir` directories with the same content, they are considered to be duplicate subsets. For the fine-tuning dataset, if there are multiple `metadata_file` files with the same content, they are considered to be duplicate subsets. If duplicate subsets exist in the dataset, subsequent subsets will be ignored. + +However, if they belong to different datasets, they are not considered duplicates. For example, if you have subsets with the same `image_dir` in different datasets, they will not be considered duplicates. This is useful when you want to train with the same image but with different resolutions. + +```toml +# If data sets exist separately, they are not considered duplicates and are both used for training. + +[[datasets]] +resolution = 512 + + [[datasets.subsets]] + image_dir = 'C:\hoge' + +[[datasets]] +resolution = 768 + + [[datasets.subsets]] + image_dir = 'C:\hoge' +``` + +## Command Line Argument and Configuration File + +There are options in the configuration file that have overlapping roles with command line argument options. + +The following command line argument options are ignored if a configuration file is passed: + +* `--train_data_dir` +* `--reg_data_dir` +* `--in_json` + +The following command line argument options are given priority over the configuration file options if both are specified simultaneously. In most cases, they have the same names as the corresponding options in the configuration file. + +| Command Line Argument Option | Prioritized Configuration File Option | +| ------------------------------- | ------------------------------------- | +| `--bucket_no_upscale` | | +| `--bucket_reso_steps` | | +| `--caption_dropout_every_n_epochs` | | +| `--caption_dropout_rate` | | +| `--caption_extension` | | +| `--caption_tag_dropout_rate` | | +| `--color_aug` | | +| `--dataset_repeats` | `num_repeats` | +| `--enable_bucket` | | +| `--face_crop_aug_range` | | +| `--flip_aug` | | +| `--keep_tokens` | | +| `--min_bucket_reso` | | +| `--random_crop` | | +| `--resolution` | | +| `--shuffle_caption` | | +| `--train_batch_size` | `batch_size` | + +## Error Guide + +Currently, we are using an external library to check if the configuration file is written correctly, but the development has not been completed, and there is a problem that the error message is not clear. In the future, we plan to improve this problem. + +As a temporary measure, we will list common errors and their solutions. If you encounter an error even though it should be correct or if the error content is not understandable, please contact us as it may be a bug. + +* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: This error occurs when a required option is not provided. It is highly likely that you forgot to specify the option or misspelled the option name. + * The error location is indicated by `...` in the error message. For example, if you encounter an error like `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']`, it means that the `image_dir` option does not exist in the 0th `subsets` of the 0th `datasets` setting. +* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: This error occurs when the specified value format is incorrect. It is highly likely that the value format is incorrect. The `int` part changes depending on the target option. The example configurations in this README may be helpful. +* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: This error occurs when there is an option name that is not supported. It is highly likely that you misspelled the option name or mistakenly included it. + +## Miscellaneous + +### Multi-line captions + +By setting `enable_wildcard = true`, multiple-line captions are also enabled. If the caption file consists of multiple lines, one line is randomly selected as the caption. + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage +a girl with a microphone standing on a stage +detailed digital art of a girl with a microphone on a stage +``` + +It can be combined with wildcard notation. + +In metadata files, you can also specify multiple-line captions. In the `.json` metadata file, use `\n` to represent a line break. If the caption file consists of multiple lines, `merge_captions_to_metadata.py` will create a metadata file in this format. + +The tags in the metadata (`tags`) are added to each line of the caption. + +```json +{ + "/path/to/image.png": { + "caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2", + "tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus" + }, + ... +} +``` + +In this case, the actual caption will be `a cartoon of a frog with the word frog on it, open mouth, simple background ...`, `test multiline caption1, open mouth, simple background ...`, `test multiline caption2, open mouth, simple background ...`, etc. + +### Example of configuration file : `secondary_separator`, wildcard notation, `keep_tokens_separator`, etc. + +```toml +[general] +flip_aug = true +color_aug = false +resolution = [1024, 1024] + +[[datasets]] +batch_size = 6 +enable_bucket = true +bucket_no_upscale = true +caption_extension = ".txt" +keep_tokens_separator= "|||" +shuffle_caption = true +caption_tag_dropout_rate = 0.1 +secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side +enable_wildcard = true # 同上 / same as above + + [[datasets.subsets]] + image_dir = "/path/to/image_dir" + num_repeats = 1 + + # ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically) + caption_prefix = "1girl, hatsune miku, vocaloid |||" + + # ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains + # 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself + caption_suffix = ", anime screencap ||| masterpiece, rating: general" +``` + +### Example of caption, secondary_separator notation: `secondary_separator = ";;;"` + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors +``` +The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped). + +### Example of caption, enable_wildcard notation: `enable_wildcard = true` + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background +``` +`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`. + +```txt +1girl, hatsune miku, vocaloid, {{retro style}} +``` +If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`). + +### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"` + +```txt +1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general +``` +It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc. + diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index 69a03f6cf..cc74c341b 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -1,5 +1,3 @@ -For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future. - `--dataset_config` で渡すことができる設定ファイルに関する説明です。 ## 概要 @@ -140,12 +138,28 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 | `shuffle_caption` | `true` | o | o | o | | `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o | | `caption_suffix` | `“, from side”` | o | o | o | +| `caption_separator` | (通常は設定しません) | o | o | o | +| `keep_tokens_separator` | `“|||”` | o | o | o | +| `secondary_separator` | `“;;;”` | o | o | o | +| `enable_wildcard` | `true` | o | o | o | * `num_repeats` * サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。 * `caption_prefix`, `caption_suffix` * キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。 +* `caption_separator` + * タグを区切る文字列を指定します。デフォルトは `,` です。このオプションは通常は設定する必要はありません。 + +* `keep_tokens_separator` + * キャプションで固定したい部分を区切る文字列を指定します。たとえば `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh` のように指定すると、`aaa, bbb` と `ggg, hhh` の部分はシャッフル、drop されず残ります。間のカンマは不要です。結果としてプロンプトは `aaa, bbb, eee, ccc, fff, ggg, hhh` や `aaa, bbb, fff, ccc, eee, ggg, hhh` などになります。 + +* `secondary_separator` + * 追加の区切り文字を指定します。この区切り文字で区切られた部分は一つのタグとして扱われ、シャッフル、drop されます。その後、`caption_separator` に置き換えられます。たとえば `aaa;;;bbb;;;ccc` のように指定すると、`aaa,bbb,ccc` に置き換えられるか、まとめて drop されます。 + +* `enable_wildcard` + * ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。 + ### DreamBooth 方式専用のオプション DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。 @@ -159,6 +173,7 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。 | `image_dir` | `‘C:\hoge’` | - | - | o(必須) | | `caption_extension` | `".txt"` | o | o | o | | `class_tokens` | `“sks girl”` | - | - | o | +| `cache_info` | `false` | o | o | o | | `is_reg` | `false` | - | - | o | まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats` と `class_tokens` で明示的に指定する必要があることに注意してください。 @@ -169,6 +184,9 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。 * `class_tokens` * クラストークンを設定します。 * 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイルも見つからなかった場合にはエラーになります。 +* `cache_info` + * 画像サイズ、キャプションをキャッシュするかどうかを指定します。指定しなかった場合は `false` になります。キャッシュは `image_dir` に `metadata_cache.json` というファイル名で保存されます。 + * キャッシュを行うと、二回目以降のデータセット読み込みが高速化されます。数千枚以上の画像を扱う場合には有効です。 * `is_reg` * サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。 @@ -280,4 +298,89 @@ resolution = 768 * `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。 * `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。 +## その他 + +### 複数行キャプション + +`enable_wildcard = true` を設定することで、複数行キャプションも同時に有効になります。キャプションファイルが複数の行からなる場合、ランダムに一つの行が選ばれてキャプションとして利用されます。 + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage +a girl with a microphone standing on a stage +detailed digital art of a girl with a microphone on a stage +``` + +ワイルドカード記法と組み合わせることも可能です。 + +メタデータファイルでも同様に複数行キャプションを指定することができます。メタデータの .json 内には、`\n` を使って改行を表現してください。キャプションファイルが複数行からなる場合、`merge_captions_to_metadata.py` を使うと、この形式でメタデータファイルが作成されます。 + +メタデータのタグ (`tags`) は、キャプションの各行に追加されます。 + +```json +{ + "/path/to/image.png": { + "caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2", + "tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus" + }, + ... +} +``` +この場合、実際のキャプションは `a cartoon of a frog with the word frog on it, open mouth, simple background ...` または `test multiline caption1, open mouth, simple background ...`、 `test multiline caption2, open mouth, simple background ...` 等になります。 + +### 設定ファイルの記述例:追加の区切り文字、ワイルドカード記法、`keep_tokens_separator` 等 + +```toml +[general] +flip_aug = true +color_aug = false +resolution = [1024, 1024] + +[[datasets]] +batch_size = 6 +enable_bucket = true +bucket_no_upscale = true +caption_extension = ".txt" +keep_tokens_separator= "|||" +shuffle_caption = true +caption_tag_dropout_rate = 0.1 +secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side +enable_wildcard = true # 同上 / same as above + + [[datasets.subsets]] + image_dir = "/path/to/image_dir" + num_repeats = 1 + + # ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically) + caption_prefix = "1girl, hatsune miku, vocaloid |||" + + # ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains + # 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself + caption_suffix = ", anime screencap ||| masterpiece, rating: general" +``` + +### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合 + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors +``` +`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (drop されたケース)などになります。 + +### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合 + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background +``` +ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。 + +```txt +1girl, hatsune miku, vocaloid, {{retro style}} +``` +タグ文字列に `{` や `}` そのものを含めたい場合は `{{` や `}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。 + +### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合 + +```txt +1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general +``` +`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` や `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。 diff --git a/docs/train_SDXL-en.md b/docs/train_SDXL-en.md new file mode 100644 index 000000000..a4c55b3fd --- /dev/null +++ b/docs/train_SDXL-en.md @@ -0,0 +1,84 @@ +## SDXL training + +The documentation will be moved to the training documentation in the future. The following is a brief explanation of the training scripts for SDXL. + +### Training scripts for SDXL + +- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset. + - `--full_bf16` option is added. Thanks to KohakuBlueleaf! + - This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage. + - The full bfloat16 training might be unstable. Please use it at your own risk. + - The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`. + - 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`. +- `prepare_buckets_latents.py` now supports SDXL fine-tuning. + +- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`. + +- Both scripts has following additional options: + - `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions. + - `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs. + +- `--weighted_captions` option is not supported yet for both scripts. + +- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`. + - `--cache_text_encoder_outputs` is not supported. + - There are two options for captions: + 1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens. + 2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored. + - See below for the format of the embeddings. + +- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000. + +### Utility scripts for SDXL + +- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance. + - The options are almost the same as `sdxl_train.py'. See the help message for the usage. + - Please launch the script as follows: + `accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...` + - This script should work with multi-GPU, but it is not tested in my environment. + +- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance. + - The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage. + +- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage. + +### Tips for SDXL training + +- The default resolution of SDXL is 1024x1024. +- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__: + - Train U-Net only. + - Use gradient checkpointing. + - Use `--cache_text_encoder_outputs` option and caching latents. + - Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work. +- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended: + - Train U-Net only. + - Use gradient checkpointing. + - Use `--cache_text_encoder_outputs` option and caching latents. + - Use one of 8bit optimizers or Adafactor optimizer. + - Use lower dim (4 to 8 for 8GB GPU). +- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected. +- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1. +- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training. + +Example of the optimizer settings for Adafactor with the fixed learning rate: +```toml +optimizer_type = "adafactor" +optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] +lr_scheduler = "constant_with_warmup" +lr_warmup_steps = 100 +learning_rate = 4e-7 # SDXL original learning rate +``` + +### Format of Textual Inversion embeddings for SDXL + +```python +from safetensors.torch import save_file + +state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768} +save_file(state_dict, file) +``` + +### ControlNet-LLLite + +ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details. + diff --git a/docs/train_lllite_README-ja.md b/docs/train_lllite_README-ja.md index dbdc1fea2..1f6a78d5c 100644 --- a/docs/train_lllite_README-ja.md +++ b/docs/train_lllite_README-ja.md @@ -21,9 +21,13 @@ ComfyUIのカスタムノードを用意しています。: https://github.com/k ## モデルの学習 ### データセットの準備 -通常のdatasetに加え、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。 +DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。 -たとえば DreamBooth 方式でキャプションファイルを用いる場合の設定ファイルは以下のようになります。 +(finetuning 方式の dataset はサポートしていません。) + +conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。 + +たとえば、キャプションにフォルダ名ではなくキャプションファイルを用いる場合の設定ファイルは以下のようになります。 ```toml [[datasets.subsets]] diff --git a/docs/train_lllite_README.md b/docs/train_lllite_README.md index 04dc12da2..a05f87f5f 100644 --- a/docs/train_lllite_README.md +++ b/docs/train_lllite_README.md @@ -26,7 +26,9 @@ Due to the limitations of the inference environment, only CrossAttention (attn1 ### Preparing the dataset -In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file. +In addition to the normal DreamBooth method dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file. + +(We do not support the finetuning method dataset.) ```toml [[datasets.subsets]] diff --git a/fine_tune.py b/fine_tune.py index 875a91951..a0350ce18 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -10,7 +10,9 @@ from tqdm import tqdm import torch +from library import deepspeed_utils from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed @@ -42,6 +44,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -108,6 +111,7 @@ def train(args): # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) @@ -158,7 +162,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # 学習を準備する if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() with torch.no_grad(): @@ -191,7 +195,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if not cache_latents: vae.requires_grad_(False) vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) for m in training_models: m.requires_grad_(True) @@ -246,13 +250,23 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): unet.to(weight_dtype) text_encoder.to(weight_dtype) - # acceleratorがなんかよろしくやってくれるらしい - if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler + if args.deepspeed: + if args.train_text_encoder: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) + else: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler ) + training_models = [ds_model] else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + # acceleratorがなんかよろしくやってくれるらしい + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: @@ -311,13 +325,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for step, batch in enumerate(train_dataloader): current_step.value = global_step - with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく + with accelerator.accumulate(*training_models): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype) latents = latents * 0.18215 b_size = latents.shape[0] @@ -457,7 +471,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -477,6 +491,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) @@ -492,6 +507,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) return parser diff --git a/finetune/blip/blip.py b/finetune/blip/blip.py index 7d192cb26..13b69ffd3 100644 --- a/finetune/blip/blip.py +++ b/finetune/blip/blip.py @@ -134,8 +134,9 @@ def forward(self, image, caption): def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0): image_embeds = self.visual_encoder(image) - if not sample: - image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) + # recent version of transformers seems to do repeat_interleave automatically + # if not sample: + # image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py index 60765b863..89f717473 100644 --- a/finetune/merge_captions_to_metadata.py +++ b/finetune/merge_captions_to_metadata.py @@ -6,75 +6,95 @@ import library.train_util as train_util import os from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def main(args): - assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" + assert not args.recursive or ( + args.recursive and args.full_path + ), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" - train_data_dir_path = Path(args.train_data_dir) - image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - logger.info(f"found {len(image_paths)} images.") + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + logger.info(f"found {len(image_paths)} images.") - if args.in_json is None and Path(args.out_json).is_file(): - args.in_json = args.out_json + if args.in_json is None and Path(args.out_json).is_file(): + args.in_json = args.out_json - if args.in_json is not None: - logger.info(f"loading existing metadata: {args.in_json}") - metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") - else: - logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") - metadata = {} + if args.in_json is not None: + logger.info(f"loading existing metadata: {args.in_json}") + metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8")) + logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") + else: + logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") + metadata = {} - logger.info("merge caption texts to metadata json.") - for image_path in tqdm(image_paths): - caption_path = image_path.with_suffix(args.caption_extension) - caption = caption_path.read_text(encoding='utf-8').strip() + logger.info("merge caption texts to metadata json.") + for image_path in tqdm(image_paths): + caption_path = image_path.with_suffix(args.caption_extension) + caption = caption_path.read_text(encoding="utf-8").strip() - if not os.path.exists(caption_path): - caption_path = os.path.join(image_path, args.caption_extension) + if not os.path.exists(caption_path): + caption_path = os.path.join(image_path, args.caption_extension) - image_key = str(image_path) if args.full_path else image_path.stem - if image_key not in metadata: - metadata[image_key] = {} + image_key = str(image_path) if args.full_path else image_path.stem + if image_key not in metadata: + metadata[image_key] = {} - metadata[image_key]['caption'] = caption - if args.debug: - logger.info(f"{image_key} {caption}") + metadata[image_key]["caption"] = caption + if args.debug: + logger.info(f"{image_key} {caption}") - # metadataを書き出して終わり - logger.info(f"writing metadata: {args.out_json}") - Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') - logger.info("done!") + # metadataを書き出して終わり + logger.info(f"writing metadata: {args.out_json}") + Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("--in_json", type=str, - help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子") - parser.add_argument("--full_path", action="store_true", - help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") - parser.add_argument("--recursive", action="store_true", - help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") - parser.add_argument("--debug", action="store_true", help="debug mode") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - - # スペルミスしていたオプションを復元する - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - - main(args) + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") + parser.add_argument( + "--in_json", + type=str, + help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)", + ) + parser.add_argument( + "--caption_extention", + type=str, + default=None, + help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)", + ) + parser.add_argument( + "--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子" + ) + parser.add_argument( + "--full_path", + action="store_true", + help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", + ) + parser.add_argument( + "--recursive", + action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す", + ) + parser.add_argument("--debug", action="store_true", help="debug mode") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + + # スペルミスしていたオプションを復元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + + main(args) diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index 9ef8f14b0..ce22d990e 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -6,70 +6,88 @@ import library.train_util as train_util import os from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def main(args): - assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" + assert not args.recursive or ( + args.recursive and args.full_path + ), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" - train_data_dir_path = Path(args.train_data_dir) - image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - logger.info(f"found {len(image_paths)} images.") + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + logger.info(f"found {len(image_paths)} images.") - if args.in_json is None and Path(args.out_json).is_file(): - args.in_json = args.out_json + if args.in_json is None and Path(args.out_json).is_file(): + args.in_json = args.out_json - if args.in_json is not None: - logger.info(f"loading existing metadata: {args.in_json}") - metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") - else: - logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") - metadata = {} + if args.in_json is not None: + logger.info(f"loading existing metadata: {args.in_json}") + metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8")) + logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") + else: + logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") + metadata = {} - logger.info("merge tags to metadata json.") - for image_path in tqdm(image_paths): - tags_path = image_path.with_suffix(args.caption_extension) - tags = tags_path.read_text(encoding='utf-8').strip() + logger.info("merge tags to metadata json.") + for image_path in tqdm(image_paths): + tags_path = image_path.with_suffix(args.caption_extension) + tags = tags_path.read_text(encoding="utf-8").strip() - if not os.path.exists(tags_path): - tags_path = os.path.join(image_path, args.caption_extension) + if not os.path.exists(tags_path): + tags_path = os.path.join(image_path, args.caption_extension) - image_key = str(image_path) if args.full_path else image_path.stem - if image_key not in metadata: - metadata[image_key] = {} + image_key = str(image_path) if args.full_path else image_path.stem + if image_key not in metadata: + metadata[image_key] = {} - metadata[image_key]['tags'] = tags - if args.debug: - logger.info(f"{image_key} {tags}") + metadata[image_key]["tags"] = tags + if args.debug: + logger.info(f"{image_key} {tags}") - # metadataを書き出して終わり - logger.info(f"writing metadata: {args.out_json}") - Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') + # metadataを書き出して終わり + logger.info(f"writing metadata: {args.out_json}") + Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8") - logger.info("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("--in_json", type=str, - help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") - parser.add_argument("--full_path", action="store_true", - help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") - parser.add_argument("--recursive", action="store_true", - help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") - parser.add_argument("--caption_extension", type=str, default=".txt", - help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子") - parser.add_argument("--debug", action="store_true", help="debug mode, print tags") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - main(args) + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") + parser.add_argument( + "--in_json", + type=str, + help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)", + ) + parser.add_argument( + "--full_path", + action="store_true", + help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", + ) + parser.add_argument( + "--recursive", + action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す", + ) + parser.add_argument( + "--caption_extension", + type=str, + default=".txt", + help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子", + ) + parser.add_argument("--debug", action="store_true", help="debug mode, print tags") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index b56d921a3..c19ae9749 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -12,8 +12,10 @@ import library.train_util as train_util from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) # from wd14 tagger @@ -79,34 +81,42 @@ def collate_fn_remove_corrupted(batch): def main(args): + # model location is model_dir + repo_id + # repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash + model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_")) + # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする # depreacatedの警告が出るけどなくなったらその時 # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 - if not os.path.exists(args.model_dir) or args.force_download: + if not os.path.exists(model_location) or args.force_download: + os.makedirs(args.model_dir, exist_ok=True) logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") files = FILES if args.onnx: + files = ["selected_tags.csv"] files += FILES_ONNX + else: + for file in SUB_DIR_FILES: + hf_hub_download( + args.repo_id, + file, + subfolder=SUB_DIR, + cache_dir=os.path.join(model_location, SUB_DIR), + force_download=True, + force_filename=file, + ) for file in files: - hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) - for file in SUB_DIR_FILES: - hf_hub_download( - args.repo_id, - file, - subfolder=SUB_DIR, - cache_dir=os.path.join(args.model_dir, SUB_DIR), - force_download=True, - force_filename=file, - ) + hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file) else: logger.info("using existing wd14 tagger model") # 画像を読み込む if args.onnx: + import torch import onnx import onnxruntime as ort - onnx_path = f"{args.model_dir}/model.onnx" + onnx_path = f"{model_location}/model.onnx" logger.info("Running wd14 tagger with onnx") logger.info(f"loading onnx model: {onnx_path}") @@ -123,7 +133,7 @@ def main(args): except: batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param - if args.batch_size != batch_size and type(batch_size) != str: + if args.batch_size != batch_size and type(batch_size) != str and batch_size > 0: # some rebatch model may use 'N' as dynamic axes logger.warning( f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" @@ -132,21 +142,32 @@ def main(args): del model - ort_sess = ort.InferenceSession( - onnx_path, - providers=["CUDAExecutionProvider"] - if "CUDAExecutionProvider" in ort.get_available_providers() - else ["CPUExecutionProvider"], - ) + if "OpenVINOExecutionProvider" in ort.get_available_providers(): + # requires provider options for gpu support + # fp16 causes nonsense outputs + ort_sess = ort.InferenceSession( + onnx_path, + providers=(["OpenVINOExecutionProvider"]), + provider_options=[{'device_type' : "GPU_FP32"}], + ) + else: + ort_sess = ort.InferenceSession( + onnx_path, + providers=( + ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else + ["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else + ["CPUExecutionProvider"] + ), + ) else: from tensorflow.keras.models import load_model - model = load_model(f"{args.model_dir}") + model = load_model(f"{model_location}") # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") # 依存ライブラリを増やしたくないので自力で読むよ - with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: + with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f: reader = csv.reader(f) l = [row for row in reader] header = l[0] # tag_id,name,category,count @@ -172,8 +193,8 @@ def run_batch(path_imgs): imgs = np.array([im for _, im in path_imgs]) if args.onnx: - if len(imgs) < args.batch_size: - imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) + # if len(imgs) < args.batch_size: + # imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy probs = probs[: len(path_imgs)] else: @@ -314,7 +335,9 @@ def setup_parser() -> argparse.ArgumentParser: help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ", ) parser.add_argument( - "--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします" + "--force_download", + action="store_true", + help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします", ) parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument( @@ -329,8 +352,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)", ) - parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") + parser.add_argument( + "--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子" + ) + parser.add_argument( + "--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値" + ) parser.add_argument( "--general_threshold", type=float, @@ -343,7 +370,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ", ) - parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") + parser.add_argument( + "--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する" + ) parser.add_argument( "--remove_underscore", action="store_true", @@ -356,9 +385,13 @@ def setup_parser() -> argparse.ArgumentParser: default="", help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", ) - parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") + parser.add_argument( + "--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する" + ) parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") - parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する") + parser.add_argument( + "--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する" + ) parser.add_argument( "--caption_separator", type=str, diff --git a/library/config_util.py b/library/config_util.py index eb652ecf3..d75d03b03 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -41,12 +41,17 @@ DatasetGroup, ) from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def add_config_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + parser.add_argument( + "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル" + ) # TODO: inherit Params class in Subset, Dataset @@ -80,6 +85,7 @@ class DreamBoothSubsetParams(BaseSubsetParams): is_reg: bool = False class_tokens: Optional[str] = None caption_extension: str = ".caption" + cache_info: bool = False @dataclass @@ -91,6 +97,7 @@ class FineTuningSubsetParams(BaseSubsetParams): class ControlNetSubsetParams(BaseSubsetParams): conditioning_data_dir: str = None caption_extension: str = ".caption" + cache_info: bool = False @dataclass @@ -200,6 +207,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] DB_SUBSET_ASCENDABLE_SCHEMA = { "caption_extension": str, "class_tokens": str, + "cache_info": bool, } DB_SUBSET_DISTINCT_SCHEMA = { Required("image_dir"): str, @@ -212,6 +220,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] } CN_SUBSET_ASCENDABLE_SCHEMA = { "caption_extension": str, + "cache_info": bool, } CN_SUBSET_DISTINCT_SCHEMA = { Required("image_dir"): str, @@ -248,9 +257,10 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] } def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: - assert ( - support_dreambooth or support_finetuning or support_controlnet - ), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + assert support_dreambooth or support_finetuning or support_controlnet, ( + "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more." + + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。" + ) self.db_subset_schema = self.__merge_dict( self.SUBSET_ASCENDABLE_SCHEMA, @@ -317,7 +327,10 @@ def validate_flex_dataset(dataset_config: dict): self.dataset_schema = validate_flex_dataset elif support_dreambooth: - self.dataset_schema = self.db_dataset_schema + if support_controlnet: + self.dataset_schema = self.cn_dataset_schema + else: + self.dataset_schema = self.db_dataset_schema elif support_finetuning: self.dataset_schema = self.ft_dataset_schema elif support_controlnet: @@ -362,7 +375,9 @@ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> return self.argparse_config_validator(argparse_namespace) except MultipleInvalid: # XXX: this should be a bug - logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") + logger.error( + "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。" + ) raise # NOTE: value would be overwritten by latter dict if there is already the same key @@ -547,11 +562,11 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu " ", ) - logger.info(f'{info}') + logger.info(f"{info}") # make buckets first because it determines the length of dataset # and set the same seed for all datasets - seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): logger.info(f"[Dataset {i}]") dataset.make_buckets() @@ -638,13 +653,17 @@ def load_user_config(file: str) -> dict: with open(file, "r") as f: config = json.load(f) except Exception: - logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error( + f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) raise elif file.name.lower().endswith(".toml"): try: config = toml.load(file) except Exception: - logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error( + f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) raise else: raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") @@ -671,13 +690,13 @@ def load_user_config(file: str) -> dict: train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) logger.info("[argparse_namespace]") - logger.info(f'{vars(argparse_namespace)}') + logger.info(f"{vars(argparse_namespace)}") user_config = load_user_config(config_args.dataset_config) logger.info("") logger.info("[user_config]") - logger.info(f'{user_config}') + logger.info(f"{user_config}") sanitizer = ConfigSanitizer( config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout @@ -686,10 +705,10 @@ def load_user_config(file: str) -> dict: logger.info("") logger.info("[sanitized_user_config]") - logger.info(f'{sanitized_user_config}') + logger.info(f"{sanitized_user_config}") blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) logger.info("") logger.info("[blueprint]") - logger.info(f'{blueprint}') + logger.info(f"{blueprint}") diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index a56474622..406e0e36e 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -3,11 +3,14 @@ import random import re from typing import List, Optional, Union -from .utils import setup_logging +from .utils import setup_logging + setup_logging() -import logging +import logging + logger = logging.getLogger(__name__) + def prepare_scheduler_for_custom_training(noise_scheduler, device): if hasattr(noise_scheduler, "all_snr"): return @@ -64,7 +67,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) if v_prediction: - snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device) + snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device) else: snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device) loss = loss * snr_weight @@ -92,13 +95,15 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los loss = loss + loss / scale * v_pred_like_loss return loss + def apply_debiased_estimation(loss, timesteps, noise_scheduler): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - weight = 1/torch.sqrt(snr_t) + weight = 1 / torch.sqrt(snr_t) loss = weight * loss return loss + # TODO train_utilと分散しているのでどちらかに寄せる @@ -474,6 +479,17 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise +def apply_masked_loss(loss, batch): + # mask image is -1 to 1. we need to convert it to 0 to 1 + mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + + # resize to the same size as the loss + mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") + mask_image = mask_image / 2 + 0.5 + loss = loss * mask_image + return loss + + """ ########################################## # Perlin Noise diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py new file mode 100644 index 000000000..99a7b2b3b --- /dev/null +++ b/library/deepspeed_utils.py @@ -0,0 +1,139 @@ +import os +import argparse +import torch +from accelerate import DeepSpeedPlugin, Accelerator + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def add_deepspeed_arguments(parser: argparse.ArgumentParser): + # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed + parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training") + parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.") + parser.add_argument( + "--offload_optimizer_device", + type=str, + default=None, + choices=[None, "cpu", "nvme"], + help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.", + ) + parser.add_argument( + "--offload_optimizer_nvme_path", + type=str, + default=None, + help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.", + ) + parser.add_argument( + "--offload_param_device", + type=str, + default=None, + choices=[None, "cpu", "nvme"], + help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.", + ) + parser.add_argument( + "--offload_param_nvme_path", + type=str, + default=None, + help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.", + ) + parser.add_argument( + "--zero3_init_flag", + action="store_true", + help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models." + "Only applicable with ZeRO Stage-3.", + ) + parser.add_argument( + "--zero3_save_16bit_model", + action="store_true", + help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.", + ) + parser.add_argument( + "--fp16_master_weights_and_gradients", + action="store_true", + help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.", + ) + + +def prepare_deepspeed_args(args: argparse.Namespace): + if not args.deepspeed: + return + + # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1. + args.max_data_loader_n_workers = 1 + + +def prepare_deepspeed_plugin(args: argparse.Namespace): + if not args.deepspeed: + return None + + try: + import deepspeed + except ImportError as e: + logger.error( + "deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed" + ) + exit(1) + + deepspeed_plugin = DeepSpeedPlugin( + zero_stage=args.zero_stage, + gradient_accumulation_steps=args.gradient_accumulation_steps, + gradient_clipping=args.max_grad_norm, + offload_optimizer_device=args.offload_optimizer_device, + offload_optimizer_nvme_path=args.offload_optimizer_nvme_path, + offload_param_device=args.offload_param_device, + offload_param_nvme_path=args.offload_param_nvme_path, + zero3_init_flag=args.zero3_init_flag, + zero3_save_16bit_model=args.zero3_save_16bit_model, + ) + deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size + deepspeed_plugin.deepspeed_config["train_batch_size"] = ( + args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"]) + ) + deepspeed_plugin.set_mixed_precision(args.mixed_precision) + if args.mixed_precision.lower() == "fp16": + deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow. + if args.full_fp16 or args.fp16_master_weights_and_gradients: + if args.offload_optimizer_device == "cpu" and args.zero_stage == 2: + deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True + logger.info("[DeepSpeed] full fp16 enable.") + else: + logger.info( + "[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage." + ) + + if args.offload_optimizer_device is not None: + logger.info("[DeepSpeed] start to manually build cpu_adam.") + deepspeed.ops.op_builder.CPUAdamBuilder().load() + logger.info("[DeepSpeed] building cpu_adam done.") + + return deepspeed_plugin + + +# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model. +def prepare_deepspeed_model(args: argparse.Namespace, **models): + # remove None from models + models = {k: v for k, v in models.items() if v is not None} + + class DeepSpeedWrapper(torch.nn.Module): + def __init__(self, **kw_models) -> None: + super().__init__() + self.models = torch.nn.ModuleDict() + + for key, model in kw_models.items(): + if isinstance(model, list): + model = torch.nn.ModuleList(model) + assert isinstance( + model, torch.nn.Module + ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" + self.models.update(torch.nn.ModuleDict({key: model})) + + def get_models(self): + return self.models + + ds_model = DeepSpeedWrapper(**models) + return ds_model diff --git a/library/ipex/attention.py b/library/ipex/attention.py index 8253c5b17..d989ad53d 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -122,15 +122,15 @@ def torch_bmm_32_bit(input, mat2, *, out=None): mat2[start_idx:end_idx], out=out ) + torch.xpu.synchronize(input.device) else: return original_torch_bmm(input, mat2, out=out) - torch.xpu.synchronize(input.device) return hidden_states original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): +def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): if query.device.type != "xpu": - return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) # Slice SDPA @@ -153,7 +153,7 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal + dropout_p=dropout_p, is_causal=is_causal, **kwargs ) else: hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( @@ -161,7 +161,7 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo key[start_idx:end_idx, start_idx_2:end_idx_2], value[start_idx:end_idx, start_idx_2:end_idx_2], attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal + dropout_p=dropout_p, is_causal=is_causal, **kwargs ) else: hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( @@ -169,9 +169,9 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo key[start_idx:end_idx], value[start_idx:end_idx], attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal + dropout_p=dropout_p, is_causal=is_causal, **kwargs ) + torch.xpu.synchronize(query.device) else: - return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) - torch.xpu.synchronize(query.device) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) return hidden_states diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index b1b9ccf0e..65089f39e 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -12,7 +12,7 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument if isinstance(device_ids, list) and len(device_ids) > 1: - logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices") + print("IPEX backend doesn't support DataParallel on multiple XPU devices") return module.to("xpu") def return_null_context(*args, **kwargs): # pylint: disable=unused-argument @@ -42,7 +42,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non original_interpolate = torch.nn.functional.interpolate @wraps(torch.nn.functional.interpolate) def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments - if antialias or align_corners is not None: + if antialias or align_corners is not None or mode == 'bicubic': return_device = tensor.device return_dtype = tensor.dtype return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, @@ -216,7 +216,9 @@ def torch_empty(*args, device=None, **kwargs): original_torch_randn = torch.randn @wraps(torch.randn) -def torch_randn(*args, device=None, **kwargs): +def torch_randn(*args, device=None, dtype=None, **kwargs): + if dtype == bytes: + dtype = None if check_device(device): return original_torch_randn(*args, device=return_xpu(device), **kwargs) else: @@ -256,11 +258,11 @@ def torch_Generator(device=None): original_torch_load = torch.load @wraps(torch.load) -def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs): +def torch_load(f, map_location=None, *args, **kwargs): if check_device(map_location): - return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + return original_torch_load(f, map_location=return_xpu(map_location), *args, **kwargs) else: - return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + return original_torch_load(f, map_location=map_location, *args, **kwargs) # Hijack Functions: diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 1932bf881..a29013e34 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -24,7 +24,6 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): - # load models for each process model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: diff --git a/library/train_util.py b/library/train_util.py index 8cc5f3b6a..49f9f5bf3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -63,12 +63,14 @@ from huggingface_hub import hf_hub_download import numpy as np from PIL import Image +import imagesize import cv2 import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec +import library.deepspeed_utils as deepspeed_utils from library.utils import setup_logging setup_logging() @@ -410,6 +412,7 @@ def __init__( is_reg: bool, class_tokens: Optional[str], caption_extension: str, + cache_info: bool, num_repeats, shuffle_caption, caption_separator: str, @@ -458,6 +461,7 @@ def __init__( self.caption_extension = caption_extension if self.caption_extension and not self.caption_extension.startswith("."): self.caption_extension = "." + self.caption_extension + self.cache_info = cache_info def __eq__(self, other) -> bool: if not isinstance(other, DreamBoothSubset): @@ -527,6 +531,7 @@ def __init__( image_dir: str, conditioning_data_dir: str, caption_extension: str, + cache_info: bool, num_repeats, shuffle_caption, caption_separator, @@ -574,6 +579,7 @@ def __init__( self.caption_extension = caption_extension if self.caption_extension and not self.caption_extension.startswith("."): self.caption_extension = "." + self.caption_extension + self.cache_info = cache_info def __eq__(self, other) -> bool: if not isinstance(other, ControlNetSubset): @@ -694,6 +700,10 @@ def process_caption(self, subset: BaseSubset, caption): else: # process wildcards if subset.enable_wildcard: + # if caption is multiline, random choice one line + if "\n" in caption: + caption = random.choice(caption.split("\n")) + # wildcard is like '{aaa|bbb|ccc...}' # escape the curly braces like {{ or }} replacer1 = "⦅" @@ -712,6 +722,9 @@ def replace_wildcard(match): # unescape the curly braces caption = caption.replace(replacer1, "{").replace(replacer2, "}") + else: + # if caption is multiline, use the first line + caption = caption.split("\n")[0] if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: fixed_tokens = [] @@ -1078,8 +1091,7 @@ def cache_text_encoder_outputs( ) def get_image_size(self, image_path): - image = Image.open(image_path) - return image.size + return imagesize.get(image_path) def load_image_with_face_info(self, subset: BaseSubset, image_path: str): img = load_image(image_path) @@ -1410,6 +1422,8 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): class DreamBoothDataset(BaseDataset): + IMAGE_INFO_CACHE_FILE = "metadata_cache.json" + def __init__( self, subsets: Sequence[DreamBoothSubset], @@ -1453,7 +1467,7 @@ def __init__( self.bucket_reso_steps = None # この情報は使われない self.bucket_no_upscale = False - def read_caption(img_path, caption_extension): + def read_caption(img_path, caption_extension, enable_wildcard): # captionの候補ファイル名を作る base_name = os.path.splitext(img_path)[0] base_name_face_det = base_name @@ -1472,7 +1486,10 @@ def read_caption(img_path, caption_extension): logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") raise e assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - caption = lines[0].strip() + if enable_wildcard: + caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 + else: + caption = lines[0].strip() break return caption @@ -1481,26 +1498,54 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.warning(f"not directory: {subset.image_dir}") return [], [] - img_paths = glob_images(subset.image_dir, "*") - logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う - captions = [] - missing_captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path, subset.caption_extension) - if cap_for_img is None and subset.class_tokens is None: + info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE) + use_cached_info_for_subset = subset.cache_info + if use_cached_info_for_subset: + logger.info( + f"using cached image info for this subset / このサブセットで、キャッシュされた画像情報を使います: {info_cache_file}" + ) + if not os.path.isfile(info_cache_file): logger.warning( - f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" + f"image info file not found. You can ignore this warning if this is the first time to use this subset" + + " / キャッシュファイルが見つかりませんでした。初回実行時はこの警告を無視してください: {metadata_file}" ) - captions.append("") - missing_captions.append(img_path) - else: - if cap_for_img is None: - captions.append(subset.class_tokens) + use_cached_info_for_subset = False + + if use_cached_info_for_subset: + # json: {`img_path`:{"caption": "caption...", "resolution": [width, height]}, ...} + with open(info_cache_file, "r", encoding="utf-8") as f: + metas = json.load(f) + img_paths = list(metas.keys()) + sizes = [meta["resolution"] for meta in metas.values()] + + # we may need to check image size and existence of image files, but it takes time, so user should check it before training + else: + img_paths = glob_images(subset.image_dir, "*") + sizes = [None] * len(img_paths) + + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + + if use_cached_info_for_subset: + captions = [meta["caption"] for meta in metas.values()] + missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""] + else: + # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う + captions = [] + missing_captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) + if cap_for_img is None and subset.class_tokens is None: + logger.warning( + f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" + ) + captions.append("") missing_captions.append(img_path) else: - captions.append(cap_for_img) + if cap_for_img is None: + captions.append(subset.class_tokens) + missing_captions.append(img_path) + else: + captions.append(cap_for_img) self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 @@ -1517,12 +1562,24 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.warning(missing_caption + f"... and {remaining_missing_captions} more") break logger.warning(missing_caption) - return img_paths, captions + + if not use_cached_info_for_subset and subset.cache_info: + logger.info(f"cache image info for / 画像情報をキャッシュします : {info_cache_file}") + sizes = [self.get_image_size(img_path) for img_path in tqdm(img_paths, desc="get image size")] + matas = {} + for img_path, caption, size in zip(img_paths, captions, sizes): + matas[img_path] = {"caption": caption, "resolution": list(size)} + with open(info_cache_file, "w", encoding="utf-8") as f: + json.dump(matas, f, ensure_ascii=False, indent=2) + logger.info(f"cache image info done for / 画像情報を出力しました : {info_cache_file}") + + # if sizes are not set, image size will be read in make_buckets + return img_paths, captions, sizes logger.info("prepare images.") num_train_images = 0 num_reg_images = 0 - reg_infos: List[ImageInfo] = [] + reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] for subset in subsets: if subset.num_repeats < 1: logger.warning( @@ -1536,7 +1593,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): ) continue - img_paths, captions = load_dreambooth_dir(subset) + img_paths, captions, sizes = load_dreambooth_dir(subset) if len(img_paths) < 1: logger.warning( f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します" @@ -1548,10 +1605,12 @@ def load_dreambooth_dir(subset: DreamBoothSubset): else: num_train_images += subset.num_repeats * len(img_paths) - for img_path, caption in zip(img_paths, captions): + for img_path, caption, size in zip(img_paths, captions, sizes): info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + if size is not None: + info.image_size = size if subset.is_reg: - reg_infos.append(info) + reg_infos.append((info, subset)) else: self.register_image(info, subset) @@ -1572,7 +1631,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): n = 0 first_loop = True while n < num_train_images: - for info in reg_infos: + for info, subset in reg_infos: if first_loop: self.register_image(info, subset) n += info.num_repeats @@ -1667,10 +1726,24 @@ def __init__( caption = img_md.get("caption") tags = img_md.get("tags") if caption is None: - caption = tags - elif tags is not None and len(tags) > 0: - caption = caption + ", " + tags - tags_list.append(tags) + caption = tags # could be multiline + tags = None + + if subset.enable_wildcard: + # tags must be single line + if tags is not None: + tags = tags.replace("\n", subset.caption_separator) + + # add tags to each line of caption + if caption is not None and tags is not None: + caption = "\n".join( + [f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""] + ) + else: + # use as is + if tags is not None and len(tags) > 0: + caption = caption + subset.caption_separator + tags + tags_list.append(tags) if caption is None: caption = "" @@ -1818,11 +1891,15 @@ def __init__( db_subsets = [] for subset in subsets: + assert ( + not subset.random_crop + ), "random_crop is not supported in ControlNetDataset / random_cropはControlNetDatasetではサポートされていません" db_subset = DreamBoothSubset( subset.image_dir, False, None, - subset.caption_extension, + subset.caption_extension, + subset.cache_info, subset.num_repeats, subset.shuffle_caption, subset.caption_separator, @@ -1868,7 +1945,7 @@ def __init__( # assert all conditioning data exists missing_imgs = [] - cond_imgs_with_img = set() + cond_imgs_with_pair = set() for image_key, info in self.dreambooth_dataset_delegate.image_data.items(): db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key] subset = None @@ -1882,23 +1959,29 @@ def __init__( logger.warning(f"not directory: {subset.conditioning_data_dir}") continue - img_basename = os.path.basename(info.absolute_path) - ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename) - if not os.path.exists(ctrl_img_path): + img_basename = os.path.splitext(os.path.basename(info.absolute_path))[0] + ctrl_img_path = glob_images(subset.conditioning_data_dir, img_basename) + if len(ctrl_img_path) < 1: missing_imgs.append(img_basename) + continue + ctrl_img_path = ctrl_img_path[0] + ctrl_img_path = os.path.abspath(ctrl_img_path) # normalize path info.cond_img_path = ctrl_img_path - cond_imgs_with_img.add(ctrl_img_path) + cond_imgs_with_pair.add(os.path.splitext(ctrl_img_path)[0]) # remove extension because Windows is case insensitive extra_imgs = [] for subset in subsets: conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*") - extra_imgs.extend( - [cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img] - ) + conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path + extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" - assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + assert ( + len(missing_imgs) == 0 + ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" + assert ( + len(extra_imgs) == 0 + ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" self.conditioning_image_transforms = IMAGE_TRANSFORMS @@ -2977,7 +3060,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--save_state", action="store_true", - help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する", + help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する", + ) + parser.add_argument( + "--save_state_on_train_end", + action="store_true", + help="save training state (including optimizer states etc.) on train end / optimizerなど学習状態も含めたstateを学習完了時に保存する", ) parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") @@ -3060,6 +3148,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う") + parser.add_argument( "--ddp_timeout", type=int, @@ -3122,12 +3211,18 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)", ) + parser.add_argument( "--noise_offset", type=float, default=None, help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)", ) + parser.add_argument( + "--noise_offset_random_strength", + action="store_true", + help="use random strength between 0~noise_offset for noise offset. / noise offsetにおいて、0からnoise_offsetの間でランダムな強度を使用します。", + ) parser.add_argument( "--multires_noise_iterations", type=int, @@ -3141,6 +3236,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) " + "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)", ) + parser.add_argument( + "--ip_noise_gamma_random_strength", + action="store_true", + help="Use random strength between 0~ip_noise_gamma for input perturbation noise." + + "/ input perturbation noiseにおいて、0からip_noise_gammaの間でランダムな強度を使用します。", + ) # parser.add_argument( # "--perlin_noise", # type=int, @@ -3284,6 +3385,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) +def add_masked_loss_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + parser.add_argument( + "--masked_loss", + action="store_true", + help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", + ) + + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable @@ -3343,6 +3458,18 @@ def verify_training_args(args: argparse.Namespace): + " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります" ) + if args.sample_every_n_epochs is not None and args.sample_every_n_epochs <= 0: + logger.warning( + "sample_every_n_epochs is less than or equal to 0, so it will be disabled / sample_every_n_epochsに0以下の値が指定されたため無効になります" + ) + args.sample_every_n_epochs = None + + if args.sample_every_n_steps is not None and args.sample_every_n_steps <= 0: + logger.warning( + "sample_every_n_steps is less than or equal to 0, so it will be disabled / sample_every_n_stepsに0以下の値が指定されたため無効になります" + ) + args.sample_every_n_steps = None + def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool @@ -3351,6 +3478,12 @@ def add_dataset_arguments( parser.add_argument( "--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ" ) + parser.add_argument( + "--cache_info", + action="store_true", + help="cache meta information (caption and image size) for faster dataset loading. only available for DreamBooth" + + " / メタ情報(キャプションとサイズ)をキャッシュしてデータセット読み込みを高速化する。DreamBooth方式のみ有効", + ) parser.add_argument( "--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする" ) @@ -3579,7 +3712,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar exit(1) logger.info(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: + with open(config_path, "r", encoding="utf-8") as f: config_dict = toml.load(f) # combine all sections into one @@ -4088,6 +4221,10 @@ def load_tokenizer(args: argparse.Namespace): def prepare_accelerator(args: argparse.Namespace): + """ + this function also prepares deepspeed plugin + """ + if args.logging_dir is None: logging_dir = None else: @@ -4133,6 +4270,8 @@ def prepare_accelerator(args: argparse.Namespace): ), ) kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) + deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -4140,6 +4279,7 @@ def prepare_accelerator(args: argparse.Namespace): project_dir=logging_dir, kwargs_handlers=kwargs_handlers, dynamo_backend=dynamo_backend, + deepspeed_plugin=deepspeed_plugin, ) print("accelerator device:", accelerator.device) return accelerator @@ -4210,7 +4350,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False): - # load models for each process for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") @@ -4221,7 +4360,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio accelerator.device if args.lowram else "cpu", unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2, ) - # work on low-ram device if args.lowram: text_encoder.to(accelerator.device) @@ -4230,7 +4368,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() - return text_encoder, vae, unet, load_stable_diffusion_format @@ -4749,7 +4886,11 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: - noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) + if args.noise_offset_random_strength: + noise_offset = torch.rand(1, device=latents.device) * args.noise_offset + else: + noise_offset = args.noise_offset + noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) if args.multires_noise_iterations: noise = custom_train_functions.pyramid_noise_like( noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount @@ -4766,7 +4907,11 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: - noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps) + if args.ip_noise_gamma_random_strength: + strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + strength = args.ip_noise_gamma + noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) diff --git a/requirements.txt b/requirements.txt index 9a1a6b670..7a2846ad1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,8 +6,10 @@ ftfy==6.1.1 opencv-python==4.7.0.68 einops==0.7.0 pytorch-lightning==1.9.0 -# bitsandbytes==0.39.1 -tensorboard==2.14.1 +bitsandbytes==0.43.0 +prodigyopt==1.0 +lion-pytorch==0.0.6 +tensorboard safetensors==0.4.2 # gradio==3.16.2 altair==4.2.2 @@ -15,6 +17,8 @@ easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 huggingface-hub==0.20.1 +# for Image utils +imagesize==1.4.1 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 @@ -22,13 +26,16 @@ huggingface-hub==0.20.1 # for WD14 captioning (tensorflow) # tensorflow==2.10.1 # for WD14 captioning (onnx) -# onnx==1.14.1 -# onnxruntime-gpu==1.16.0 -# onnxruntime==1.16.0 +# onnx==1.15.0 +# onnxruntime-gpu==1.17.1 +# onnxruntime==1.17.1 +# for cuda 12.1(default 11.8) +# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ + # this is for onnx: protobuf==3.20.3 # open clip for SDXL -open-clip-torch==2.20.0 +# open-clip-torch==2.20.0 # For logging rich==13.7.0 # for kohya_ss library diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 084735665..a1e93b7f0 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -11,20 +11,24 @@ import torch from library.device_utils import init_ipex, get_preferred_device + init_ipex() from tqdm import tqdm from transformers import CLIPTokenizer from diffusers import EulerDiscreteScheduler from PIL import Image -import open_clip + +# import open_clip from safetensors.torch import load_file from library import model_util, sdxl_model_util import networks.lora as lora from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) # scheduler: このあたりの設定はSD1/2と同じでいいらしい @@ -154,12 +158,13 @@ def get_timestep_embedding(x, outdim): text_model2.eval() unet.set_use_memory_efficient_attention(True, False) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える vae.set_use_memory_efficient_attention_xformers(True) # Tokenizers tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name) - tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77) + # tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77) + tokenizer2 = CLIPTokenizer.from_pretrained(text_encoder_2_name) # LoRA for weights_file in args.lora_weights: @@ -192,7 +197,9 @@ def generate_image(prompt, prompt2, negative_prompt, seed=None): emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256) # logger.info("emb1", emb1.shape) c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE) - uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right + uc_vector = c_vector.clone().to( + DEVICE, dtype=DTYPE + ) # ちょっとここ正しいかどうかわからない I'm not sure if this is right # crossattn @@ -215,13 +222,22 @@ def call_text_encoder(text, text2): # text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい # text encoder 2 - with torch.no_grad(): - tokens = tokenizer2(text2).to(DEVICE) + # tokens = tokenizer2(text2).to(DEVICE) + tokens = tokenizer2( + text, + truncation=True, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(DEVICE) + with torch.no_grad(): enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) text_embedding2_penu = enc_out["hidden_states"][-2] # logger.info("hidden_states2", text_embedding2_penu.shape) - text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion + text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion # 連結して終了 concat and finish text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2) diff --git a/sdxl_train.py b/sdxl_train.py index e0df263d6..816598e04 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -11,11 +11,13 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + + init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import sdxl_model_util +from library import deepspeed_utils, sdxl_model_util import library.train_util as train_util @@ -39,6 +41,7 @@ scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, + apply_masked_loss, ) from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -97,6 +100,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) assert ( @@ -124,7 +128,7 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -398,18 +402,33 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder1.to(weight_dtype) text_encoder2.to(weight_dtype) - # acceleratorがなんかよろしくやってくれるらしい - if train_unet: - unet = accelerator.prepare(unet) + # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer if train_text_encoder1: - # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) text_encoder1.text_model.final_layer_norm.requires_grad_(False) - text_encoder1 = accelerator.prepare(text_encoder1) - if train_text_encoder2: - text_encoder2 = accelerator.prepare(text_encoder2) - optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model( + args, + unet=unet if train_unet else None, + text_encoder1=text_encoder1 if train_text_encoder1 else None, + text_encoder2=text_encoder2 if train_text_encoder2 else None, + ) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # acceleratorがなんかよろしくやってくれるらしい + if train_unet: + unet = accelerator.prepare(unet) + if train_text_encoder1: + text_encoder1 = accelerator.prepare(text_encoder1) + if train_text_encoder2: + text_encoder2 = accelerator.prepare(text_encoder2) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: @@ -424,6 +443,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする @@ -576,9 +597,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss or args.debiased_estimation_loss + or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -712,7 +736,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.end_training() - if args.save_state: # and is_main_process: + if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -744,6 +768,8 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) @@ -779,7 +805,6 @@ def setup_parser() -> argparse.ArgumentParser: help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", ) - return parser diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 1e5f92349..9eaaa19f2 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -22,7 +22,7 @@ import accelerate from diffusers import DDPMScheduler, ControlNetModel from safetensors.torch import load_file -from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util +from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util import library.model_util as model_util import library.train_util as train_util @@ -394,10 +394,10 @@ def remove_model(old_ckpt_name): with accelerator.accumulate(unet): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): @@ -549,7 +549,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: @@ -566,6 +566,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index dac56eedd..e55a58896 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -18,7 +18,7 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel from safetensors.torch import load_file -from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util +from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util import library.model_util as model_util import library.train_util as train_util @@ -361,10 +361,10 @@ def remove_model(old_ckpt_name): with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): @@ -534,6 +534,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index b9a948bb2..257d181ad 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -7,7 +7,6 @@ from library.device_utils import init_ipex init_ipex() -import open_clip from library import sdxl_model_util, sdxl_train_util, train_util import train_textual_inversion diff --git a/train_controlnet.py b/train_controlnet.py index dc73a91c8..0cb0405fd 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -11,6 +11,7 @@ from tqdm import tqdm import torch +from library import deepspeed_utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -396,7 +397,7 @@ def remove_model(old_ckpt_name): with accelerator.accumulate(controlnet): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() @@ -565,7 +566,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく @@ -584,6 +585,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) diff --git a/train_db.py b/train_db.py index 8d36097a5..0a152f224 100644 --- a/train_db.py +++ b/train_db.py @@ -11,7 +11,10 @@ from tqdm import tqdm import torch +from library import deepspeed_utils from library.device_utils import init_ipex, clean_memory_on_device + + init_ipex() from accelerate.utils import set_seed @@ -32,6 +35,7 @@ apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, + apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments @@ -46,6 +50,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, False) + deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -57,7 +62,7 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True)) if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -219,12 +224,25 @@ def train(args): text_encoder.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - if train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler + if args.deepspeed: + if args.train_text_encoder: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) + else: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler ) + training_models = [ds_model] + else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + training_models = [unet, text_encoder] + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + training_models = [unet] if not train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error @@ -296,12 +314,14 @@ def train(args): if not args.gradient_checkpointing: text_encoder.train(False) text_encoder.requires_grad_(False) + if len(training_models) == 2: + training_models = training_models[0] # remove text_encoder from training_models - with accelerator.accumulate(unet): + with accelerator.accumulate(*training_models): with torch.no_grad(): # latentに変換 if cache_latents: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 @@ -339,6 +359,8 @@ def train(args): target = noise loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -444,7 +466,7 @@ def train(args): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -464,6 +486,8 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) diff --git a/train_network.py b/train_network.py index e5b26d8a2..ed569aea6 100644 --- a/train_network.py +++ b/train_network.py @@ -13,18 +13,15 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device -init_ipex() -from torch.nn.parallel import DistributedDataParallel as DDP +init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import model_util +from library import deepspeed_utils, model_util import library.train_util as train_util -from library.train_util import ( - DreamBoothDataset, -) +from library.train_util import DreamBoothDataset import library.config_util as config_util from library.config_util import ( ConfigSanitizer, @@ -39,6 +36,7 @@ scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, + apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments @@ -141,6 +139,7 @@ def train(self, args): training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -157,7 +156,7 @@ def train(self, args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) if use_user_config: logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -413,20 +412,36 @@ def train(self, args): t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good - if train_unet: - unet = accelerator.prepare(unet) + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model( + args, + unet=unet if train_unet else None, + text_encoder1=text_encoders[0] if train_text_encoder else None, + text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None, + network=network, + ) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_model = ds_model else: - unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator - if train_text_encoder: - if len(text_encoders) > 1: - text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] + if train_unet: + unet = accelerator.prepare(unet) else: - text_encoder = accelerator.prepare(text_encoder) - text_encoders = [text_encoder] - else: - pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set + unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator + if train_text_encoder: + if len(text_encoders) > 1: + text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] + else: + text_encoder = accelerator.prepare(text_encoder) + text_encoders = [text_encoder] + else: + pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler + ) + training_model = network if args.gradient_checkpointing: # according to TI example in Diffusers, train is required @@ -758,21 +773,21 @@ def remove_model(old_ckpt_name): for step, batch in enumerate(train_dataloader): current_step.value = global_step - with accelerator.accumulate(network): + with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) - latents = latents * self.vae_scale_factor + latents = latents * self.vae_scale_factor # get multiplier for each sample if network_has_multiplier: @@ -834,6 +849,8 @@ def remove_model(old_ckpt_name): target = noise loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -940,7 +957,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: @@ -957,6 +974,8 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index df1d8485a..e7083596f 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -8,12 +8,14 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + + init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer -from library import model_util +from library import deepspeed_utils, model_util import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -29,6 +31,7 @@ scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, + apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments @@ -268,7 +271,7 @@ def train(self, args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False)) if args.dataset_config is not None: accelerator.print(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -558,10 +561,10 @@ def remove_model(old_ckpt_name): with accelerator.accumulate(text_encoders[0]): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) latents = latents * self.vae_scale_factor # Get the text embedding for conditioning @@ -586,6 +589,8 @@ def remove_model(old_ckpt_name): target = noise loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -732,7 +737,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: @@ -749,6 +754,8 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser, False) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 695fad2a8..861d48d1d 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -8,7 +8,9 @@ from tqdm import tqdm import torch +from library import deepspeed_utils from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed @@ -31,6 +33,7 @@ apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, + apply_masked_loss, ) import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI @@ -200,7 +203,7 @@ def train(args): logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False)) if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -439,7 +442,7 @@ def remove_model(old_ckpt_name): with accelerator.accumulate(text_encoder): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() @@ -471,6 +474,8 @@ def remove_model(old_ckpt_name): target = noise loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -586,7 +591,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone() @@ -662,6 +667,8 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser, False)