diff --git a/README.md b/README.md index e2c606708..c9a13afc1 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,99 @@ +# Training Stable Cascade Stage C + +This is an experimental feature. There may be bugs. + +## Usage + +Training is run with `stable_cascade_train_stage_c.py`. + +The main options are the same as `sdxl_train.py`. The following options have been added. + +- `--effnet_checkpoint_path`: Specifies the path to the EfficientNetEncoder weights. +- `--stage_c_checkpoint_path`: Specifies the path to the Stage C weights. +- `--text_model_checkpoint_path`: Specifies the path to the Text Encoder weights. If omitted, the model from Hugging Face will be used. +- `--save_text_model`: Saves the model downloaded from Hugging Face to `--text_model_checkpoint_path`. +- `--previewer_checkpoint_path`: Specifies the path to the Previewer weights. Used to generate sample images during training. +- `--adaptive_loss_weight`: Uses [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) . If omitted, P2LossWeight is used. The official settings use Adaptive Loss Weight. + +The learning rate is set to 1e-4 in the official settings. + +The first time, specify `--text_model_checkpoint_path` and `--save_text_model` to save the Text Encoder weights. From the next time, specify `--text_model_checkpoint_path` to load the saved weights. + +Sample image generation during training is done with Perviewer. Perviewer is a simple decoder that converts EfficientNetEncoder latents to images. + +Some of the options for SDXL are simply ignored or cause an error (especially noise-related options such as `--noise_offset`). `--vae_batch_size` and `--no_half_vae` are applied directly to the EfficientNetEncoder (when `bf16` is specified for mixed precision, `--no_half_vae` is not necessary). + +Options for latents and Text Encoder output caches can be used as is, but since the EfficientNetEncoder is much lighter than the VAE, you may not need to use the cache unless memory is particularly tight. + +`--gradient_checkpointing`, `--full_bf16`, and `--full_fp16` (untested) to reduce memory consumption can be used as is. + +A scale of about 4 is suitable for sample image generation. + +Since the official settings use `bf16` for training, training with `fp16` may be unstable. + +The code for training the Text Encoder is also written, but it is untested. + +### Command line sample + +```batch +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 stable_cascade_train_stage_c.py --mixed_precision bf16 --save_precision bf16 --max_data_loader_n_workers 2 --persistent_data_loader_workers --gradient_checkpointing --learning_rate 1e-4 --optimizer_type adafactor --optimizer_args "scale_parameter=False" "relative_step=False" "warmup_init=False" --max_train_epochs 10 --save_every_n_epochs 1 --save_precision bf16 --output_dir ../output --output_name sc_test - --stage_c_checkpoint_path ../models/stage_c_bf16.safetensors --effnet_checkpoint_path ../models/effnet_encoder.safetensors --previewer_checkpoint_path ../models/previewer.safetensors --dataset_config ../dataset/config_bs1.toml --sample_every_n_epochs 1 --sample_prompts ../dataset/prompts.txt --adaptive_loss_weight +``` + +### About the dataset for fine tuning + +If the latents cache files for SD/SDXL exist (extension `*.npz`), it will be read and an error will occur during training. Please move them to another location in advance. + +After that, run `finetune/prepare_buckets_latents.py` with the `--stable_cascade` option to create latents cache files for Stable Cascade (suffix `_sc_latents.npz` is added). + + +# Stable Cascade Stage C の学習 + +実験的機能です。不具合があるかもしれません。 + +## 使い方 + +学習は `stable_cascade_train_stage_c.py` で行います。 + +主なオプションは `sdxl_train.py` と同様です。以下のオプションが追加されています。 + +- `--effnet_checkpoint_path` : EfficientNetEncoder の重みのパスを指定します。 +- `--stage_c_checkpoint_path` : Stage C の重みのパスを指定します。 +- `--text_model_checkpoint_path` : Text Encoder の重みのパスを指定します。省略時は Hugging Face のモデルを使用します。 +- `--save_text_model` : `--text_model_checkpoint_path` にHugging Face からダウンロードしたモデルを保存します。 +- `--previewer_checkpoint_path` : Previewer の重みのパスを指定します。学習中のサンプル画像生成に使用します。 +- `--adaptive_loss_weight` : [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) を用います。省略時は P2LossWeight が使用されます。公式では Adaptive Loss Weight が使用されているようです。 + +学習率は、公式の設定では 1e-4 のようです。 + +初回は `--text_model_checkpoint_path` と `--save_text_model` を指定して、Text Encoder の重みを保存すると良いでしょう。次からは `--text_model_checkpoint_path` を指定して、保存した重みを読み込むことができます。 + +学習中のサンプル画像生成は Perviewer で行われます。Previewer は EfficientNetEncoder の latents を画像に変換する簡易的な decoder です。 + +SDXL の向けの一部のオプションは単に無視されるか、エラーになります(特に `--noise_offset` などのノイズ関係)。`--vae_batch_size` および `--no_half_vae` はそのまま EfficientNetEncoder に適用されます(mixed precision に `bf16` 指定時は `--no_half_vae` は不要のようです)。 + +latents および Text Encoder 出力キャッシュのためのオプションはそのまま使用できますが、EfficientNetEncoder は VAE よりもかなり軽量のため、メモリが特に厳しい場合以外はキャッシュを使用する必要はないかもしれません。 + +メモリ消費を抑えるための `--gradient_checkpointing` 、`--full_bf16`、`--full_fp16`(未テスト)はそのまま使用できます。 + +サンプル画像生成時の Scale には 4 程度が適しているようです。 + +公式の設定では学習に `bf16` を用いているため、`fp16` での学習は不安定かもしれません。 + +Text Encoder 学習のコードも書いてありますが、未テストです。 + +### コマンドラインのサンプル + +[Command-line-sample](#command-line-sample)を参照してください。 + + +### fine tuning方式のデータセットについて + +SD/SDXL 向けの latents キャッシュファイル(拡張子 `*.npz`)が存在するとそれを読み込んでしまい学習時にエラーになります。あらかじめ他の場所に退避しておいてください。 + +その後、`finetune/prepare_buckets_latents.py` をオプション `--stable_cascade` を指定して実行すると、Stable Cascade 向けの latents キャッシュファイル(接尾辞 `_sc_latents.npz` が付きます)が作成されます。 + +--- + __SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training). This repository contains training, generation and utility scripts for Stable Diffusion. @@ -249,6 +345,45 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Working in progress + +- The log output has been improved. PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) Thanks to shirayu! + - The log is formatted by default. The `rich` library is required. Please see [Upgrade](#upgrade) and update the library. + - If `rich` is not installed, the log output will be the same as before. + - The following options are available in each training script: + - `--console_log_simple` option can be used to switch to the previous log output. + - `--console_log_level` option can be used to specify the log level. The default is `INFO`. + - `--console_log_file` option can be used to output the log to a file. The default is `None` (output to the console). +- The sample image generation during multi-GPU training is now done with multiple GPUs. PR [#1061](https://github.com/kohya-ss/sd-scripts/pull/1061) Thanks to DKnight54! +- The support for mps devices is improved. PR [#1054](https://github.com/kohya-ss/sd-scripts/pull/1054) Thanks to akx! If mps device exists instead of CUDA, the mps device is used automatically. +- An option `--highvram` to disable the optimization for environments with little VRAM is added to the training scripts. If you specify it when there is enough VRAM, the operation will be faster. + - Currently, only the cache part of latents is optimized. +- The IPEX support is improved. PR [#1086](https://github.com/kohya-ss/sd-scripts/pull/1086) Thanks to Disty0! +- Fixed a bug that `svd_merge_lora.py` crashes in some cases. PR [#1087](https://github.com/kohya-ss/sd-scripts/pull/1087) Thanks to mgz-dev! +- The common image generation script `gen_img.py` for SD 1/2 and SDXL is added. The basic functions are the same as the scripts for SD 1/2 and SDXL, but some new features are added. + - External scripts to generate prompts can be supported. It can be called with `--from_module` option. (The documentation will be added later) + - The normalization method after prompt weighting can be specified with `--emb_normalize_mode` option. `original` is the original method, `abs` is the normalization with the average of the absolute values, `none` is no normalization. +- Gradual Latent Hires fix is added to each generation script. See [here](./docs/gen_img_README-ja.md#about-gradual-latent) for details. + +- ログ出力が改善されました。 PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) shirayu 氏に感謝します。 + - デフォルトでログが成形されます。`rich` ライブラリが必要なため、[Upgrade](#upgrade) を参照し更新をお願いします。 + - `rich` がインストールされていない場合は、従来のログ出力になります。 + - 各学習スクリプトでは以下のオプションが有効です。 + - `--console_log_simple` オプションで従来のログ出力に切り替えられます。 + - `--console_log_level` でログレベルを指定できます。デフォルトは `INFO` です。 + - `--console_log_file` でログファイルを出力できます。デフォルトは `None`(コンソールに出力) です。 +- 複数 GPU 学習時に学習中のサンプル画像生成を複数 GPU で行うようになりました。 PR [#1061](https://github.com/kohya-ss/sd-scripts/pull/1061) DKnight54 氏に感謝します。 +- mps デバイスのサポートが改善されました。 PR [#1054](https://github.com/kohya-ss/sd-scripts/pull/1054) akx 氏に感謝します。CUDA ではなく mps が存在する場合には自動的に mps デバイスを使用します。 +- 学習スクリプトに VRAMが少ない環境向け最適化を無効にするオプション `--highvram` を追加しました。VRAM に余裕がある場合に指定すると動作が高速化されます。 + - 現在は latents のキャッシュ部分のみ高速化されます。 +- IPEX サポートが改善されました。 PR [#1086](https://github.com/kohya-ss/sd-scripts/pull/1086) Disty0 氏に感謝します。 +- `svd_merge_lora.py` が場合によってエラーになる不具合が修正されました。 PR [#1087](https://github.com/kohya-ss/sd-scripts/pull/1087) mgz-dev 氏に感謝します。 +- SD 1/2 および SDXL 共通の生成スクリプト `gen_img.py` を追加しました。基本的な機能は SD 1/2、SDXL 向けスクリプトと同じですが、いくつかの新機能が追加されています。 + - プロンプトを動的に生成する外部スクリプトをサポートしました。 `--from_module` で呼び出せます。(ドキュメントはのちほど追加します) + - プロンプト重みづけ後の正規化方法を `--emb_normalize_mode` で指定できます。`original` は元の方法、`abs` は絶対値の平均値で正規化、`none` は正規化を行いません。 +- Gradual Latent Hires fix を各生成スクリプトに追加しました。詳細は [こちら](./docs/gen_img_README-ja.md#about-gradual-latent)。 + + ### Jan 27, 2024 / 2024/1/27: v0.8.3 - Fixed a bug that the training crashes when `--fp8_base` is specified with `--save_state`. PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) Thanks to feffy380! diff --git a/XTI_hijack.py b/XTI_hijack.py index 1dbc263ac..93bc1c0b1 100644 --- a/XTI_hijack.py +++ b/XTI_hijack.py @@ -1,7 +1,7 @@ import torch -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex init_ipex() + from typing import Union, List, Optional, Dict, Any, Tuple from diffusers.models.unet_2d_condition import UNet2DConditionOutput diff --git a/docs/gen_img_README-ja.md b/docs/gen_img_README-ja.md index cf35f1df7..8f4442d00 100644 --- a/docs/gen_img_README-ja.md +++ b/docs/gen_img_README-ja.md @@ -452,3 +452,36 @@ python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt - `--network_show_meta` : 追加ネットワークのメタデータを表示します。 + +--- + +# About Gradual Latent + +Gradual Latent is a Hires fix that gradually increases the size of the latent. `gen_img.py`, `sdxl_gen_img.py`, and `gen_img_diffusers.py` have the following options. + +- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first. +- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size. +- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0. +- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps. + +Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`. + +__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers. + +It is more effective with SD 1.5. It is quite subtle with SDXL. + +# Gradual Latent について + +latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py` 、``sdxl_gen_img.py` 、`gen_img_diffusers.py` に以下のオプションが追加されています。 + +- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。 +- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。 +- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。 +- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。 + +それぞれのオプションは、プロンプトオプション、`--glt`、`--glr`、`--gls`、`--gle` でも指定できます。 + +サンプラーに手を加えているため、__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。 + +SD 1.5 のほうが効果があります。SDXL ではかなり微妙です。 + diff --git a/fine_tune.py b/fine_tune.py index 982dc8aec..8df896b43 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -2,22 +2,27 @@ # XXX dropped option: hypernetwork training import argparse -import gc import math import os from multiprocessing import Value import toml from tqdm import tqdm -import torch - -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + import library.train_util as train_util import library.config_util as config_util from library.config_util import ( @@ -37,6 +42,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -49,11 +55,11 @@ def train(args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -86,7 +92,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" ) return @@ -97,7 +103,7 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -158,9 +164,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -223,7 +227,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -287,7 +293,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) @@ -461,12 +467,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): train_util.save_sd_model_on_train_end( args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) @@ -475,7 +482,9 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) - parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") + parser.add_argument( + "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + ) parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument( "--learning_rate_te", diff --git a/finetune/blip/blip.py b/finetune/blip/blip.py index 7851fb08b..7d192cb26 100644 --- a/finetune/blip/blip.py +++ b/finetune/blip/blip.py @@ -21,6 +21,10 @@ import os from urllib.parse import urlparse from timm.models.hub import download_cached_file +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class BLIP_Base(nn.Module): def __init__(self, @@ -235,6 +239,6 @@ def load_checkpoint(model,url_or_filename): del state_dict[key] msg = model.load_state_dict(state_dict,strict=False) - print('load checkpoint from %s'%url_or_filename) + logger.info('load checkpoint from %s'%url_or_filename) return model,msg diff --git a/finetune/clean_captions_and_tags.py b/finetune/clean_captions_and_tags.py index 68839eccc..5aeb17425 100644 --- a/finetune/clean_captions_and_tags.py +++ b/finetune/clean_captions_and_tags.py @@ -8,6 +8,10 @@ import re from tqdm import tqdm +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') @@ -36,13 +40,13 @@ def clean_tags(image_key, tags): tokens = tags.split(", rating") if len(tokens) == 1: # WD14 taggerのときはこちらになるのでメッセージは出さない - # print("no rating:") - # print(f"{image_key} {tags}") + # logger.info("no rating:") + # logger.info(f"{image_key} {tags}") pass else: if len(tokens) > 2: - print("multiple ratings:") - print(f"{image_key} {tags}") + logger.info("multiple ratings:") + logger.info(f"{image_key} {tags}") tags = tokens[0] tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策 @@ -124,43 +128,43 @@ def clean_caption(caption): def main(args): if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") with open(args.in_json, "rt", encoding='utf-8') as f: metadata = json.load(f) else: - print("no metadata / メタデータファイルがありません") + logger.error("no metadata / メタデータファイルがありません") return - print("cleaning captions and tags.") + logger.info("cleaning captions and tags.") image_keys = list(metadata.keys()) for image_key in tqdm(image_keys): tags = metadata[image_key].get('tags') if tags is None: - print(f"image does not have tags / メタデータにタグがありません: {image_key}") + logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}") else: org = tags tags = clean_tags(image_key, tags) metadata[image_key]['tags'] = tags if args.debug and org != tags: - print("FROM: " + org) - print("TO: " + tags) + logger.info("FROM: " + org) + logger.info("TO: " + tags) caption = metadata[image_key].get('caption') if caption is None: - print(f"image does not have caption / メタデータにキャプションがありません: {image_key}") + logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}") else: org = caption caption = clean_caption(caption) metadata[image_key]['caption'] = caption if args.debug and org != caption: - print("FROM: " + org) - print("TO: " + caption) + logger.info("FROM: " + org) + logger.info("TO: " + caption) # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") with open(args.out_json, "wt", encoding='utf-8') as f: json.dump(metadata, f, indent=2) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: @@ -178,10 +182,10 @@ def setup_parser() -> argparse.ArgumentParser: args, unknown = parser.parse_known_args() if len(unknown) == 1: - print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") - print("All captions and tags in the metadata are processed.") - print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") - print("メタデータ内のすべてのキャプションとタグが処理されます。") + logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") + logger.warning("All captions and tags in the metadata are processed.") + logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") + logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。") args.in_json = args.out_json args.out_json = unknown[0] elif len(unknown) > 0: diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 074576bc2..489bdbcce 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -9,14 +9,22 @@ from PIL import Image from tqdm import tqdm import numpy as np + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from torchvision import transforms from torchvision.transforms.functional import InterpolationMode sys.path.append(os.path.dirname(__file__)) from blip.blip import blip_decoder, is_url import library.train_util as train_util +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DEVICE = get_preferred_device() IMAGE_SIZE = 384 @@ -47,7 +55,7 @@ def __getitem__(self, idx): # convert to tensor temporarily so dataloader will accept it tensor = IMAGE_TRANSFORM(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None return (tensor, img_path) @@ -74,21 +82,21 @@ def main(args): args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path cwd = os.getcwd() - print("Current Working Directory is: ", cwd) + logger.info(f"Current Working Directory is: {cwd}") os.chdir("finetune") if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights): args.caption_weights = os.path.join("..", args.caption_weights) - print(f"load images from {args.train_data_dir}") + logger.info(f"load images from {args.train_data_dir}") train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") - print(f"loading BLIP caption: {args.caption_weights}") + logger.info(f"loading BLIP caption: {args.caption_weights}") model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json") model.eval() model = model.to(DEVICE) - print("BLIP loaded") + logger.info("BLIP loaded") # captioningする def run_batch(path_imgs): @@ -108,7 +116,7 @@ def run_batch(path_imgs): with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: f.write(caption + "\n") if args.debug: - print(image_path, caption) + logger.info(f'{image_path} {caption}') # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -138,7 +146,7 @@ def run_batch(path_imgs): raw_image = raw_image.convert("RGB") img_tensor = IMAGE_TRANSFORM(raw_image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue b_imgs.append((image_path, img_tensor)) @@ -148,7 +156,7 @@ def run_batch(path_imgs): if len(b_imgs) > 0: run_batch(b_imgs) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index b3c5cc423..edeebadf3 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -5,12 +5,19 @@ from pathlib import Path from PIL import Image from tqdm import tqdm + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from transformers import AutoProcessor, AutoModelForCausalLM from transformers.generation.utils import GenerationMixin import library.train_util as train_util - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -35,8 +42,8 @@ def remove_words(captions, debug): for pat in PATTERN_REPLACE: cap = pat.sub("", cap) if debug and cap != caption: - print(caption) - print(cap) + logger.info(caption) + logger.info(cap) removed_caps.append(cap) return removed_caps @@ -70,16 +77,16 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs) GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch """ - print(f"load images from {args.train_data_dir}") + logger.info(f"load images from {args.train_data_dir}") train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") # できればcacheに依存せず明示的にダウンロードしたい - print(f"loading GIT: {args.model_id}") + logger.info(f"loading GIT: {args.model_id}") git_processor = AutoProcessor.from_pretrained(args.model_id) git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) - print("GIT loaded") + logger.info("GIT loaded") # captioningする def run_batch(path_imgs): @@ -97,7 +104,7 @@ def run_batch(path_imgs): with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: f.write(caption + "\n") if args.debug: - print(image_path, caption) + logger.info(f"{image_path} {caption}") # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -126,7 +133,7 @@ def run_batch(path_imgs): if image.mode != "RGB": image = image.convert("RGB") except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue b_imgs.append((image_path, image)) @@ -137,7 +144,7 @@ def run_batch(path_imgs): if len(b_imgs) > 0: run_batch(b_imgs) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py index 241f6f902..60765b863 100644 --- a/finetune/merge_captions_to_metadata.py +++ b/finetune/merge_captions_to_metadata.py @@ -5,26 +5,30 @@ from tqdm import tqdm import library.train_util as train_util import os +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def main(args): assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" train_data_dir_path = Path(args.train_data_dir) image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") if args.in_json is None and Path(args.out_json).is_file(): args.in_json = args.out_json if args.in_json is not None: - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") + logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") else: - print("new metadata will be created / 新しいメタデータファイルが作成されます") + logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") metadata = {} - print("merge caption texts to metadata json.") + logger.info("merge caption texts to metadata json.") for image_path in tqdm(image_paths): caption_path = image_path.with_suffix(args.caption_extension) caption = caption_path.read_text(encoding='utf-8').strip() @@ -38,12 +42,12 @@ def main(args): metadata[image_key]['caption'] = caption if args.debug: - print(image_key, caption) + logger.info(f"{image_key} {caption}") # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index db1bff6da..9ef8f14b0 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -5,26 +5,30 @@ from tqdm import tqdm import library.train_util as train_util import os +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def main(args): assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" train_data_dir_path = Path(args.train_data_dir) image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") if args.in_json is None and Path(args.out_json).is_file(): args.in_json = args.out_json if args.in_json is not None: - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") + logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") else: - print("new metadata will be created / 新しいメタデータファイルが作成されます") + logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") metadata = {} - print("merge tags to metadata json.") + logger.info("merge tags to metadata json.") for image_path in tqdm(image_paths): tags_path = image_path.with_suffix(args.caption_extension) tags = tags_path.read_text(encoding='utf-8').strip() @@ -38,13 +42,13 @@ def main(args): metadata[image_key]['tags'] = tags if args.debug: - print(image_key, tags) + logger.info(f"{image_key} {tags}") # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 1bccb1d3b..c4b2b1fec 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -8,13 +8,25 @@ import numpy as np from PIL import Image import cv2 + import torch +from library.device_utils import init_ipex, get_preferred_device + +init_ipex() + from torchvision import transforms import library.model_util as model_util +import library.stable_cascade_utils as sc_utils import library.train_util as train_util +from library.utils import setup_logging + +setup_logging() +import logging -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +logger = logging.getLogger(__name__) + +DEVICE = get_preferred_device() IMAGE_TRANSFORMS = transforms.Compose( [ @@ -34,7 +46,7 @@ def collate_fn_remove_corrupted(batch): return batch -def get_npz_filename(data_dir, image_key, is_full_path, recursive): +def get_npz_filename(data_dir, image_key, is_full_path, recursive, stable_cascade): if is_full_path: base_name = os.path.splitext(os.path.basename(image_key))[0] relative_path = os.path.relpath(os.path.dirname(image_key), data_dir) @@ -42,31 +54,32 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive): base_name = image_key relative_path = "" + ext = ".npz" if not stable_cascade else train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX if recursive and relative_path: - return os.path.join(data_dir, relative_path, base_name) + ".npz" + return os.path.join(data_dir, relative_path, base_name) + ext else: - return os.path.join(data_dir, base_name) + ".npz" + return os.path.join(data_dir, base_name) + ext def main(args): # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" if args.bucket_reso_steps % 8 > 0: - print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") + logger.warning(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") if args.bucket_reso_steps % 32 > 0: - print( + logger.warning( f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません" ) train_data_dir_path = Path(args.train_data_dir) image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") with open(args.in_json, "rt", encoding="utf-8") as f: metadata = json.load(f) else: - print(f"no metadata / メタデータファイルがありません: {args.in_json}") + logger.error(f"no metadata / メタデータファイルがありません: {args.in_json}") return weight_dtype = torch.float32 @@ -75,13 +88,20 @@ def main(args): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - vae = model_util.load_vae(args.model_name_or_path, weight_dtype) + if not args.stable_cascade: + vae = model_util.load_vae(args.model_name_or_path, weight_dtype) + divisor = 8 + else: + vae = sc_utils.load_effnet(args.model_name_or_path, DEVICE) + divisor = 32 vae.eval() vae.to(DEVICE, dtype=weight_dtype) # bucketのサイズを計算する max_reso = tuple([int(t) for t in args.max_resolution.split(",")]) - assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" + assert ( + len(max_reso) == 2 + ), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" bucket_manager = train_util.BucketManager( args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps @@ -89,7 +109,7 @@ def main(args): if not args.bucket_no_upscale: bucket_manager.make_buckets() else: - print( + logger.warning( "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" ) @@ -130,7 +150,7 @@ def process_batch(is_last): if image.mode != "RGB": image = image.convert("RGB") except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] @@ -146,6 +166,10 @@ def process_batch(is_last): # メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8) + # 追加情報を記録 + metadata[image_key]["original_size"] = (image.width, image.height) + metadata[image_key]["train_resized_size"] = resized_size + if not args.bucket_no_upscale: # upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する assert ( @@ -160,9 +184,9 @@ def process_batch(is_last): ), f"internal error resized size is small: {resized_size}, {reso}" # 既に存在するファイルがあればshape等を確認して同じならskipする - npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive) + npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive, args.stable_cascade) if args.skip_existing: - if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug): + if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug, divisor): continue # バッチへ追加 @@ -183,15 +207,15 @@ def process_batch(is_last): for i, reso in enumerate(bucket_manager.resos): count = bucket_counts.get(reso, 0) if count > 0: - print(f"bucket {i} {reso}: {count}") + logger.info(f"bucket {i} {reso}: {count}") img_ar_errors = np.array(img_ar_errors) - print(f"mean ar error: {np.mean(img_ar_errors)}") + logger.info(f"mean ar error: {np.mean(img_ar_errors)}") # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") with open(args.out_json, "wt", encoding="utf-8") as f: json.dump(metadata, f, indent=2) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: @@ -200,7 +224,14 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") - parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)") + parser.add_argument( + "--stable_cascade", + action="store_true", + help="prepare EffNet latents for stable cascade / stable cascade用のEffNetのlatentsを準備する", + ) + parser.add_argument( + "--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)" + ) parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument( "--max_data_loader_n_workers", @@ -223,10 +254,16 @@ def setup_parser() -> argparse.ArgumentParser: help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", ) parser.add_argument( - "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" + "--bucket_no_upscale", + action="store_true", + help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します", ) parser.add_argument( - "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="use mixed precision / 混合精度を使う場合、その精度", ) parser.add_argument( "--full_path", @@ -234,7 +271,9 @@ def setup_parser() -> argparse.ArgumentParser: help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", ) parser.add_argument( - "--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する" + "--flip_aug", + action="store_true", + help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する", ) parser.add_argument( "--skip_existing", diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index fbf328e83..b56d921a3 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,6 +11,10 @@ from tqdm import tqdm import library.train_util as train_util +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # from wd14 tagger IMAGE_SIZE = 448 @@ -58,7 +62,7 @@ def __getitem__(self, idx): image = preprocess_image(image) tensor = torch.tensor(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None return (tensor, img_path) @@ -79,7 +83,7 @@ def main(args): # depreacatedの警告が出るけどなくなったらその時 # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 if not os.path.exists(args.model_dir) or args.force_download: - print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") + logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") files = FILES if args.onnx: files += FILES_ONNX @@ -95,7 +99,7 @@ def main(args): force_filename=file, ) else: - print("using existing wd14 tagger model") + logger.info("using existing wd14 tagger model") # 画像を読み込む if args.onnx: @@ -103,8 +107,8 @@ def main(args): import onnxruntime as ort onnx_path = f"{args.model_dir}/model.onnx" - print("Running wd14 tagger with onnx") - print(f"loading onnx model: {onnx_path}") + logger.info("Running wd14 tagger with onnx") + logger.info(f"loading onnx model: {onnx_path}") if not os.path.exists(onnx_path): raise Exception( @@ -121,7 +125,7 @@ def main(args): if args.batch_size != batch_size and type(batch_size) != str: # some rebatch model may use 'N' as dynamic axes - print( + logger.warning( f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" ) args.batch_size = batch_size @@ -156,7 +160,7 @@ def main(args): train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") tag_freq = {} @@ -237,7 +241,10 @@ def run_batch(path_imgs): with open(caption_file, "wt", encoding="utf-8") as f: f.write(tag_text + "\n") if args.debug: - print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") + logger.info("") + logger.info(f"{image_path}:") + logger.info(f"\tCharacter tags: {character_tag_text}") + logger.info(f"\tGeneral tags: {general_tag_text}") # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -269,7 +276,7 @@ def run_batch(path_imgs): image = image.convert("RGB") image = preprocess_image(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue b_imgs.append((image_path, image)) @@ -284,11 +291,11 @@ def run_batch(path_imgs): if args.frequency_tags: sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True) - print("\nTag frequencies:") + print("Tag frequencies:") for tag, freq in sorted_tags: print(f"{tag}: {freq}") - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/gen_img.py b/gen_img.py new file mode 100644 index 000000000..a24220a0a --- /dev/null +++ b/gen_img.py @@ -0,0 +1,3326 @@ +import itertools +import json +from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable +import glob +import importlib +import importlib.util +import sys +import inspect +import time +import zipfile +from diffusers.utils import deprecate +from diffusers.configuration_utils import FrozenDict +import argparse +import math +import os +import random +import re + +import diffusers +import numpy as np +import torch + +from library.ipex_interop import init_ipex + +init_ipex() + +import torchvision +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + # UNet2DConditionModel, + StableDiffusionPipeline, +) +from einops import rearrange +from tqdm import tqdm +from torchvision import transforms +from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor +import PIL +from PIL import Image +from PIL.PngImagePlugin import PngInfo + +import library.model_util as model_util +import library.train_util as train_util +import library.sdxl_model_util as sdxl_model_util +import library.sdxl_train_util as sdxl_train_util +from networks.lora import LoRANetwork +import tools.original_control_net as original_control_net +from tools.original_control_net import ControlNetInfo +from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel +from library.sdxl_original_unet import InferSdxlUNet2DConditionModel +from library.original_unet import FlashAttentionFunction +from networks.control_net_lllite import ControlNetLLLite +from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + +# その他の設定 +LATENT_CHANNELS = 4 +DOWNSAMPLING_FACTOR = 8 + +CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + +# region モジュール入れ替え部 +""" +高速化のためのモジュール入れ替え +""" + + +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + print("Enable memory efficient attention for U-Net") + + # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い + unet.set_use_memory_efficient_attention(False, True) + elif xformers: + print("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + + unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + print("Enable SDPA for U-Net") + unet.set_use_memory_efficient_attention(False, False) + unet.set_use_sdpa(True) + + +# TODO common train_util.py +def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + replace_vae_attn_to_memory_efficient() + elif xformers: + # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す? + vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う + elif sdpa: + replace_vae_attn_to_sdpa() + + +def replace_vae_attn_to_memory_efficient(): + print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, hidden_states, **kwargs): + q_bucket_size = 512 + k_bucket_size = 1024 + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_flash_attn_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_flash_attn(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_flash_attn + + +def replace_vae_attn_to_xformers(): + print("VAE: Attention.forward has been replaced to xformers") + import xformers.ops + + def forward_xformers(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + query_proj = query_proj.contiguous() + key_proj = key_proj.contiguous() + value_proj = value_proj.contiguous() + out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_xformers_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_xformers(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_xformers + + +def replace_vae_attn_to_sdpa(): + print("VAE: Attention.forward has been replaced to sdpa") + + def forward_sdpa(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = torch.nn.functional.scaled_dot_product_attention( + query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + out = rearrange(out, "b n h d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_sdpa_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_sdpa(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_sdpa + + +# endregion + +# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 +# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 + + +class PipelineLike: + def __init__( + self, + is_sdxl, + device, + vae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + tokenizers: List[CLIPTokenizer], + unet: InferSdxlUNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + clip_skip: int, + ): + super().__init__() + self.is_sdxl = is_sdxl + self.device = device + self.clip_skip = clip_skip + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.vae = vae + self.text_encoders = text_encoders + self.tokenizers = tokenizers + self.unet: Union[InferUNet2DConditionModel, InferSdxlUNet2DConditionModel] = unet + self.scheduler = scheduler + self.safety_checker = None + + self.clip_vision_model: CLIPVisionModelWithProjection = None + self.clip_vision_processor: CLIPImageProcessor = None + self.clip_vision_strength = 0.0 + + # Textual Inversion + self.token_replacements_list = [] + for _ in range(len(self.text_encoders)): + self.token_replacements_list.append({}) + + # ControlNet + self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5 + self.control_net_lllites: List[ControlNetLLLite] = [] + self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + + self.gradual_latent: GradualLatent = None + + # Textual Inversion + def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): + self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids + + def set_enable_control_net(self, en: bool): + self.control_net_enabled = en + + def get_token_replacer(self, tokenizer): + tokenizer_index = self.tokenizers.index(tokenizer) + token_replacements = self.token_replacements_list[tokenizer_index] + + def replace_tokens(tokens): + # print("replace_tokens", tokens, "=>", token_replacements) + if isinstance(tokens, torch.Tensor): + tokens = tokens.tolist() + + new_tokens = [] + for token in tokens: + if token in token_replacements: + replacement = token_replacements[token] + new_tokens.extend(replacement) + else: + new_tokens.append(token) + return new_tokens + + return replace_tokens + + def set_control_nets(self, ctrl_nets): + self.control_nets = ctrl_nets + + def set_control_net_lllites(self, ctrl_net_lllites): + self.control_net_lllites = ctrl_net_lllites + + def set_gradual_latent(self, gradual_latent): + if gradual_latent is None: + print("gradual_latent is disabled") + self.gradual_latent = None + else: + print(f"gradual_latent is enabled: {gradual_latent}") + self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: int = 1024, + width: int = 1024, + original_height: int = None, + original_width: int = None, + original_height_negative: int = None, + original_width_negative: int = None, + crop_top: int = 0, + crop_left: int = 0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_scale: float = None, + strength: float = 0.8, + # num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + vae_batch_size: float = None, + return_latents: bool = False, + # return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: Optional[int] = 1, + img2img_noise=None, + clip_guide_images=None, + emb_normalize_mode: str = "original", + **kwargs, + ): + # TODO support secondary prompt + num_images_per_prompt = 1 # fixed because already prompt is repeated + + if isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + regional_network = " AND " in prompt[0] + + vae_batch_size = ( + batch_size + if vae_batch_size is None + else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) + ) + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) + + # get prompt text embeddings + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if not do_classifier_free_guidance and negative_scale is not None: + print(f"negative_scale is ignored if guidance scalle <= 1.0") + negative_scale = None + + # get unconditional embeddings for classifier free guidance + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + tes_text_embs = [] + tes_uncond_embs = [] + tes_real_uncond_embs = [] + + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + token_replacer = self.get_token_replacer(tokenizer) + + # use last text_pool, because it is from text encoder 2 + text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings( + self.is_sdxl, + tokenizer, + text_encoder, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + emb_normalize_mode=emb_normalize_mode, + **kwargs, + ) + tes_text_embs.append(text_embeddings) + tes_uncond_embs.append(uncond_embeddings) + + if negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + self.is_sdxl, + token_replacer, + prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 + uncond_prompt=[""] * batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + emb_normalize_mode=emb_normalize_mode, + **kwargs, + ) + tes_real_uncond_embs.append(real_uncond_embeddings) + + # concat text encoder outputs + text_embeddings = tes_text_embs[0] + uncond_embeddings = tes_uncond_embs[0] + for i in range(1, len(tes_text_embs)): + text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 + if do_classifier_free_guidance: + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + + if self.control_net_lllites: + # ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う + if isinstance(clip_guide_images, PIL.Image.Image): + clip_guide_images = [clip_guide_images] + if isinstance(clip_guide_images[0], PIL.Image.Image): + clip_guide_images = [preprocess_image(im) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images) + if isinstance(clip_guide_images, list): + clip_guide_images = torch.stack(clip_guide_images) + + clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) + + # create size embs + if original_height is None: + original_height = height + if original_width is None: + original_width = width + if original_height_negative is None: + original_height_negative = original_height + if original_width_negative is None: + original_width_negative = original_width + if crop_top is None: + crop_top = 0 + if crop_left is None: + crop_left = 0 + if self.is_sdxl: + emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) + uc_emb1 = sdxl_train_util.get_timestep_embedding( + torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256 + ) + emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) + emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) + c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + + if regional_network: + # use last pool for conditioning + num_sub_prompts = len(text_pool) // batch_size + text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt + + if init_image is not None and self.clip_vision_model is not None: + print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) + pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) + + clip_vision_embeddings = self.clip_vision_model( + pixel_values=pixel_values, output_hidden_states=True, return_dict=True + ) + clip_vision_embeddings = clip_vision_embeddings.image_embeds + + if len(clip_vision_embeddings) == 1 and batch_size > 1: + clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1)) + + clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength + assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}" + text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) + + c_vector = torch.cat([text_pool, c_vector], dim=1) + if do_classifier_free_guidance: + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + vector_embeddings = torch.cat([uc_vector, c_vector]) + else: + vector_embeddings = c_vector + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, self.device) + + latents_dtype = text_embeddings.dtype + init_latents_orig = None + mask = None + + if init_image is None: + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = ( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height // 8, + width // 8, + ) + + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) + else: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + timesteps = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + else: + # image to tensor + if isinstance(init_image, PIL.Image.Image): + init_image = [init_image] + if isinstance(init_image[0], PIL.Image.Image): + init_image = [preprocess_image(im) for im in init_image] + init_image = torch.cat(init_image) + if isinstance(init_image, list): + init_image = torch.stack(init_image) + + # mask image to tensor + if mask_image is not None: + if isinstance(mask_image, PIL.Image.Image): + mask_image = [mask_image] + if isinstance(mask_image[0], PIL.Image.Image): + mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint + + # encode the init image into latents and scale the latents + init_image = init_image.to(device=self.device, dtype=latents_dtype) + if init_image.size()[-2:] == (height // 8, width // 8): + init_latents = init_image + else: + if vae_batch_size >= batch_size: + init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + init_latents = [] + for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): + init_latent_dist = self.vae.encode( + (init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to( + self.vae.dtype + ) + ).latent_dist + init_latents.append(init_latent_dist.sample(generator=generator)) + init_latents = torch.cat(init_latents) + + init_latents = (sdxl_model_util.VAE_SCALE_FACTOR if self.is_sdxl else 0.18215) * init_latents + + if len(init_latents) == 1: + init_latents = init_latents.repeat((batch_size, 1, 1, 1)) + init_latents_orig = init_latents + + # preprocess mask + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=latents_dtype) + if len(mask) == 1: + mask = mask.repeat((batch_size, 1, 1, 1)) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 + + if self.control_nets: + guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) + + if self.control_net_lllites: + # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if self.control_net_enabled: + for control_net, _ in self.control_net_lllites: + with torch.no_grad(): + control_net.set_cond_image(clip_guide_images) + else: + for control_net, _ in self.control_net_lllites: + control_net.set_cond_image(None) + + each_control_net_enabled = [self.control_net_enabled] * len(self.control_net_lllites) + + enable_gradual_latent = False + if self.gradual_latent: + if not hasattr(self.scheduler, "set_gradual_latent_params"): + print("gradual_latent is not supported for this scheduler. Ignoring.") + print(self.scheduler.__class__.__name__) + else: + enable_gradual_latent = True + step_elapsed = 1000 + current_ratio = self.gradual_latent.ratio + + # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする + height, width = latents.shape[-2:] + org_dtype = latents.dtype + if org_dtype == torch.bfloat16: + latents = latents.float() + latents = torch.nn.functional.interpolate( + latents, scale_factor=current_ratio, mode="bicubic", align_corners=False + ).to(org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if self.gradual_latent.gaussian_blur_ksize: + latents = self.gradual_latent.apply_unshark_mask(latents) + + for i, t in enumerate(tqdm(timesteps)): + resized_size = None + if enable_gradual_latent: + # gradually upscale the latents / latentsを徐々にアップスケールする + if ( + t < self.gradual_latent.start_timesteps + and current_ratio < 1.0 + and step_elapsed >= self.gradual_latent.every_n_steps + ): + current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) + # make divisible by 8 because size of latents must be divisible at bottom of UNet + h = int(height * current_ratio) // 8 * 8 + w = int(width * current_ratio) // 8 * 8 + resized_size = (h, w) + self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) + step_elapsed = 0 + else: + self.scheduler.set_gradual_latent_params(None, None) + step_elapsed += 1 + + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo + if self.control_net_lllites: + for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)): + if not enabled or ratio >= 1.0: + continue + if ratio < i / len(timesteps): + print(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + control_net.set_cond_image(None) + each_control_net_enabled[j] = False + + # predict the noise residual + if self.control_nets and self.control_net_enabled: + if regional_network: + num_sub_and_neg_prompts = len(text_embeddings) // batch_size + text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt + else: + text_emb_last = text_embeddings + + noise_pred = original_control_net.call_unet_and_control_net( + i, + num_latent_input, + self.unet, + self.control_nets, + guided_hints, + i / len(timesteps), + latent_model_input, + t, + text_embeddings, + text_emb_last, + ).sample + elif self.is_sdxl: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + else: + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + if negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( + num_latent_input + ) # uncond is real uncond + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + - negative_scale * (noise_pred_negative - noise_pred_uncond) + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + if return_latents: + return latents + + latents = 1 / (sdxl_model_util.VAE_SCALE_FACTOR if self.is_sdxl else 0.18215) * latents + if vae_batch_size >= batch_size: + image = self.vae.decode(latents.to(self.vae.dtype)).sample + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + images = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + images.append( + self.vae.decode( + (latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype) + ).sample + ) + image = torch.cat(images) + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + if output_type == "pil": + # image = self.numpy_to_pil(image) + image = (image * 255).round().astype("uint8") + image = [Image.fromarray(im) for im in image] + + return image + + # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + # keep break as separate token + text = text.replace("BREAK", "\\BREAK\\") + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + if word.strip() == "BREAK": + # pad until next multiple of tokenizer's max token length + pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) + print(f"BREAK pad_len: {pad_len}") + for i in range(pad_len): + # v2のときEOSをつけるべきかどうかわからないぜ + # if i == 0: + # text_token.append(tokenizer.eos_token_id) + # else: + text_token.append(tokenizer.pad_token_id) + text_weight.append(1.0) + continue + + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + + token = token_replacer(token) # for Textual Inversion + + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + is_sdxl: bool, + text_encoder: CLIPTextModel, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + pool = None + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + # in sdxl, value of clip_skip is same for Text Encoder 1 and 2 + enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-clip_skip] + if not is_sdxl: # SD 1.5 requires final_layer_norm + text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) + if pool is None: + pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos) + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-clip_skip] + if not is_sdxl: # SD 1.5 requires final_layer_norm + text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) + pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input, eos) + return text_embeddings, pool + + +def get_weighted_text_embeddings( + is_sdxl: bool, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 1, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + clip_skip: int = 1, + token_replacer=None, + device=None, + emb_normalize_mode: Optional[str] = "original", # "original", "abs", "none" + **kwargs, +): + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + # split the prompts with "AND". each prompt must have the same number of splits + new_prompts = [] + for p in prompt: + new_prompts.extend(p.split(" AND ")) + prompt = new_prompts + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2) + else: + prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = tokenizer.bos_token_id + eos = tokenizer.eos_token_id + pad = tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) + + # get the embeddings + text_embeddings, text_pool = get_unweighted_text_embeddings( + is_sdxl, + text_encoder, + prompt_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) + if uncond_prompt is not None: + uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( + is_sdxl, + text_encoder, + uncond_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + # →全体でいいんじゃないかな + + if (not skip_parsing) and (not skip_weighting): + if emb_normalize_mode == "abs": + previous_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + elif emb_normalize_mode == "none": + text_embeddings *= prompt_weights.unsqueeze(-1) + if uncond_prompt is not None: + uncond_embeddings *= uncond_weights.unsqueeze(-1) + + else: # "original" + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens + return text_embeddings, text_pool, None, None, prompt_tokens + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +# regular expression for dynamic prompt: +# starts and ends with "{" and "}" +# contains at least one variant divided by "|" +# optional framgments divided by "$$" at start +# if the first fragment is "E" or "e", enumerate all variants +# if the second fragment is a number or two numbers, repeat the variants in the range +# if the third fragment is a string, use it as a separator + +RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") + + +def handle_dynamic_prompt_variants(prompt, repeat_count): + founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) + if not founds: + return [prompt] + + # make each replacement for each variant + enumerating = False + replacers = [] + for found in founds: + # if "e$$" is found, enumerate all variants + found_enumerating = found.group(2) is not None + enumerating = enumerating or found_enumerating + + separator = ", " if found.group(6) is None else found.group(6) + variants = found.group(7).split("|") + + # parse count range + count_range = found.group(4) + if count_range is None: + count_range = [1, 1] + else: + count_range = count_range.split("-") + if len(count_range) == 1: + count_range = [int(count_range[0]), int(count_range[0])] + elif len(count_range) == 2: + count_range = [int(count_range[0]), int(count_range[1])] + else: + print(f"invalid count range: {count_range}") + count_range = [1, 1] + if count_range[0] > count_range[1]: + count_range = [count_range[1], count_range[0]] + if count_range[0] < 0: + count_range[0] = 0 + if count_range[1] > len(variants): + count_range[1] = len(variants) + + if found_enumerating: + # make function to enumerate all combinations + def make_replacer_enum(vari, cr, sep): + def replacer(): + values = [] + for count in range(cr[0], cr[1] + 1): + for comb in itertools.combinations(vari, count): + values.append(sep.join(comb)) + return values + + return replacer + + replacers.append(make_replacer_enum(variants, count_range, separator)) + else: + # make function to choose random combinations + def make_replacer_single(vari, cr, sep): + def replacer(): + count = random.randint(cr[0], cr[1]) + comb = random.sample(vari, count) + return [sep.join(comb)] + + return replacer + + replacers.append(make_replacer_single(variants, count_range, separator)) + + # make each prompt + if not enumerating: + # if not enumerating, repeat the prompt, replace each variant randomly + prompts = [] + for _ in range(repeat_count): + current = prompt + for found, replacer in zip(founds, replacers): + current = current.replace(found.group(0), replacer()[0], 1) + prompts.append(current) + else: + # if enumerating, iterate all combinations for previous prompts + prompts = [prompt] + + for found, replacer in zip(founds, replacers): + if found.group(2) is not None: + # make all combinations for existing prompts + new_prompts = [] + for current in prompts: + replecements = replacer() + for replecement in replecements: + new_prompts.append(current.replace(found.group(0), replecement, 1)) + prompts = new_prompts + + for found, replacer in zip(founds, replacers): + # make random selection for existing prompts + if found.group(2) is None: + for i in range(len(prompts)): + prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) + + return prompts + + +# endregion + +# def load_clip_l14_336(dtype): +# print(f"loading CLIP: {CLIP_ID_L14_336}") +# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) +# return text_encoder + + +class BatchDataBase(NamedTuple): + # バッチ分割が必要ないデータ + step: int + prompt: str + negative_prompt: str + seed: int + init_image: Any + mask_image: Any + clip_prompt: str + guide_image: Any + raw_prompt: str + + +class BatchDataExt(NamedTuple): + # バッチ分割が必要なデータ + width: int + height: int + original_width: int + original_height: int + original_width_negative: int + original_height_negative: int + crop_left: int + crop_top: int + steps: int + scale: float + negative_scale: float + strength: float + network_muls: Tuple[float] + num_sub_prompts: int + + +class BatchData(NamedTuple): + return_latents: bool + base: BatchDataBase + ext: BatchDataExt + + +class ListPrompter: + def __init__(self, prompts: List[str]): + self.prompts = prompts + self.index = 0 + + def shuffle(self): + random.shuffle(self.prompts) + + def __len__(self): + return len(self.prompts) + + def __call__(self, *args, **kwargs): + if self.index >= len(self.prompts): + self.index = 0 # reset + return None + + prompt = self.prompts[self.index] + self.index += 1 + return prompt + + +def main(args): + if args.fp16: + dtype = torch.float16 + elif args.bf16: + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + highres_fix = args.highres_fix_scale is not None + # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" + + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + + # モデルを読み込む + if not os.path.exists(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う + files = glob.glob(args.ckpt) + if len(files) == 1: + args.ckpt = files[0] + + name_or_path = os.readlink(args.ckpt) if os.path.islink(args.ckpt) else args.ckpt + use_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers + + # SDXLかどうかを判定する + is_sdxl = args.sdxl + if not is_sdxl and not args.v1 and not args.v2: # どれも指定されていない場合は自動で判定する + if use_stable_diffusion_format: + # if file size > 5.5GB, sdxl + is_sdxl = os.path.getsize(name_or_path) > 5.5 * 1024**3 + else: + # if `text_encoder_2` subdirectory exists, sdxl + is_sdxl = os.path.isdir(os.path.join(name_or_path, "text_encoder_2")) + print(f"SDXL: {is_sdxl}") + + if is_sdxl: + if args.clip_skip is None: + args.clip_skip = 2 + + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype + ) + unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) + text_encoders = [text_encoder1, text_encoder2] + else: + if args.clip_skip is None: + args.clip_skip = 2 if args.v2 else 1 + + if use_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) + else: + print("load Diffusers pretrained models") + loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) + text_encoder = loading_pipe.text_encoder + vae = loading_pipe.vae + unet = loading_pipe.unet + tokenizer = loading_pipe.tokenizer + del loading_pipe + + # Diffusers U-Net to original U-Net + original_unet = UNet2DConditionModel( + unet.config.sample_size, + unet.config.attention_head_dim, + unet.config.cross_attention_dim, + unet.config.use_linear_projection, + unet.config.upcast_attention, + ) + original_unet.load_state_dict(unet.state_dict()) + unet = original_unet + unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet) + text_encoders = [text_encoder] + + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, dtype) + print("additional VAE loaded") + + # xformers、Hypernetwork対応 + if not args.diffusers_xformers: + mem_eff = not (args.xformers or args.sdpa) + replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) + replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) + + # tokenizerを読み込む + print("loading tokenizer") + if is_sdxl: + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenizers = [tokenizer1, tokenizer2] + else: + if use_stable_diffusion_format: + tokenizer = train_util.load_tokenizer(args) + tokenizers = [tokenizer] + + # schedulerを用意する + sched_init_args = {} + has_steps_offset = True + has_clip_sample = True + scheduler_num_noises_per_step = 1 + + if args.sampler == "ddim": + scheduler_cls = DDIMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddim + elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddpm + elif args.sampler == "pndm": + scheduler_cls = PNDMScheduler + scheduler_module = diffusers.schedulers.scheduling_pndm + has_clip_sample = False + elif args.sampler == "lms" or args.sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_lms_discrete + has_clip_sample = False + elif args.sampler == "euler" or args.sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_discrete + has_clip_sample = False + elif args.sampler == "euler_a" or args.sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteSchedulerGL + scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete + has_clip_sample = False + elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = args.sampler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep + has_clip_sample = False + elif args.sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep + has_clip_sample = False + has_steps_offset = False + elif args.sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_heun_discrete + has_clip_sample = False + elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete + has_clip_sample = False + elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete + scheduler_num_noises_per_step = 2 + has_clip_sample = False + + if args.v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" + + # 警告を出さないようにする + if has_steps_offset: + sched_init_args["steps_offset"] = 1 + if has_clip_sample: + sched_init_args["clip_sample"] = False + + # samplerの乱数をあらかじめ指定するための処理 + + # replace randn + class NoiseManager: + def __init__(self): + self.sampler_noises = None + self.sampler_noise_index = 0 + + def reset_sampler_noises(self, noises): + self.sampler_noise_index = 0 + self.sampler_noises = noises + + def randn(self, shape, device=None, dtype=None, layout=None, generator=None): + # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): + noise = self.sampler_noises[self.sampler_noise_index] + if shape != noise.shape: + noise = None + else: + noise = None + + if noise == None: + print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) + + self.sampler_noise_index += 1 + return noise + + class TorchRandReplacer: + def __init__(self, noise_manager): + self.noise_manager = noise_manager + + def __getattr__(self, item): + if item == "randn": + return self.noise_manager.randn + if hasattr(torch, item): + return getattr(torch, item) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + noise_manager = NoiseManager() + if scheduler_module is not None: + scheduler_module.torch = TorchRandReplacer(noise_manager) + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # ↓以下は結局PipeでFalseに設定されるので意味がなかった + # # clip_sample=Trueにする + # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # print("set clip_sample to True") + # scheduler.config.clip_sample = True + + # deviceを決定する + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + + # custom pipelineをコピったやつを生成する + if args.vae_slices: + from library.slicing_vae import SlicingAutoencoderKL + + sli_vae = SlicingAutoencoderKL( + act_fn="silu", + block_out_channels=(128, 256, 512, 512), + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + in_channels=3, + latent_channels=4, + layers_per_block=2, + norm_num_groups=32, + out_channels=3, + sample_size=512, + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + num_slices=args.vae_slices, + ) + sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする + vae = sli_vae + del sli_vae + + vae_dtype = dtype + if args.no_half_vae: + print("set vae_dtype to float32") + vae_dtype = torch.float32 + vae.to(vae_dtype).to(device) + vae.eval() + + for text_encoder in text_encoders: + text_encoder.to(dtype).to(device) + text_encoder.eval() + unet.to(dtype).to(device) + unet.eval() + + # networkを組み込む + if args.network_module: + networks = [] + network_default_muls = [] + network_pre_calc = args.network_pre_calc + + # merge関連の引数を統合する + if args.network_merge: + network_merge = len(args.network_module) # all networks are merged + elif args.network_merge_n_models: + network_merge = args.network_merge_n_models + else: + network_merge = 0 + print(f"network_merge: {network_merge}") + + for i, network_module in enumerate(args.network_module): + print("import network module:", network_module) + imported_module = importlib.import_module(network_module) + + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + # TODO escape special chars + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + if args.network_weights is None or len(args.network_weights) <= i: + raise ValueError("No weight. Weight is required.") + + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) + + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open + + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") + + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, text_encoders, unet, for_inference=True, **net_kwargs + ) + if network is None: + return + + mergeable = network.is_mergeable() + if network_merge and not mergeable: + print("network is not mergiable. ignore merge option.") + + if not mergeable or i >= network_merge: + # not merging + network.apply_to(text_encoders, unet) + info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい + print(f"weights are loaded: {info}") + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + if network_pre_calc: + print("backup original weights") + network.backup_weights() + + networks.append(network) + network_default_muls.append(network_mul) + else: + network.merge_to(text_encoders, unet, weights_sd, dtype, device) + + else: + networks = [] + + # upscalerの指定があれば取得する + upscaler = None + if args.highres_fix_upscaler: + print("import upscaler module:", args.highres_fix_upscaler) + imported_module = importlib.import_module(args.highres_fix_upscaler) + + us_kwargs = {} + if args.highres_fix_upscaler_args: + for net_arg in args.highres_fix_upscaler_args.split(";"): + key, value = net_arg.split("=") + us_kwargs[key] = value + + print("create upscaler") + upscaler = imported_module.create_upscaler(**us_kwargs) + upscaler.to(dtype).to(device) + + # ControlNetの処理 + control_nets: List[ControlNetInfo] = [] + if args.control_net_models: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + + control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] + if args.control_net_lllite_models: + for i, model_file in enumerate(args.control_net_lllite_models): + print(f"loading ControlNet-LLLite: {model_file}") + + from safetensors.torch import load_file + + state_dict = load_file(model_file) + mlp_dim = None + cond_emb_dim = None + for key, value in state_dict.items(): + if mlp_dim is None and "down.0.weight" in key: + mlp_dim = value.shape[0] + elif cond_emb_dim is None and "conditioning1.0" in key: + cond_emb_dim = value.shape[0] * 2 + if mlp_dim is not None and cond_emb_dim is not None: + break + assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}" + + multiplier = ( + 1.0 + if not args.control_net_multipliers or len(args.control_net_multipliers) <= i + else args.control_net_multipliers[i] + ) + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + control_net_lllite = ControlNetLLLite(unet, cond_emb_dim, mlp_dim, multiplier=multiplier) + control_net_lllite.apply_to() + control_net_lllite.load_state_dict(state_dict) + control_net_lllite.to(dtype).to(device) + control_net_lllite.set_batch_cond_only(False, False) + control_net_lllites.append((control_net_lllite, ratio)) + assert ( + len(control_nets) == 0 or len(control_net_lllites) == 0 + ), "ControlNet and ControlNet-LLLite cannot be used at the same time" + + if args.opt_channels_last: + print(f"set optimizing: channels last") + for text_encoder in text_encoders: + text_encoder.to(memory_format=torch.channels_last) + vae.to(memory_format=torch.channels_last) + unet.to(memory_format=torch.channels_last) + if networks: + for network in networks: + network.to(memory_format=torch.channels_last) + + for cn in control_nets: + cn.to(memory_format=torch.channels_last) + + for cn in control_net_lllites: + cn.to(memory_format=torch.channels_last) + + pipe = PipelineLike( + is_sdxl, + device, + vae, + text_encoders, + tokenizers, + unet, + scheduler, + args.clip_skip, + ) + pipe.set_control_nets(control_nets) + pipe.set_control_net_lllites(control_net_lllites) + print("pipeline is ready.") + + if args.diffusers_xformers: + pipe.enable_xformers_memory_efficient_attention() + + # Deep Shrink + if args.ds_depth_1 is not None: + unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + + # Gradual Latent + if args.gradual_latent_timesteps is not None: + if args.gradual_latent_unsharp_params: + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + + gradual_latent = GradualLatent( + args.gradual_latent_ratio, + args.gradual_latent_timesteps, + args.gradual_latent_every_n_steps, + args.gradual_latent_ratio_step, + args.gradual_latent_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + + # Textual Inversionを処理する + if args.textual_inversion_embeddings: + token_ids_embeds1 = [] + token_ids_embeds2 = [] + for embeds_file in args.textual_inversion_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + + if "string_to_param" in data: + data = data["string_to_param"] + if is_sdxl: + + embeds1 = data["clip_l"] # text encoder 1 + embeds2 = data["clip_g"] # text encoder 2 + else: + embeds1 = next(iter(data.values())) + embeds2 = None + + num_vectors_per_token = embeds1.size()[0] + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens1 = tokenizers[0].add_tokens(token_strings) + num_added_tokens2 = tokenizers[1].add_tokens(token_strings) if is_sdxl else 0 + assert num_added_tokens1 == num_vectors_per_token and ( + num_added_tokens2 == 0 or num_added_tokens2 == num_vectors_per_token + ), ( + f"tokenizer has same word to token string (filename): {embeds_file}" + + f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}" + ) + + token_ids1 = tokenizers[0].convert_tokens_to_ids(token_strings) + token_ids2 = tokenizers[1].convert_tokens_to_ids(token_strings) if is_sdxl else None + print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + assert ( + min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 + ), f"token ids1 is not ordered" + assert not is_sdxl or ( + min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 + ), f"token ids2 is not ordered" + assert len(tokenizers[0]) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizers[0])}" + assert ( + not is_sdxl or len(tokenizers[1]) - 1 == token_ids2[-1] + ), f"token ids 2 is not end of tokenize: {len(tokenizers[1])}" + + if num_vectors_per_token > 1: + pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ... + if is_sdxl: + pipe.add_token_replacement(1, token_ids2[0], token_ids2) + + token_ids_embeds1.append((token_ids1, embeds1)) + if is_sdxl: + token_ids_embeds2.append((token_ids2, embeds2)) + + text_encoders[0].resize_token_embeddings(len(tokenizers[0])) + token_embeds1 = text_encoders[0].get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds1: + for token_id, embed in zip(token_ids, embeds): + token_embeds1[token_id] = embed + + if is_sdxl: + text_encoders[1].resize_token_embeddings(len(tokenizers[1])) + token_embeds2 = text_encoders[1].get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds2: + for token_id, embed in zip(token_ids, embeds): + token_embeds2[token_id] = embed + + # promptを取得する + prompt_list = None + if args.from_file is not None: + print(f"reading prompts from {args.from_file}") + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_list = f.read().splitlines() + prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] + prompter = ListPrompter(prompt_list) + + elif args.from_module is not None: + + def load_module_from_path(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Module '{module_name}' cannot be loaded from '{file_path}'") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + print(f"reading prompts from module: {args.from_module}") + prompt_module = load_module_from_path("prompt_module", args.from_module) + + prompter = prompt_module.get_prompter(args, pipe, networks) + + elif args.prompt is not None: + prompter = ListPrompter([args.prompt]) + + else: + prompter = None # interactive mode + + if args.interactive: + args.n_iter = 1 + + # img2imgの前処理、画像の読み込みなど + def load_images(path): + if os.path.isfile(path): + paths = [path] + else: + paths = ( + glob.glob(os.path.join(path, "*.png")) + + glob.glob(os.path.join(path, "*.jpg")) + + glob.glob(os.path.join(path, "*.jpeg")) + + glob.glob(os.path.join(path, "*.webp")) + ) + paths.sort() + + images = [] + for p in paths: + image = Image.open(p) + if image.mode != "RGB": + print(f"convert image to RGB from {image.mode}: {p}") + image = image.convert("RGB") + images.append(image) + + return images + + def resize_images(imgs, size): + resized = [] + for img in imgs: + r_img = img.resize(size, Image.Resampling.LANCZOS) + if hasattr(img, "filename"): # filename属性がない場合があるらしい + r_img.filename = img.filename + resized.append(r_img) + return resized + + if args.image_path is not None: + print(f"load image for img2img: {args.image_path}") + init_images = load_images(args.image_path) + assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" + print(f"loaded {len(init_images)} images for img2img") + + # CLIP Vision + if args.clip_vision_strength is not None: + print(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) + vision_model.to(device, dtype) + processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) + + pipe.clip_vision_model = vision_model + pipe.clip_vision_processor = processor + pipe.clip_vision_strength = args.clip_vision_strength + print(f"CLIP Vision model loaded.") + + else: + init_images = None + + if args.mask_path is not None: + print(f"load mask for inpainting: {args.mask_path}") + mask_images = load_images(args.mask_path) + assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" + print(f"loaded {len(mask_images)} mask images for inpainting") + else: + mask_images = None + + # promptがないとき、画像のPngInfoから取得する + if init_images is not None and prompter is None and not args.interactive: + print("get prompts from images' metadata") + prompt_list = [] + for img in init_images: + if "prompt" in img.text: + prompt = img.text["prompt"] + if "negative-prompt" in img.text: + prompt += " --n " + img.text["negative-prompt"] + prompt_list.append(prompt) + prompter = ListPrompter(prompt_list) + + # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) + l = [] + for im in init_images: + l.extend([im] * args.images_per_prompt) + init_images = l + + if mask_images is not None: + l = [] + for im in mask_images: + l.extend([im] * args.images_per_prompt) + mask_images = l + + # 画像サイズにオプション指定があるときはリサイズする + if args.W is not None and args.H is not None: + # highres fix を考慮に入れる + w, h = args.W, args.H + if highres_fix: + w = int(w * args.highres_fix_scale + 0.5) + h = int(h * args.highres_fix_scale + 0.5) + + if init_images is not None: + print(f"resize img2img source images to {w}*{h}") + init_images = resize_images(init_images, (w, h)) + if mask_images is not None: + print(f"resize img2img mask images to {w}*{h}") + mask_images = resize_images(mask_images, (w, h)) + + regional_network = False + if networks and mask_images: + # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 + regional_network = True + print("use mask as region") + + size = None + for i, network in enumerate(networks): + if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: + np_mask = np.array(mask_images[0]) + + if args.network_regional_mask_max_color_codes: + # カラーコードでマスクを指定する + ch0 = (i + 1) & 1 + ch1 = ((i + 1) >> 1) & 1 + ch2 = ((i + 1) >> 2) & 1 + np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) + np_mask = np_mask.astype(np.uint8) * 255 + else: + np_mask = np_mask[:, :, i] + size = np_mask.shape + else: + np_mask = np.full(size, 255, dtype=np.uint8) + mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) + network.set_region(i, i == len(networks) - 1, mask) + mask_images = None + + prev_image = None # for VGG16 guided + if args.guide_image_path is not None: + print(f"load image for ControlNet guidance: {args.guide_image_path}") + guide_images = [] + for p in args.guide_image_path: + guide_images.extend(load_images(p)) + + print(f"loaded {len(guide_images)} guide images for guidance") + if len(guide_images) == 0: + print( + f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" + ) + guide_images = None + else: + guide_images = None + + # 新しい乱数生成器を作成する + if args.seed is not None: + if prompt_list and len(prompt_list) == 1 and args.images_per_prompt == 1: + # 引数のseedをそのまま使う + def fixed_seed(*args, **kwargs): + return args.seed + + seed_random = SimpleNamespace(randint=fixed_seed) + else: + seed_random = random.Random(args.seed) + else: + seed_random = random.Random() + + # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) + if args.W is None: + args.W = 1024 if is_sdxl else 512 + if args.H is None: + args.H = 1024 if is_sdxl else 512 + + # 画像生成のループ + os.makedirs(args.outdir, exist_ok=True) + max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples + + for gen_iter in range(args.n_iter): + print(f"iteration {gen_iter+1}/{args.n_iter}") + if args.iter_same_seed: + iter_seed = seed_random.randint(0, 2**32 - 1) + else: + iter_seed = None + + # shuffle prompt list + if args.shuffle_prompts: + prompter.shuffle() + + # バッチ処理の関数 + def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): + batch_size = len(batch) + + # highres_fixの処理 + if highres_fix and not highres_1st: + # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す + is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling + + print("process 1st stage") + batch_1st = [] + for _, base, ext in batch: + + def scale_and_round(x): + if x is None: + return None + return int(x * args.highres_fix_scale + 0.5) + + width_1st = scale_and_round(ext.width) + height_1st = scale_and_round(ext.height) + width_1st = width_1st - width_1st % 32 + height_1st = height_1st - height_1st % 32 + + original_width_1st = scale_and_round(ext.original_width) + original_height_1st = scale_and_round(ext.original_height) + original_width_negative_1st = scale_and_round(ext.original_width_negative) + original_height_negative_1st = scale_and_round(ext.original_height_negative) + crop_left_1st = scale_and_round(ext.crop_left) + crop_top_1st = scale_and_round(ext.crop_top) + + strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength + + ext_1st = BatchDataExt( + width_1st, + height_1st, + original_width_1st, + original_height_1st, + original_width_negative_1st, + original_height_negative_1st, + crop_left_1st, + crop_top_1st, + args.highres_fix_steps, + ext.scale, + ext.negative_scale, + strength_1st, + ext.network_muls, + ext.num_sub_prompts, + ) + batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) + + pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする + images_1st = process_batch(batch_1st, True, True) + + # 2nd stageのバッチを作成して以下処理する + print("process 2nd stage") + width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height + + if upscaler: + # upscalerを使って画像を拡大する + lowreso_imgs = None if is_1st_latent else images_1st + lowreso_latents = None if not is_1st_latent else images_1st + + # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents + batch_size = len(images_1st) + vae_batch_size = ( + batch_size + if args.vae_batch_size is None + else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) + ) + vae_batch_size = int(vae_batch_size) + images_1st = upscaler.upscale( + vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size + ) + + elif args.highres_fix_latents_upscaling: + # latentを拡大する + org_dtype = images_1st.dtype + if images_1st.dtype == torch.bfloat16: + images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない + images_1st = torch.nn.functional.interpolate( + images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" + ) # , antialias=True) + images_1st = images_1st.to(org_dtype) + + else: + # 画像をLANCZOSで拡大する + images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] + + batch_2nd = [] + for i, (bd, image) in enumerate(zip(batch, images_1st)): + bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) + batch_2nd.append(bd_2nd) + batch = batch_2nd + + if args.highres_fix_disable_control_net: + pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする + + # このバッチの情報を取り出す + ( + return_latents, + (step_first, _, _, _, init_image, mask_image, _, guide_image, _), + ( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + network_muls, + num_sub_prompts, + ), + ) = batch[0] + noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + + prompts = [] + negative_prompts = [] + raw_prompts = [] + start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + noises = [ + torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + for _ in range(steps * scheduler_num_noises_per_step) + ] + seeds = [] + clip_prompts = [] + + if init_image is not None: # img2img? + i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + init_images = [] + + if mask_image is not None: + mask_images = [] + else: + mask_images = None + else: + i2i_noises = None + init_images = None + mask_images = None + + if guide_image is not None: # CLIP image guided? + guide_images = [] + else: + guide_images = None + + # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする + all_images_are_same = True + all_masks_are_same = True + all_guide_images_are_same = True + for i, ( + _, + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + _, + ) in enumerate(batch): + prompts.append(prompt) + negative_prompts.append(negative_prompt) + seeds.append(seed) + clip_prompts.append(clip_prompt) + raw_prompts.append(raw_prompt) + + if init_image is not None: + init_images.append(init_image) + if i > 0 and all_images_are_same: + all_images_are_same = init_images[-2] is init_image + + if mask_image is not None: + mask_images.append(mask_image) + if i > 0 and all_masks_are_same: + all_masks_are_same = mask_images[-2] is mask_image + + if guide_image is not None: + if type(guide_image) is list: + guide_images.extend(guide_image) + all_guide_images_are_same = False + else: + guide_images.append(guide_image) + if i > 0 and all_guide_images_are_same: + all_guide_images_are_same = guide_images[-2] is guide_image + + # make start code + torch.manual_seed(seed) + start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + # make each noises + for j in range(steps * scheduler_num_noises_per_step): + noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) + + if i2i_noises is not None: # img2img noise + i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + noise_manager.reset_sampler_noises(noises) + + # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する + if init_images is not None and all_images_are_same: + init_images = init_images[0] + if mask_images is not None and all_masks_are_same: + mask_images = mask_images[0] + if guide_images is not None and all_guide_images_are_same: + guide_images = guide_images[0] + + # ControlNet使用時はguide imageをリサイズする + if control_nets or control_net_lllites: + # TODO resampleのメソッド + guide_images = guide_images if type(guide_images) == list else [guide_images] + guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] + if len(guide_images) == 1: + guide_images = guide_images[0] + + # generate + if networks: + # 追加ネットワークの処理 + shared = {} + for n, m in zip(networks, network_muls if network_muls else network_default_muls): + n.set_multiplier(m) + if regional_network: + # TODO バッチから ds_ratio を取り出すべき + n.set_current_generation(batch_size, num_sub_prompts, width, height, shared, unet.ds_ratio) + + if not regional_network and network_pre_calc: + for n in networks: + n.restore_weights() + for n in networks: + n.pre_calculation() + print("pre-calculation... done") + + images = pipe( + prompts, + negative_prompts, + init_images, + mask_images, + height, + width, + original_height, + original_width, + original_height_negative, + original_width_negative, + crop_top, + crop_left, + steps, + scale, + negative_scale, + strength, + latents=start_code, + output_type="pil", + max_embeddings_multiples=max_embeddings_multiples, + img2img_noise=i2i_noises, + vae_batch_size=args.vae_batch_size, + return_latents=return_latents, + clip_prompts=clip_prompts, + clip_guide_images=guide_images, + emb_normalize_mode=args.emb_normalize_mode, + ) + if highres_1st and not args.highres_fix_save_1st: # return images or latents + return images + + # save image + highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) + ): + if highres_fix: + seed -= 1 # record original seed + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_prompt is not None: + metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) + if is_sdxl: + metadata.add_text("original-height", str(original_height)) + metadata.add_text("original-width", str(original_width)) + metadata.add_text("original-height-negative", str(original_height_negative)) + metadata.add_text("original-width-negative", str(original_width_negative)) + metadata.add_text("crop-top", str(crop_top)) + metadata.add_text("crop-left", str(crop_left)) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and not highres_1st and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + print( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) + + return images + + # 画像生成のプロンプトが一周するまでのループ + prompt_index = 0 + global_step = 0 + batch_data = [] + while True: + if args.interactive: + # interactive + valid = False + while not valid: + print("\nType prompt:") + try: + raw_prompt = input() + except EOFError: + break + + valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 + if not valid: # EOF, end app + break + else: + raw_prompt = prompter(args, pipe, seed_random, iter_seed, prompt_index, global_step) + if raw_prompt is None: + break + + # sd-dynamic-prompts like variants: + # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) + raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) + + # repeat prompt + for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + + if pi == 0 or len(raw_prompts) > 1: + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + original_width = args.original_width + original_height = args.original_height + original_width_negative = args.original_width_negative + original_height_negative = args.original_height_negative + crop_top = args.crop_top + crop_left = args.crop_left + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + length = len(prompter) if hasattr(prompter, "__len__") else 0 + print(f"prompt {prompt_index+1}/{length}: {prompt}") + + for parg in prompt_args[1:]: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + print(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + print(f"height: {height}") + continue + + m = re.match(r"ow (\d+)", parg, re.IGNORECASE) + if m: + original_width = int(m.group(1)) + print(f"original width: {original_width}") + continue + + m = re.match(r"oh (\d+)", parg, re.IGNORECASE) + if m: + original_height = int(m.group(1)) + print(f"original height: {original_height}") + continue + + m = re.match(r"nw (\d+)", parg, re.IGNORECASE) + if m: + original_width_negative = int(m.group(1)) + print(f"original width negative: {original_width_negative}") + continue + + m = re.match(r"nh (\d+)", parg, re.IGNORECASE) + if m: + original_height_negative = int(m.group(1)) + print(f"original height negative: {original_height_negative}") + continue + + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) + if m: + crop_top = int(m.group(1)) + print(f"crop top: {crop_top}") + continue + + m = re.match(r"cl (\d+)", parg, re.IGNORECASE) + if m: + crop_left = int(m.group(1)) + print(f"crop left: {crop_left}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + print(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + print(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + print(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + print(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + print(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + print(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + print(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + print(f"network mul: {network_muls}") + continue + + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + print(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + print(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + + # override Gradual Latent + if gl_timesteps is not None: + if gl_timesteps < 0: + gl_timesteps = args.gradual_latent_timesteps or 650 + if gl_unsharp_params is not None: + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + gradual_latent = GradualLatent( + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + + # prepare seed + if seeds is not None: # given in prompt + # num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う + if len(seeds) > 0: + seed = seeds.pop(0) + else: + if args.iter_same_seed: + seed = iter_seed + else: + seed = None # 前のを消す + + if seed is None: + seed = seed_random.randint(0, 2**32 - 1) + if args.interactive: + print(f"seed: {seed}") + + # prepare init image, guide image and mask + init_image = mask_image = guide_image = None + + # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する + if init_images is not None: + init_image = init_images[global_step % len(init_images)] + + # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する + # 32単位に丸めたやつにresizeされるので踏襲する + if not highres_fix: + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + print( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" + ) + + if mask_images is not None: + mask_image = mask_images[global_step % len(mask_images)] + + if guide_images is not None: + if control_nets or control_net_lllites: # 複数件の場合あり + c = max(len(control_nets), len(control_net_lllites)) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c : p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] + + if regional_network: + num_sub_prompts = len(prompt.split(" AND ")) + assert ( + len(networks) <= num_sub_prompts + ), "Number of networks must be less than or equal to number of sub prompts." + else: + num_sub_prompts = None + + b1 = BatchData( + False, + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), + BatchDataExt( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + tuple(network_muls) if network_muls else None, + num_sub_prompts, + ), + ) + if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? + process_batch(batch_data, highres_fix) + batch_data.clear() + + batch_data.append(b1) + if len(batch_data) == args.batch_size: + prev_image = process_batch(batch_data, highres_fix)[0] + batch_data.clear() + + global_step += 1 + + prompt_index += 1 + + if len(batch_data) > 0: + process_batch(batch_data, highres_fix) + batch_data.clear() + + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + parser.add_argument( + "--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む" + ) + parser.add_argument( + "--v1", action="store_true", help="load Stable Diffusion v1.x model / Stable Diffusion 1.xのモデルを読み込む" + ) + parser.add_argument( + "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" + ) + parser.add_argument( + "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" + ) + + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") + parser.add_argument( + "--from_file", + type=str, + default=None, + help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", + ) + parser.add_argument( + "--from_module", + type=str, + default=None, + help="if specified, load prompts from this module / 指定時はプロンプトをモジュールから読み込む", + ) + parser.add_argument( + "--prompter_module_args", type=str, default=None, help="args for prompter module / prompterモジュールの引数" + ) + parser.add_argument( + "--interactive", + action="store_true", + help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", + ) + parser.add_argument( + "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" + ) + parser.add_argument( + "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" + ) + parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") + parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") + parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") + parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") + parser.add_argument( + "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" + ) + parser.add_argument( + "--use_original_file_name", + action="store_true", + help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", + ) + # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) + parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") + parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") + parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") + parser.add_argument( + "--original_height", + type=int, + default=None, + help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値", + ) + parser.add_argument( + "--original_width", + type=int, + default=None, + help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値", + ) + parser.add_argument( + "--original_height_negative", + type=int, + default=None, + help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値", + ) + parser.add_argument( + "--original_width_negative", + type=int, + default=None, + help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", + ) + parser.add_argument( + "--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値" + ) + parser.add_argument( + "--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値" + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") + parser.add_argument( + "--vae_batch_size", + type=float, + default=None, + help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", + ) + parser.add_argument( + "--vae_slices", + type=int, + default=None, + help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", + ) + parser.add_argument( + "--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない" + ) + parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") + parser.add_argument( + "--sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", + ) + parser.add_argument( + "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) + # parser.add_argument("--replace_clip_l14_336", action='store_true', + # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") + parser.add_argument( + "--seed", + type=int, + default=None, + help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", + ) + parser.add_argument( + "--iter_same_seed", + action="store_true", + help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", + ) + parser.add_argument( + "--shuffle_prompts", + action="store_true", + help="shuffle prompts in iteration / 繰り返し内のプロンプトをシャッフルする", + ) + parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") + parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") + parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") + parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") + parser.add_argument( + "--diffusers_xformers", + action="store_true", + help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", + ) + parser.add_argument( + "--opt_channels_last", + action="store_true", + help="set channels last option to model / モデルにchannels lastを指定し最適化する", + ) + parser.add_argument( + "--network_module", + type=str, + default=None, + nargs="*", + help="additional network module to use / 追加ネットワークを使う時そのモジュール名", + ) + parser.add_argument( + "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" + ) + parser.add_argument( + "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" + ) + parser.add_argument( + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", + ) + parser.add_argument( + "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" + ) + parser.add_argument( + "--network_merge_n_models", + type=int, + default=None, + help="merge this number of networks / この数だけネットワークをマージする", + ) + parser.add_argument( + "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" + ) + parser.add_argument( + "--network_pre_calc", + action="store_true", + help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", + ) + parser.add_argument( + "--network_regional_mask_max_color_codes", + type=int, + default=None, + help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", + ) + parser.add_argument( + "--textual_inversion_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", + ) + parser.add_argument( + "--clip_skip", + type=int, + default=None, + help="layer number from bottom to use in CLIP, default is 1 for SD1/2, 2 for SDXL " + + "/ CLIPの後ろからn層目の出力を使う(デフォルトはSD1/2の場合1、SDXLの場合2)", + ) + parser.add_argument( + "--max_embeddings_multiples", + type=int, + default=None, + help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", + ) + parser.add_argument( + "--emb_normalize_mode", + type=str, + default="original", + choices=["original", "none", "abs"], + help="embedding normalization mode / embeddingの正規化モード", + ) + parser.add_argument( + "--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像" + ) + parser.add_argument( + "--highres_fix_scale", + type=float, + default=None, + help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", + ) + parser.add_argument( + "--highres_fix_steps", + type=int, + default=28, + help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", + ) + parser.add_argument( + "--highres_fix_strength", + type=float, + default=None, + help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", + ) + parser.add_argument( + "--highres_fix_save_1st", + action="store_true", + help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", + ) + parser.add_argument( + "--highres_fix_latents_upscaling", + action="store_true", + help="use latents upscaling for highres fix / highres fixでlatentで拡大する", + ) + parser.add_argument( + "--highres_fix_upscaler", + type=str, + default=None, + help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", + ) + parser.add_argument( + "--highres_fix_upscaler_args", + type=str, + default=None, + help="additional arguments for upscaler (key=value) / upscalerへの追加の引数", + ) + parser.add_argument( + "--highres_fix_disable_control_net", + action="store_true", + help="disable ControlNet for highres fix / highres fixでControlNetを使わない", + ) + + parser.add_argument( + "--negative_scale", + type=float, + default=None, + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", + ) + + parser.add_argument( + "--control_net_lllite_models", + type=str, + default=None, + nargs="*", + help="ControlNet models to use / 使用するControlNetのモデル名", + ) + parser.add_argument( + "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + ) + parser.add_argument( + "--control_net_preps", + type=str, + default=None, + nargs="*", + help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名", + ) + parser.add_argument( + "--control_net_multipliers", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率" + ) + parser.add_argument( + "--control_net_ratios", + type=float, + default=None, + nargs="*", + help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + ) + parser.add_argument( + "--clip_vision_strength", + type=float, + default=None, + help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", + ) + + # Deep Shrink + parser.add_argument( + "--ds_depth_1", + type=int, + default=None, + help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする", + ) + parser.add_argument( + "--ds_timesteps_1", + type=int, + default=650, + help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", + ) + parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") + parser.add_argument( + "--ds_timesteps_2", + type=int, + default=650, + help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", + ) + parser.add_argument( + "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" + ) + + # gradual latent + parser.add_argument( + "--gradual_latent_timesteps", + type=int, + default=None, + help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", + ) + parser.add_argument( + "--gradual_latent_ratio", + type=float, + default=0.5, + help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", + ) + parser.add_argument( + "--gradual_latent_ratio_step", + type=float, + default=0.125, + help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", + ) + parser.add_argument( + "--gradual_latent_every_n_steps", + type=int, + default=3, + help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", + ) + parser.add_argument( + "--gradual_latent_s_noise", + type=float, + default=1.0, + help="s_noise for Gradual Latent / Gradual Latentのs_noise", + ) + parser.add_argument( + "--gradual_latent_unsharp_params", + type=str, + default=None, + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", + ) + + # # parser.add_argument( + # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" + # ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index a207ad5a1..2c5f84a93 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -64,10 +64,9 @@ import diffusers import numpy as np -import torch - -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory, get_preferred_device init_ipex() import torchvision @@ -102,8 +101,15 @@ from tools.original_control_net import ControlNetInfo from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel from library.original_unet import FlashAttentionFunction +from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -139,12 +145,12 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -152,7 +158,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_memory_efficient_attention(False, False) unet.set_use_sdpa(True) @@ -168,7 +174,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform def replace_vae_attn_to_memory_efficient(): - print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states, **kwargs): @@ -224,7 +230,7 @@ def forward_flash_attn_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_xformers(): - print("VAE: Attention.forward has been replaced to xformers") + logger.info("VAE: Attention.forward has been replaced to xformers") import xformers.ops def forward_xformers(self, hidden_states, **kwargs): @@ -280,7 +286,7 @@ def forward_xformers_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_sdpa(): - print("VAE: Attention.forward has been replaced to sdpa") + logger.info("VAE: Attention.forward has been replaced to sdpa") def forward_sdpa(self, hidden_states, **kwargs): residual = hidden_states @@ -449,6 +455,8 @@ def __init__( self.control_nets: List[ControlNetInfo] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + self.gradual_latent: GradualLatent = None + # Textual Inversion def add_token_replacement(self, target_token_id, rep_token_ids): self.token_replacements[target_token_id] = rep_token_ids @@ -479,6 +487,14 @@ def add_token_replacement_XTI(self, target_token_id, rep_token_ids): def set_control_nets(self, ctrl_nets): self.control_nets = ctrl_nets + def set_gradual_latent(self, gradual_latent): + if gradual_latent is None: + print("gradual_latent is disabled") + self.gradual_latent = None + else: + print(f"gradual_latent is enabled: {gradual_latent}") + self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) + # region xformersとか使う部分:独自に書き換えるので関係なし def enable_xformers_memory_efficient_attention(self): @@ -684,7 +700,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") + logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0") negative_scale = None # get unconditional embeddings for classifier free guidance @@ -766,11 +782,11 @@ def __call__( clip_text_input = prompt_tokens if clip_text_input.shape[1] > self.tokenizer.model_max_length: # TODO 75文字を超えたら警告を出す? - print("trim text input", clip_text_input.shape) + logger.info(f"trim text input {clip_text_input.shape}") clip_text_input = torch.cat( [clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1 ) - print("trimmed", clip_text_input.shape) + logger.info(f"trimmed {clip_text_input.shape}") for i, clip_prompt in enumerate(clip_prompts): if clip_prompt is not None: # clip_promptがあれば上書きする @@ -888,8 +904,7 @@ def __call__( init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample(generator=generator) else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() init_latents = [] for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): init_latent_dist = self.vae.encode( @@ -953,7 +968,49 @@ def __call__( else: text_emb_last = text_embeddings + enable_gradual_latent = False + if self.gradual_latent: + if not hasattr(self.scheduler, "set_gradual_latent_params"): + print("gradual_latent is not supported for this scheduler. Ignoring.") + print(self.scheduler.__class__.__name__) + else: + enable_gradual_latent = True + step_elapsed = 1000 + current_ratio = self.gradual_latent.ratio + + # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする + height, width = latents.shape[-2:] + org_dtype = latents.dtype + if org_dtype == torch.bfloat16: + latents = latents.float() + latents = torch.nn.functional.interpolate( + latents, scale_factor=current_ratio, mode="bicubic", align_corners=False + ).to(org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if self.gradual_latent.gaussian_blur_ksize: + latents = self.gradual_latent.apply_unshark_mask(latents) + for i, t in enumerate(tqdm(timesteps)): + resized_size = None + if enable_gradual_latent: + # gradually upscale the latents / latentsを徐々にアップスケールする + if ( + t < self.gradual_latent.start_timesteps + and current_ratio < 1.0 + and step_elapsed >= self.gradual_latent.every_n_steps + ): + current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) + # make divisible by 8 because size of latents must be divisible at bottom of UNet + h = int(height * current_ratio) // 8 * 8 + w = int(width * current_ratio) // 8 * 8 + resized_size = (h, w) + self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) + step_elapsed = 0 + else: + self.scheduler.set_gradual_latent_params(None, None) + step_elapsed += 1 + # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1047,8 +1104,7 @@ def __call__( if vae_batch_size >= batch_size: image = self.vae.decode(latents).sample else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() images = [] for i in tqdm(range(0, batch_size, vae_batch_size)): images.append( @@ -1535,7 +1591,9 @@ def cond_fn_vgg16_b1(self, latents, timestep, index, text_embeddings, noise_pred image_embeddings = self.vgg16_feat_model(image)["feat"] # バッチサイズが複数だと正しく動くかわからない - loss = ((image_embeddings - guide_embeddings) ** 2).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので + loss = ( + (image_embeddings - guide_embeddings) ** 2 + ).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので grads = -torch.autograd.grad(loss, latents)[0] if isinstance(self.scheduler, LMSDiscreteScheduler): @@ -1699,7 +1757,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: if word.strip() == "BREAK": # pad until next multiple of tokenizer's max token length pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length) - print(f"BREAK pad_len: {pad_len}") + logger.info(f"BREAK pad_len: {pad_len}") for i in range(pad_len): # v2のときEOSをつけるべきかどうかわからないぜ # if i == 0: @@ -1729,7 +1787,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: tokens.append(text_token) weights.append(text_weight) if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights @@ -2041,7 +2099,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): elif len(count_range) == 2: count_range = [int(count_range[0]), int(count_range[1])] else: - print(f"invalid count range: {count_range}") + logger.warning(f"invalid count range: {count_range}") count_range = [1, 1] if count_range[0] > count_range[1]: count_range = [count_range[1], count_range[0]] @@ -2111,7 +2169,7 @@ def replacer(): # def load_clip_l14_336(dtype): -# print(f"loading CLIP: {CLIP_ID_L14_336}") +# logger.info(f"loading CLIP: {CLIP_ID_L14_336}") # text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) # return text_encoder @@ -2126,6 +2184,7 @@ class BatchDataBase(NamedTuple): mask_image: Any clip_prompt: str guide_image: Any + raw_prompt: str class BatchDataExt(NamedTuple): @@ -2158,9 +2217,9 @@ def main(args): # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") # モデルを読み込む if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う @@ -2170,10 +2229,10 @@ def main(args): use_stable_diffusion_format = os.path.isfile(args.ckpt) if use_stable_diffusion_format: - print("load StableDiffusion checkpoint") + logger.info("load StableDiffusion checkpoint") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) else: - print("load Diffusers pretrained models") + logger.info("load Diffusers pretrained models") loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) text_encoder = loading_pipe.text_encoder vae = loading_pipe.vae @@ -2196,21 +2255,21 @@ def main(args): # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") # # 置換するCLIPを読み込む # if args.replace_clip_l14_336: # text_encoder = load_clip_l14_336(dtype) - # print(f"large clip {CLIP_ID_L14_336} is loaded") + # logger.info(f"large clip {CLIP_ID_L14_336} is loaded") if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale: - print("prepare clip model") + logger.info("prepare clip model") clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype) else: clip_model = None if args.vgg16_guidance_scale > 0.0: - print("prepare resnet model") + logger.info("prepare resnet model") vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1) else: vgg16_model = None @@ -2222,7 +2281,7 @@ def main(args): replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) # tokenizerを読み込む - print("loading tokenizer") + logger.info("loading tokenizer") if use_stable_diffusion_format: tokenizer = train_util.load_tokenizer(args) @@ -2245,7 +2304,7 @@ def main(args): scheduler_cls = EulerDiscreteScheduler scheduler_module = diffusers.schedulers.scheduling_euler_discrete elif args.sampler == "euler_a" or args.sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_cls = EulerAncestralDiscreteSchedulerGL scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": scheduler_cls = DPMSolverMultistepScheduler @@ -2281,7 +2340,7 @@ def reset_sampler_noises(self, noises): self.sampler_noises = noises def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + # logger.info(f"replacing {shape} {len(self.sampler_noises)} {self.sampler_noise_index}") if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): noise = self.sampler_noises[self.sampler_noise_index] if shape != noise.shape: @@ -2290,7 +2349,7 @@ def randn(self, shape, device=None, dtype=None, layout=None, generator=None): noise = None if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) self.sampler_noise_index += 1 @@ -2321,11 +2380,11 @@ def __getattr__(self, item): # clip_sample=Trueにする if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - print("set clip_sample to True") + logger.info("set clip_sample to True") scheduler.config.clip_sample = True # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + device = get_preferred_device() # custom pipelineをコピったやつを生成する if args.vae_slices: @@ -2378,7 +2437,7 @@ def __getattr__(self, item): network_merge = 0 for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) + logger.info(f"import network module: {network_module}") imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] @@ -2396,7 +2455,7 @@ def __getattr__(self, item): raise ValueError("No weight. Weight is required.") network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + logger.info(f"load network weights from: {network_weight}") if model_util.is_safetensors(network_weight) and args.network_show_meta: from safetensors.torch import safe_open @@ -2404,7 +2463,7 @@ def __getattr__(self, item): with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + logger.info(f"metadata for: {network_weight}: {metadata}") network, weights_sd = imported_module.create_network_from_weights( network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs @@ -2414,20 +2473,20 @@ def __getattr__(self, item): mergeable = network.is_mergeable() if network_merge and not mergeable: - print("network is not mergiable. ignore merge option.") + logger.warning("network is not mergiable. ignore merge option.") if not mergeable or i >= network_merge: # not merging network.apply_to(text_encoder, unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - print(f"weights are loaded: {info}") + logger.info(f"weights are loaded: {info}") if args.opt_channels_last: network.to(memory_format=torch.channels_last) network.to(dtype).to(device) if network_pre_calc: - print("backup original weights") + logger.info("backup original weights") network.backup_weights() networks.append(network) @@ -2441,7 +2500,7 @@ def __getattr__(self, item): # upscalerの指定があれば取得する upscaler = None if args.highres_fix_upscaler: - print("import upscaler module:", args.highres_fix_upscaler) + logger.info(f"import upscaler module {args.highres_fix_upscaler}") imported_module = importlib.import_module(args.highres_fix_upscaler) us_kwargs = {} @@ -2450,7 +2509,7 @@ def __getattr__(self, item): key, value = net_arg.split("=") us_kwargs[key] = value - print("create upscaler") + logger.info("create upscaler") upscaler = imported_module.create_upscaler(**us_kwargs) upscaler.to(dtype).to(device) @@ -2467,7 +2526,7 @@ def __getattr__(self, item): control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) if args.opt_channels_last: - print(f"set optimizing: channels last") + logger.info(f"set optimizing: channels last") text_encoder.to(memory_format=torch.channels_last) vae.to(memory_format=torch.channels_last) unet.to(memory_format=torch.channels_last) @@ -2499,7 +2558,7 @@ def __getattr__(self, item): args.vgg16_guidance_layer, ) pipe.set_control_nets(control_nets) - print("pipeline is ready.") + logger.info("pipeline is ready.") if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() @@ -2508,6 +2567,29 @@ def __getattr__(self, item): if args.ds_depth_1 is not None: unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + # Gradual Latent + if args.gradual_latent_timesteps is not None: + if args.gradual_latent_unsharp_params: + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + + gradual_latent = GradualLatent( + args.gradual_latent_ratio, + args.gradual_latent_timesteps, + args.gradual_latent_every_n_steps, + args.gradual_latent_ratio_step, + args.gradual_latent_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + # Extended Textual Inversion および Textual Inversionを処理する if args.XTI_embeddings: diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI @@ -2529,7 +2611,9 @@ def __getattr__(self, item): embeds = next(iter(data.values())) if type(embeds) != torch.Tensor: - raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}") + raise ValueError( + f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}" + ) num_vectors_per_token = embeds.size()[0] token_string = os.path.splitext(os.path.basename(embeds_file))[0] @@ -2542,7 +2626,7 @@ def __getattr__(self, item): ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") assert ( min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 ), f"token ids is not ordered" @@ -2601,7 +2685,7 @@ def __getattr__(self, item): ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + logger.info(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") # if num_vectors_per_token > 1: pipe.add_token_replacement(token_ids[0], token_ids) @@ -2626,10 +2710,10 @@ def __getattr__(self, item): # promptを取得する if args.from_file is not None: - print(f"reading prompts from {args.from_file}") + logger.info(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0] + prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] elif args.prompt is not None: prompt_list = [args.prompt] else: @@ -2655,7 +2739,7 @@ def load_images(path): for p in paths: image = Image.open(p) if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") + logger.info(f"convert image to RGB from {image.mode}: {p}") image = image.convert("RGB") images.append(image) @@ -2671,24 +2755,24 @@ def resize_images(imgs, size): return resized if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") + logger.info(f"load image for img2img: {args.image_path}") init_images = load_images(args.image_path) assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") + logger.info(f"loaded {len(init_images)} images for img2img") else: init_images = None if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") + logger.info(f"load mask for inpainting: {args.mask_path}") mask_images = load_images(args.mask_path) assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") + logger.info(f"loaded {len(mask_images)} mask images for inpainting") else: mask_images = None # promptがないとき、画像のPngInfoから取得する if init_images is not None and len(prompt_list) == 0 and not args.interactive: - print("get prompts from images' meta data") + logger.info("get prompts from images' meta data") for img in init_images: if "prompt" in img.text: prompt = img.text["prompt"] @@ -2717,17 +2801,17 @@ def resize_images(imgs, size): h = int(h * args.highres_fix_scale + 0.5) if init_images is not None: - print(f"resize img2img source images to {w}*{h}") + logger.info(f"resize img2img source images to {w}*{h}") init_images = resize_images(init_images, (w, h)) if mask_images is not None: - print(f"resize img2img mask images to {w}*{h}") + logger.info(f"resize img2img mask images to {w}*{h}") mask_images = resize_images(mask_images, (w, h)) regional_network = False if networks and mask_images: # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 regional_network = True - print("use mask as region") + logger.info("use mask as region") size = None for i, network in enumerate(networks): @@ -2752,14 +2836,16 @@ def resize_images(imgs, size): prev_image = None # for VGG16 guided if args.guide_image_path is not None: - print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") + logger.info(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") guide_images = [] for p in args.guide_image_path: guide_images.extend(load_images(p)) - print(f"loaded {len(guide_images)} guide images for guidance") + logger.info(f"loaded {len(guide_images)} guide images for guidance") if len(guide_images) == 0: - print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + logger.info( + f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" + ) guide_images = None else: guide_images = None @@ -2785,7 +2871,7 @@ def resize_images(imgs, size): max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") iter_seed = random.randint(0, 0x7FFFFFFF) # shuffle prompt list @@ -2801,7 +2887,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - print("process 1st stage") + logger.info("process 1st stage") batch_1st = [] for _, base, ext in batch: width_1st = int(ext.width * args.highres_fix_scale + 0.5) @@ -2827,7 +2913,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") + logger.info("process 2nd stage") width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height if upscaler: @@ -2873,13 +2959,14 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # このバッチの情報を取り出す ( return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image), + (step_first, _, _, _, init_image, mask_image, _, guide_image, _), (width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts), ) = batch[0] noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) prompts = [] negative_prompts = [] + raw_prompts = [] start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) noises = [ torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) @@ -2910,11 +2997,16 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): all_images_are_same = True all_masks_are_same = True all_guide_images_are_same = True - for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + for i, ( + _, + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + _, + ) in enumerate(batch): prompts.append(prompt) negative_prompts.append(negative_prompt) seeds.append(seed) clip_prompts.append(clip_prompt) + raw_prompts.append(raw_prompt) if init_image is not None: init_images.append(init_image) @@ -2978,7 +3070,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): n.restore_weights() for n in networks: n.pre_calculation() - print("pre-calculation... done") + logger.info("pre-calculation... done") images = pipe( prompts, @@ -3006,8 +3098,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # save image highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) ): if highres_fix: seed -= 1 # record original seed @@ -3023,6 +3115,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): metadata.add_text("negative-scale", str(negative_scale)) if clip_prompt is not None: metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) if args.use_original_file_name and init_images is not None: if type(init_images) is list: @@ -3045,7 +3139,9 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): cv2.waitKey() cv2.destroyAllWindows() except ImportError: - print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + logger.info( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) return images @@ -3058,7 +3154,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # interactive valid = False while not valid: - print("\nType prompt:") + logger.info("") + logger.info("Type prompt:") try: raw_prompt = input() except EOFError: @@ -3099,40 +3196,48 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): ds_timesteps_2 = args.ds_timesteps_2 ds_ratio = args.ds_ratio + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: try: m = re.match(r"w (\d+)", parg, re.IGNORECASE) if m: width = int(m.group(1)) - print(f"width: {width}") + logger.info(f"width: {width}") continue m = re.match(r"h (\d+)", parg, re.IGNORECASE) if m: height = int(m.group(1)) - print(f"height: {height}") + logger.info(f"height: {height}") continue m = re.match(r"s (\d+)", parg, re.IGNORECASE) if m: # steps steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") + logger.info(f"steps: {steps}") continue m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) if m: # seed seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") + logger.info(f"seeds: {seeds}") continue m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) if m: # scale scale = float(m.group(1)) - print(f"scale: {scale}") + logger.info(f"scale: {scale}") continue m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) @@ -3141,25 +3246,25 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): negative_scale = None else: negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") + logger.info(f"negative scale: {negative_scale}") continue m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) if m: # strength strength = float(m.group(1)) - print(f"strength: {strength}") + logger.info(f"strength: {strength}") continue m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") + logger.info(f"negative prompt: {negative_prompt}") continue m = re.match(r"c (.+)", parg, re.IGNORECASE) if m: # clip prompt clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") + logger.info(f"clip prompt: {clip_prompt}") continue m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) @@ -3167,47 +3272,89 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): network_muls = [float(v) for v in m.group(1).split(",")] while len(network_muls) < len(networks): network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") + logger.info(f"network mul: {network_muls}") continue # Deep Shrink m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 1 ds_depth_1 = int(m.group(1)) - print(f"deep shrink depth 1: {ds_depth_1}") + logger.info(f"deep shrink depth 1: {ds_depth_1}") continue m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 1 ds_timesteps_1 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 1: {ds_timesteps_1}") + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") continue m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 2 ds_depth_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink depth 2: {ds_depth_2}") + logger.info(f"deep shrink depth 2: {ds_depth_2}") continue m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 2 ds_timesteps_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 2: {ds_timesteps_2}") + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") continue m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink ratio ds_ratio = float(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink ratio: {ds_ratio}") + logger.info(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + print(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent unsharp params: {gl_unsharp_params}") continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.info(f"Exception in parsing / 解析エラー: {parg}") + logger.info(ex) # override Deep Shrink if ds_depth_1 is not None: @@ -3215,6 +3362,31 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): ds_depth_1 = args.ds_depth_1 or 3 unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + # override Gradual Latent + if gl_timesteps is not None: + if gl_timesteps < 0: + gl_timesteps = args.gradual_latent_timesteps or 650 + if gl_unsharp_params is not None: + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + print(unsharp_params) + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + gradual_latent = GradualLatent( + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + # prepare seed if seeds is not None: # given in prompt # 数が足りないなら前のをそのまま使う @@ -3225,7 +3397,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): if len(predefined_seeds) > 0: seed = predefined_seeds.pop(0) else: - print("predefined seeds are exhausted") + logger.info("predefined seeds are exhausted") seed = None elif args.iter_same_seed: seed = iter_seed @@ -3235,7 +3407,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): if seed is None: seed = random.randint(0, 0x7FFFFFFF) if args.interactive: - print(f"seed: {seed}") + logger.info(f"seed: {seed}") # prepare init image, guide image and mask init_image = mask_image = guide_image = None @@ -3251,7 +3423,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): width = width - width % 32 height = height - height % 32 if width != init_image.size[0] or height != init_image.size[1]: - print( + logger.info( f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" ) @@ -3267,9 +3439,9 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): guide_image = guide_images[global_step % len(guide_images)] elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0: if prev_image is None: - print("Generate 1st image without guide image.") + logger.info("Generate 1st image without guide image.") else: - print("Use previous image as guide image.") + logger.info("Use previous image as guide image.") guide_image = prev_image if regional_network: @@ -3282,7 +3454,9 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): b1 = BatchData( False, - BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), BatchDataExt( width, height, @@ -3311,22 +3485,31 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): process_batch(batch_data, highres_fix) batch_data.clear() - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む") + add_logging_arguments(parser) + + parser.add_argument( + "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" + ) parser.add_argument( "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" ) parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") parser.add_argument( - "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む" + "--from_file", + type=str, + default=None, + help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", ) parser.add_argument( - "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)" + "--interactive", + action="store_true", + help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", ) parser.add_argument( "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" @@ -3338,7 +3521,9 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする") + parser.add_argument( + "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" + ) parser.add_argument( "--use_original_file_name", action="store_true", @@ -3392,9 +3577,14 @@ def setup_parser() -> argparse.ArgumentParser: default=7.5, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", ) - parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") parser.add_argument( - "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", ) parser.add_argument( "--tokenizer_cache_dir", @@ -3430,25 +3620,46 @@ def setup_parser() -> argparse.ArgumentParser: help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", ) parser.add_argument( - "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" + "--opt_channels_last", + action="store_true", + help="set channels last option to model / モデルにchannels lastを指定し最適化する", ) parser.add_argument( - "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名" + "--network_module", + type=str, + default=None, + nargs="*", + help="additional network module to use / 追加ネットワークを使う時そのモジュール名", ) parser.add_argument( "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" ) - parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率") parser.add_argument( - "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" + "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" + ) + parser.add_argument( + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", + ) + parser.add_argument( + "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" + ) + parser.add_argument( + "--network_merge_n_models", + type=int, + default=None, + help="merge this number of networks / この数だけネットワークをマージする", ) - parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument( - "--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする" + "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" ) - parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( - "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" + "--network_pre_calc", + action="store_true", + help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", ) parser.add_argument( "--network_regional_mask_max_color_codes", @@ -3470,7 +3681,9 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings", ) - parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") + parser.add_argument( + "--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う" + ) parser.add_argument( "--max_embeddings_multiples", type=int, @@ -3511,7 +3724,10 @@ def setup_parser() -> argparse.ArgumentParser: help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", ) parser.add_argument( - "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" + "--highres_fix_steps", + type=int, + default=28, + help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", ) parser.add_argument( "--highres_fix_strength", @@ -3520,7 +3736,9 @@ def setup_parser() -> argparse.ArgumentParser: help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", ) parser.add_argument( - "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" + "--highres_fix_save_1st", + action="store_true", + help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", ) parser.add_argument( "--highres_fix_latents_upscaling", @@ -3528,7 +3746,10 @@ def setup_parser() -> argparse.ArgumentParser: help="use latents upscaling for highres fix / highres fixでlatentで拡大する", ) parser.add_argument( - "--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名" + "--highres_fix_upscaler", + type=str, + default=None, + help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", ) parser.add_argument( "--highres_fix_upscaler_args", @@ -3543,14 +3764,21 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" + "--negative_scale", + type=float, + default=None, + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", ) parser.add_argument( "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" ) parser.add_argument( - "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" + "--control_net_preps", + type=str, + default=None, + nargs="*", + help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名", ) parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み") parser.add_argument( @@ -3588,6 +3816,45 @@ def setup_parser() -> argparse.ArgumentParser: "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" ) + # gradual latent + parser.add_argument( + "--gradual_latent_timesteps", + type=int, + default=None, + help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", + ) + parser.add_argument( + "--gradual_latent_ratio", + type=float, + default=0.5, + help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", + ) + parser.add_argument( + "--gradual_latent_ratio_step", + type=float, + default=0.125, + help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", + ) + parser.add_argument( + "--gradual_latent_every_n_steps", + type=int, + default=3, + help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", + ) + parser.add_argument( + "--gradual_latent_s_noise", + type=float, + default=1.0, + help="s_noise for Gradual Latent / Gradual Latentのs_noise", + ) + parser.add_argument( + "--gradual_latent_unsharp_params", + type=str, + default=None, + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", + ) + return parser @@ -3595,4 +3862,5 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + setup_logging(args, reset=True) main(args) diff --git a/library/config_util.py b/library/config_util.py index a98c2b90d..fc4b36175 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -40,7 +40,10 @@ ControlNetDataset, DatasetGroup, ) - +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def add_config_arguments(parser: argparse.ArgumentParser): parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") @@ -345,7 +348,7 @@ def sanitize_user_config(self, user_config: dict) -> dict: return self.user_config_validator(user_config) except MultipleInvalid: # TODO: エラー発生時のメッセージをわかりやすくする - print("Invalid user config / ユーザ設定の形式が正しくないようです") + logger.error("Invalid user config / ユーザ設定の形式が正しくないようです") raise # NOTE: In nature, argument parser result is not needed to be sanitize @@ -355,7 +358,7 @@ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> return self.argparse_config_validator(argparse_namespace) except MultipleInvalid: # XXX: this should be a bug - print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") + logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") raise # NOTE: value would be overwritten by latter dict if there is already the same key @@ -538,13 +541,13 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu " ", ) - print(info) + logger.info(f'{info}') # make buckets first because it determines the length of dataset # and set the same seed for all datasets - seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): - print(f"[Dataset {i}]") + logger.info(f"[Dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) @@ -557,7 +560,7 @@ def extract_dreambooth_params(name: str) -> Tuple[int, str]: try: n_repeats = int(tokens[0]) except ValueError as e: - print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") + logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") return 0, "" caption_by_folder = "_".join(tokens[1:]) return n_repeats, caption_by_folder @@ -629,17 +632,13 @@ def load_user_config(file: str) -> dict: with open(file, "r") as f: config = json.load(f) except Exception: - print( - f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" - ) + logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") raise elif file.name.lower().endswith(".toml"): try: config = toml.load(file) except Exception: - print( - f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" - ) + logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") raise else: raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") @@ -665,23 +664,26 @@ def load_user_config(file: str) -> dict: argparse_namespace = parser.parse_args(remain) train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) - print("[argparse_namespace]") - print(vars(argparse_namespace)) + logger.info("[argparse_namespace]") + logger.info(f'{vars(argparse_namespace)}') user_config = load_user_config(config_args.dataset_config) - print("\n[user_config]") - print(user_config) + logger.info("") + logger.info("[user_config]") + logger.info(f'{user_config}') sanitizer = ConfigSanitizer( config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout ) sanitized_user_config = sanitizer.sanitize_user_config(user_config) - print("\n[sanitized_user_config]") - print(sanitized_user_config) + logger.info("") + logger.info("[sanitized_user_config]") + logger.info(f'{sanitized_user_config}') blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) - print("\n[blueprint]") - print(blueprint) + logger.info("") + logger.info("[blueprint]") + logger.info(f'{blueprint}') diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index e0a026dae..a56474622 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -3,7 +3,10 @@ import random import re from typing import List, Optional, Union - +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def prepare_scheduler_for_custom_training(noise_scheduler, device): if hasattr(noise_scheduler, "all_snr"): @@ -21,7 +24,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device): def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): # fix beta: zero terminal SNR - print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") + logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") def enforce_zero_terminal_snr(betas): # Convert betas to alphas_bar_sqrt @@ -49,8 +52,8 @@ def enforce_zero_terminal_snr(betas): alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) - # print("original:", noise_scheduler.betas) - # print("fixed:", betas) + # logger.info(f"original: {noise_scheduler.betas}") + # logger.info(f"fixed: {betas}") noise_scheduler.betas = betas noise_scheduler.alphas = alphas @@ -79,13 +82,13 @@ def get_snr_scale(timesteps, noise_scheduler): snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 scale = snr_t / (snr_t + 1) # # show debug info - # print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") + # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") return scale def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss): scale = get_snr_scale(timesteps, noise_scheduler) - # print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") + # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") loss = loss + loss / scale * v_pred_like_loss return loss @@ -268,7 +271,7 @@ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int): tokens.append(text_token) weights.append(text_weight) if truncated: - print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights diff --git a/library/device_utils.py b/library/device_utils.py new file mode 100644 index 000000000..8823c5d9a --- /dev/null +++ b/library/device_utils.py @@ -0,0 +1,84 @@ +import functools +import gc + +import torch + +try: + HAS_CUDA = torch.cuda.is_available() +except Exception: + HAS_CUDA = False + +try: + HAS_MPS = torch.backends.mps.is_available() +except Exception: + HAS_MPS = False + +try: + import intel_extension_for_pytorch as ipex # noqa + + HAS_XPU = torch.xpu.is_available() +except Exception: + HAS_XPU = False + + +def clean_memory(): + gc.collect() + if HAS_CUDA: + torch.cuda.empty_cache() + if HAS_XPU: + torch.xpu.empty_cache() + if HAS_MPS: + torch.mps.empty_cache() + + +def clean_memory_on_device(device: torch.device): + r""" + Clean memory on the specified device, will be called from training scripts. + """ + gc.collect() + + # device may "cuda" or "cuda:0", so we need to check the type of device + if device.type == "cuda": + torch.cuda.empty_cache() + if device.type == "xpu": + torch.xpu.empty_cache() + if device.type == "mps": + torch.mps.empty_cache() + + +@functools.lru_cache(maxsize=None) +def get_preferred_device() -> torch.device: + r""" + Do not call this function from training scripts. Use accelerator.device instead. + """ + if HAS_CUDA: + device = torch.device("cuda") + elif HAS_XPU: + device = torch.device("xpu") + elif HAS_MPS: + device = torch.device("mps") + else: + device = torch.device("cpu") + print(f"get_preferred_device() -> {device}") + return device + + +def init_ipex(): + """ + Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. + + This function should run right after importing torch and before doing anything else. + + If IPEX is not available, this function does nothing. + """ + try: + if HAS_XPU: + from library.ipex import ipex_init + + is_initialized, error_message = ipex_init() + if not is_initialized: + print("failed to initialize ipex:", error_message) + else: + return + except Exception as e: + print("failed to initialize ipex:", e) diff --git a/library/huggingface_util.py b/library/huggingface_util.py index 376fdb1e6..57b19d982 100644 --- a/library/huggingface_util.py +++ b/library/huggingface_util.py @@ -4,7 +4,10 @@ import argparse import os from library.utils import fire_in_thread - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): api = HfApi( @@ -33,9 +36,9 @@ def upload( try: api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので - print("===========================================") - print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") - print("===========================================") + logger.error("===========================================") + logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") + logger.error("===========================================") is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir()) @@ -56,9 +59,9 @@ def uploader(): path_in_repo=path_in_repo, ) except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので - print("===========================================") - print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") - print("===========================================") + logger.error("===========================================") + logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") + logger.error("===========================================") if args.async_upload and not force_sync_upload: fire_in_thread(uploader) diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 333504935..972a3bf63 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -9,162 +9,171 @@ def ipex_init(): # pylint: disable=too-many-statements try: - # Replace cuda with xpu: - torch.cuda.current_device = torch.xpu.current_device - torch.cuda.current_stream = torch.xpu.current_stream - torch.cuda.device = torch.xpu.device - torch.cuda.device_count = torch.xpu.device_count - torch.cuda.device_of = torch.xpu.device_of - torch.cuda.get_device_name = torch.xpu.get_device_name - torch.cuda.get_device_properties = torch.xpu.get_device_properties - torch.cuda.init = torch.xpu.init - torch.cuda.is_available = torch.xpu.is_available - torch.cuda.is_initialized = torch.xpu.is_initialized - torch.cuda.is_current_stream_capturing = lambda: False - torch.cuda.set_device = torch.xpu.set_device - torch.cuda.stream = torch.xpu.stream - torch.cuda.synchronize = torch.xpu.synchronize - torch.cuda.Event = torch.xpu.Event - torch.cuda.Stream = torch.xpu.Stream - torch.cuda.FloatTensor = torch.xpu.FloatTensor - torch.Tensor.cuda = torch.Tensor.xpu - torch.Tensor.is_cuda = torch.Tensor.is_xpu - torch.UntypedStorage.cuda = torch.UntypedStorage.xpu - torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock - torch.cuda._initialized = torch.xpu.lazy_init._initialized - torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker - torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls - torch.cuda._tls = torch.xpu.lazy_init._tls - torch.cuda.threading = torch.xpu.lazy_init.threading - torch.cuda.traceback = torch.xpu.lazy_init.traceback - torch.cuda.Optional = torch.xpu.Optional - torch.cuda.__cached__ = torch.xpu.__cached__ - torch.cuda.__loader__ = torch.xpu.__loader__ - torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage - torch.cuda.Tuple = torch.xpu.Tuple - torch.cuda.streams = torch.xpu.streams - torch.cuda._lazy_new = torch.xpu._lazy_new - torch.cuda.FloatStorage = torch.xpu.FloatStorage - torch.cuda.Any = torch.xpu.Any - torch.cuda.__doc__ = torch.xpu.__doc__ - torch.cuda.default_generators = torch.xpu.default_generators - torch.cuda.HalfTensor = torch.xpu.HalfTensor - torch.cuda._get_device_index = torch.xpu._get_device_index - torch.cuda.__path__ = torch.xpu.__path__ - torch.cuda.Device = torch.xpu.Device - torch.cuda.IntTensor = torch.xpu.IntTensor - torch.cuda.ByteStorage = torch.xpu.ByteStorage - torch.cuda.set_stream = torch.xpu.set_stream - torch.cuda.BoolStorage = torch.xpu.BoolStorage - torch.cuda.os = torch.xpu.os - torch.cuda.torch = torch.xpu.torch - torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage - torch.cuda.Union = torch.xpu.Union - torch.cuda.DoubleTensor = torch.xpu.DoubleTensor - torch.cuda.ShortTensor = torch.xpu.ShortTensor - torch.cuda.LongTensor = torch.xpu.LongTensor - torch.cuda.IntStorage = torch.xpu.IntStorage - torch.cuda.LongStorage = torch.xpu.LongStorage - torch.cuda.__annotations__ = torch.xpu.__annotations__ - torch.cuda.__package__ = torch.xpu.__package__ - torch.cuda.__builtins__ = torch.xpu.__builtins__ - torch.cuda.CharTensor = torch.xpu.CharTensor - torch.cuda.List = torch.xpu.List - torch.cuda._lazy_init = torch.xpu._lazy_init - torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor - torch.cuda.DoubleStorage = torch.xpu.DoubleStorage - torch.cuda.ByteTensor = torch.xpu.ByteTensor - torch.cuda.StreamContext = torch.xpu.StreamContext - torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage - torch.cuda.ShortStorage = torch.xpu.ShortStorage - torch.cuda._lazy_call = torch.xpu._lazy_call - torch.cuda.HalfStorage = torch.xpu.HalfStorage - torch.cuda.random = torch.xpu.random - torch.cuda._device = torch.xpu._device - torch.cuda.classproperty = torch.xpu.classproperty - torch.cuda.__name__ = torch.xpu.__name__ - torch.cuda._device_t = torch.xpu._device_t - torch.cuda.warnings = torch.xpu.warnings - torch.cuda.__spec__ = torch.xpu.__spec__ - torch.cuda.BoolTensor = torch.xpu.BoolTensor - torch.cuda.CharStorage = torch.xpu.CharStorage - torch.cuda.__file__ = torch.xpu.__file__ - torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked: + return True, "Skipping IPEX hijack" + else: + # Replace cuda with xpu: + torch.cuda.current_device = torch.xpu.current_device + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.device = torch.xpu.device + torch.cuda.device_count = torch.xpu.device_count + torch.cuda.device_of = torch.xpu.device_of + torch.cuda.get_device_name = torch.xpu.get_device_name + torch.cuda.get_device_properties = torch.xpu.get_device_properties + torch.cuda.init = torch.xpu.init + torch.cuda.is_available = torch.xpu.is_available + torch.cuda.is_initialized = torch.xpu.is_initialized + torch.cuda.is_current_stream_capturing = lambda: False + torch.cuda.set_device = torch.xpu.set_device + torch.cuda.stream = torch.xpu.stream + torch.cuda.synchronize = torch.xpu.synchronize + torch.cuda.Event = torch.xpu.Event + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.Tensor.cuda = torch.Tensor.xpu + torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.UntypedStorage.cuda = torch.UntypedStorage.xpu + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda.Optional = torch.xpu.Optional + torch.cuda.__cached__ = torch.xpu.__cached__ + torch.cuda.__loader__ = torch.xpu.__loader__ + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.Tuple = torch.xpu.Tuple + torch.cuda.streams = torch.xpu.streams + torch.cuda._lazy_new = torch.xpu._lazy_new + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.Any = torch.xpu.Any + torch.cuda.__doc__ = torch.xpu.__doc__ + torch.cuda.default_generators = torch.xpu.default_generators + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda._get_device_index = torch.xpu._get_device_index + torch.cuda.__path__ = torch.xpu.__path__ + torch.cuda.Device = torch.xpu.Device + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.set_stream = torch.xpu.set_stream + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.os = torch.xpu.os + torch.cuda.torch = torch.xpu.torch + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.Union = torch.xpu.Union + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.__annotations__ = torch.xpu.__annotations__ + torch.cuda.__package__ = torch.xpu.__package__ + torch.cuda.__builtins__ = torch.xpu.__builtins__ + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.List = torch.xpu.List + torch.cuda._lazy_init = torch.xpu._lazy_init + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.StreamContext = torch.xpu.StreamContext + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda._lazy_call = torch.xpu._lazy_call + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.random = torch.xpu.random + torch.cuda._device = torch.xpu._device + torch.cuda.classproperty = torch.xpu.classproperty + torch.cuda.__name__ = torch.xpu.__name__ + torch.cuda._device_t = torch.xpu._device_t + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.__spec__ = torch.xpu.__spec__ + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.__file__ = torch.xpu.__file__ + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - # Memory: - torch.cuda.memory = torch.xpu.memory - if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): - torch.xpu.empty_cache = lambda: None - torch.cuda.empty_cache = torch.xpu.empty_cache - torch.cuda.memory_stats = torch.xpu.memory_stats - torch.cuda.memory_summary = torch.xpu.memory_summary - torch.cuda.memory_snapshot = torch.xpu.memory_snapshot - torch.cuda.memory_allocated = torch.xpu.memory_allocated - torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated - torch.cuda.memory_reserved = torch.xpu.memory_reserved - torch.cuda.memory_cached = torch.xpu.memory_reserved - torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved - torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved - torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats - torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict - torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats + # Memory: + torch.cuda.memory = torch.xpu.memory + if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): + torch.xpu.empty_cache = lambda: None + torch.cuda.empty_cache = torch.xpu.empty_cache + torch.cuda.memory_stats = torch.xpu.memory_stats + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory_allocated = torch.xpu.memory_allocated + torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated + torch.cuda.memory_reserved = torch.xpu.memory_reserved + torch.cuda.memory_cached = torch.xpu.memory_reserved + torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved + torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved + torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats + torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict + torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats - # RNG: - torch.cuda.get_rng_state = torch.xpu.get_rng_state - torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all - torch.cuda.set_rng_state = torch.xpu.set_rng_state - torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all - torch.cuda.manual_seed = torch.xpu.manual_seed - torch.cuda.manual_seed_all = torch.xpu.manual_seed_all - torch.cuda.seed = torch.xpu.seed - torch.cuda.seed_all = torch.xpu.seed_all - torch.cuda.initial_seed = torch.xpu.initial_seed + # RNG: + torch.cuda.get_rng_state = torch.xpu.get_rng_state + torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all + torch.cuda.set_rng_state = torch.xpu.set_rng_state + torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all + torch.cuda.manual_seed = torch.xpu.manual_seed + torch.cuda.manual_seed_all = torch.xpu.manual_seed_all + torch.cuda.seed = torch.xpu.seed + torch.cuda.seed_all = torch.xpu.seed_all + torch.cuda.initial_seed = torch.xpu.initial_seed + + # AMP: + torch.cuda.amp = torch.xpu.amp + torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled + torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype + + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False - # AMP: - torch.cuda.amp = torch.xpu.amp - if not hasattr(torch.cuda.amp, "common"): - torch.cuda.amp.common = contextlib.nullcontext() - torch.cuda.amp.common.amp_definitely_not_available = lambda: False - try: - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught try: - from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error - gradscaler_init() torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler except Exception: # pylint: disable=broad-exception-caught - torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + try: + from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - # C - torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream - ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count - ipex._C._DeviceProperties.major = 2023 - ipex._C._DeviceProperties.minor = 2 + # C + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count + ipex._C._DeviceProperties.major = 2023 + ipex._C._DeviceProperties.minor = 2 - # Fix functions with ipex: - torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] - torch._utils._get_available_device_type = lambda: "xpu" - torch.has_cuda = True - torch.cuda.has_half = True - torch.cuda.is_bf16_supported = lambda *args, **kwargs: True - torch.cuda.is_fp16_supported = lambda *args, **kwargs: True - torch.version.cuda = "11.7" - torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] - torch.cuda.get_device_properties.major = 11 - torch.cuda.get_device_properties.minor = 7 - torch.cuda.ipc_collect = lambda *args, **kwargs: None - torch.cuda.utilization = lambda *args, **kwargs: 0 + # Fix functions with ipex: + torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] + torch._utils._get_available_device_type = lambda: "xpu" + torch.has_cuda = True + torch.cuda.has_half = True + torch.cuda.is_bf16_supported = lambda *args, **kwargs: True + torch.cuda.is_fp16_supported = lambda *args, **kwargs: True + torch.backends.cuda.is_built = lambda *args, **kwargs: True + torch.version.cuda = "12.1" + torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] + torch.cuda.get_device_properties.major = 12 + torch.cuda.get_device_properties.minor = 1 + torch.cuda.ipc_collect = lambda *args, **kwargs: None + torch.cuda.utilization = lambda *args, **kwargs: 0 - ipex_hijacks() - if not torch.xpu.has_fp64_dtype(): - try: - from .diffusers import ipex_diffusers - ipex_diffusers() - except Exception: # pylint: disable=broad-exception-caught - pass + ipex_hijacks() + if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: + try: + from .diffusers import ipex_diffusers + ipex_diffusers() + except Exception: # pylint: disable=broad-exception-caught + pass + torch.cuda.is_xpu_hijacked = True except Exception as e: return False, e return True, None diff --git a/library/ipex/attention.py b/library/ipex/attention.py index e98807a84..8253c5b17 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -124,6 +124,7 @@ def torch_bmm_32_bit(input, mat2, *, out=None): ) else: return original_torch_bmm(input, mat2, out=out) + torch.xpu.synchronize(input.device) return hidden_states original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention @@ -172,4 +173,5 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo ) else: return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + torch.xpu.synchronize(query.device) return hidden_states diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 47b0375ae..732a18568 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -149,6 +149,7 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice del attn_slice + torch.xpu.synchronize(query.device) else: query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] @@ -283,6 +284,7 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, hidden_states[start_idx:end_idx] = attn_slice del attn_slice + torch.xpu.synchronize(query.device) else: attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index b6d246dd2..b1b9ccf0e 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -1,17 +1,22 @@ -import contextlib +import os +from functools import wraps +from contextlib import nullcontext import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import numpy as np + +device_supports_fp64 = torch.xpu.has_fp64_dtype() # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument if isinstance(device_ids, list) and len(device_ids) > 1: - print("IPEX backend doesn't support DataParallel on multiple XPU devices") + logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices") return module.to("xpu") def return_null_context(*args, **kwargs): # pylint: disable=unused-argument - return contextlib.nullcontext() + return nullcontext() @property def is_cuda(self): @@ -25,15 +30,17 @@ def return_xpu(device): # Autocast -original_autocast = torch.autocast -def ipex_autocast(*args, **kwargs): - if len(args) > 0 and args[0] == "cuda": - return original_autocast("xpu", *args[1:], **kwargs) +original_autocast_init = torch.amp.autocast_mode.autocast.__init__ +@wraps(torch.amp.autocast_mode.autocast.__init__) +def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None): + if device_type == "cuda": + return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) else: - return original_autocast(*args, **kwargs) + return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) # Latent Antialias CPU Offload: original_interpolate = torch.nn.functional.interpolate +@wraps(torch.nn.functional.interpolate) def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments if antialias or align_corners is not None: return_device = tensor.device @@ -44,15 +51,29 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) + # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): original_from_numpy = torch.from_numpy +@wraps(torch.from_numpy) def from_numpy(ndarray): if ndarray.dtype == float: return original_from_numpy(ndarray.astype('float32')) else: return original_from_numpy(ndarray) -if torch.xpu.has_fp64_dtype(): +original_as_tensor = torch.as_tensor +@wraps(torch.as_tensor) +def as_tensor(data, dtype=None, device=None): + if check_device(device): + device = return_xpu(device) + if isinstance(data, np.ndarray) and data.dtype == float and not ( + (isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)): + return original_as_tensor(data, dtype=torch.float32, device=device) + else: + return original_as_tensor(data, dtype=dtype, device=device) + + +if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: @@ -66,20 +87,25 @@ def from_numpy(ndarray): # Data Type Errors: +@wraps(torch.bmm) def torch_bmm(input, mat2, *, out=None): if input.dtype != mat2.dtype: mat2 = mat2.to(input.dtype) return original_torch_bmm(input, mat2, out=out) +@wraps(torch.nn.functional.scaled_dot_product_attention) def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): if query.dtype != key.dtype: key = key.to(dtype=query.dtype) if query.dtype != value.dtype: value = value.to(dtype=query.dtype) + if attn_mask is not None and query.dtype != attn_mask.dtype: + attn_mask = attn_mask.to(dtype=query.dtype) return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) # A1111 FP16 original_functional_group_norm = torch.nn.functional.group_norm +@wraps(torch.nn.functional.group_norm) def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): if weight is not None and input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -89,6 +115,7 @@ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): # A1111 BF16 original_functional_layer_norm = torch.nn.functional.layer_norm +@wraps(torch.nn.functional.layer_norm) def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): if weight is not None and input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -98,6 +125,7 @@ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1 # Training original_functional_linear = torch.nn.functional.linear +@wraps(torch.nn.functional.linear) def functional_linear(input, weight, bias=None): if input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -106,6 +134,7 @@ def functional_linear(input, weight, bias=None): return original_functional_linear(input, weight, bias=bias) original_functional_conv2d = torch.nn.functional.conv2d +@wraps(torch.nn.functional.conv2d) def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): if input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -115,6 +144,7 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, # A1111 Embedding BF16 original_torch_cat = torch.cat +@wraps(torch.cat) def torch_cat(tensor, *args, **kwargs): if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) @@ -123,6 +153,7 @@ def torch_cat(tensor, *args, **kwargs): # SwinIR BF16: original_functional_pad = torch.nn.functional.pad +@wraps(torch.nn.functional.pad) def functional_pad(input, pad, mode='constant', value=None): if mode == 'reflect' and input.dtype == torch.bfloat16: return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) @@ -131,13 +162,20 @@ def functional_pad(input, pad, mode='constant', value=None): original_torch_tensor = torch.tensor -def torch_tensor(*args, device=None, **kwargs): +@wraps(torch.tensor) +def torch_tensor(data, *args, dtype=None, device=None, **kwargs): if check_device(device): - return original_torch_tensor(*args, device=return_xpu(device), **kwargs) - else: - return original_torch_tensor(*args, device=device, **kwargs) + device = return_xpu(device) + if not device_supports_fp64: + if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device): + if dtype == torch.float64: + dtype = torch.float32 + elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)): + dtype = torch.float32 + return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs) original_Tensor_to = torch.Tensor.to +@wraps(torch.Tensor.to) def Tensor_to(self, device=None, *args, **kwargs): if check_device(device): return original_Tensor_to(self, return_xpu(device), *args, **kwargs) @@ -145,6 +183,7 @@ def Tensor_to(self, device=None, *args, **kwargs): return original_Tensor_to(self, device, *args, **kwargs) original_Tensor_cuda = torch.Tensor.cuda +@wraps(torch.Tensor.cuda) def Tensor_cuda(self, device=None, *args, **kwargs): if check_device(device): return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs) @@ -152,6 +191,7 @@ def Tensor_cuda(self, device=None, *args, **kwargs): return original_Tensor_cuda(self, device, *args, **kwargs) original_UntypedStorage_init = torch.UntypedStorage.__init__ +@wraps(torch.UntypedStorage.__init__) def UntypedStorage_init(*args, device=None, **kwargs): if check_device(device): return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs) @@ -159,6 +199,7 @@ def UntypedStorage_init(*args, device=None, **kwargs): return original_UntypedStorage_init(*args, device=device, **kwargs) original_UntypedStorage_cuda = torch.UntypedStorage.cuda +@wraps(torch.UntypedStorage.cuda) def UntypedStorage_cuda(self, device=None, *args, **kwargs): if check_device(device): return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs) @@ -166,6 +207,7 @@ def UntypedStorage_cuda(self, device=None, *args, **kwargs): return original_UntypedStorage_cuda(self, device, *args, **kwargs) original_torch_empty = torch.empty +@wraps(torch.empty) def torch_empty(*args, device=None, **kwargs): if check_device(device): return original_torch_empty(*args, device=return_xpu(device), **kwargs) @@ -173,6 +215,7 @@ def torch_empty(*args, device=None, **kwargs): return original_torch_empty(*args, device=device, **kwargs) original_torch_randn = torch.randn +@wraps(torch.randn) def torch_randn(*args, device=None, **kwargs): if check_device(device): return original_torch_randn(*args, device=return_xpu(device), **kwargs) @@ -180,6 +223,7 @@ def torch_randn(*args, device=None, **kwargs): return original_torch_randn(*args, device=device, **kwargs) original_torch_ones = torch.ones +@wraps(torch.ones) def torch_ones(*args, device=None, **kwargs): if check_device(device): return original_torch_ones(*args, device=return_xpu(device), **kwargs) @@ -187,6 +231,7 @@ def torch_ones(*args, device=None, **kwargs): return original_torch_ones(*args, device=device, **kwargs) original_torch_zeros = torch.zeros +@wraps(torch.zeros) def torch_zeros(*args, device=None, **kwargs): if check_device(device): return original_torch_zeros(*args, device=return_xpu(device), **kwargs) @@ -194,6 +239,7 @@ def torch_zeros(*args, device=None, **kwargs): return original_torch_zeros(*args, device=device, **kwargs) original_torch_linspace = torch.linspace +@wraps(torch.linspace) def torch_linspace(*args, device=None, **kwargs): if check_device(device): return original_torch_linspace(*args, device=return_xpu(device), **kwargs) @@ -201,6 +247,7 @@ def torch_linspace(*args, device=None, **kwargs): return original_torch_linspace(*args, device=device, **kwargs) original_torch_Generator = torch.Generator +@wraps(torch.Generator) def torch_Generator(device=None): if check_device(device): return original_torch_Generator(return_xpu(device)) @@ -208,12 +255,14 @@ def torch_Generator(device=None): return original_torch_Generator(device) original_torch_load = torch.load +@wraps(torch.load) def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs): if check_device(map_location): return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) else: return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + # Hijack Functions: def ipex_hijacks(): torch.tensor = torch_tensor @@ -232,7 +281,7 @@ def ipex_hijacks(): torch.backends.cuda.sdp_kernel = return_null_context torch.nn.DataParallel = DummyDataParallel torch.UntypedStorage.is_cuda = is_cuda - torch.autocast = ipex_autocast + torch.amp.autocast_mode.autocast.__init__ = autocast_init torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention torch.nn.functional.group_norm = functional_group_norm @@ -244,5 +293,6 @@ def ipex_hijacks(): torch.bmm = torch_bmm torch.cat = torch_cat - if not torch.xpu.has_fp64_dtype(): + if not device_supports_fp64: torch.from_numpy = from_numpy + torch.as_tensor = as_tensor diff --git a/library/ipex_interop.py b/library/ipex_interop.py deleted file mode 100644 index 6fe320c57..000000000 --- a/library/ipex_interop.py +++ /dev/null @@ -1,24 +0,0 @@ -import torch - - -def init_ipex(): - """ - Try to import `intel_extension_for_pytorch`, and apply - the hijacks using `library.ipex.ipex_init`. - - If IPEX is not installed, this function does nothing. - """ - try: - import intel_extension_for_pytorch as ipex # noqa - except ImportError: - return - - try: - from library.ipex import ipex_init - - if torch.xpu.is_available(): - is_initialized, error_message = ipex_init() - if not is_initialized: - print("failed to initialize ipex:", error_message) - except Exception as e: - print("failed to initialize ipex:", e) diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 3963e9b15..5717233d4 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -17,7 +17,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.utils import logging - try: from diffusers.utils import PIL_INTERPOLATION except ImportError: @@ -626,7 +625,7 @@ def check_inputs(self, prompt, height, width, strength, callback_steps): raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: - print(height, width) + logger.info(f'{height} {width}') raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( diff --git a/library/model_util.py b/library/model_util.py index 4361b4994..be410a026 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -3,16 +3,20 @@ import math import os -import torch - -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex init_ipex() + import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel from safetensors.torch import load_file, save_file from library.original_unet import UNet2DConditionModel +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # DiffUsers版StableDiffusionのモデルパラメータ NUM_TRAIN_TIMESTEPS = 1000 @@ -944,7 +948,7 @@ def convert_vae_state_dict(vae_state_dict): for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: - # print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") + # logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict @@ -1002,7 +1006,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt unet = UNet2DConditionModel(**unet_config).to(device) info = unet.load_state_dict(converted_unet_checkpoint) - print("loading u-net:", info) + logger.info(f"loading u-net: {info}") # Convert the VAE model. vae_config = create_vae_diffusers_config() @@ -1010,7 +1014,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt vae = AutoencoderKL(**vae_config).to(device) info = vae.load_state_dict(converted_vae_checkpoint) - print("loading vae:", info) + logger.info(f"loading vae: {info}") # convert text_model if v2: @@ -1044,7 +1048,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt # logging.set_verbosity_error() # don't show annoying warning # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) # logging.set_verbosity_warning() - # print(f"config: {text_model.config}") + # logger.info(f"config: {text_model.config}") cfg = CLIPTextConfig( vocab_size=49408, hidden_size=768, @@ -1067,7 +1071,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt ) text_model = CLIPTextModel._from_config(cfg) info = text_model.load_state_dict(converted_text_encoder_checkpoint) - print("loading text encoder:", info) + logger.info(f"loading text encoder: {info}") return text_model, vae, unet @@ -1142,7 +1146,7 @@ def convert_key(key): # 最後の層などを捏造するか if make_dummy_weights: - print("make dummy weights for resblock.23, text_projection and logit scale.") + logger.info("make dummy weights for resblock.23, text_projection and logit scale.") keys = list(new_sd.keys()) for key in keys: if key.startswith("transformer.resblocks.22."): @@ -1261,14 +1265,14 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod def load_vae(vae_id, dtype): - print(f"load VAE: {vae_id}") + logger.info(f"load VAE: {vae_id}") if os.path.isdir(vae_id) or not os.path.isfile(vae_id): # Diffusers local/remote try: vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) except EnvironmentError as e: - print(f"exception occurs in loading vae: {e}") - print("retry with subfolder='vae'") + logger.error(f"exception occurs in loading vae: {e}") + logger.error("retry with subfolder='vae'") vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) return vae @@ -1340,13 +1344,13 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64) if __name__ == "__main__": resos = make_bucket_resolutions((512, 768)) - print(len(resos)) - print(resos) + logger.info(f"{len(resos)}") + logger.info(f"{resos}") aspect_ratios = [w / h for w, h in resos] - print(aspect_ratios) + logger.info(f"{aspect_ratios}") ars = set() for ar in aspect_ratios: if ar in ars: - print("error! duplicate ar:", ar) + logger.error(f"error! duplicate ar: {ar}") ars.add(ar) diff --git a/library/original_unet.py b/library/original_unet.py index 030c5c9ec..e944ff22b 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -113,6 +113,10 @@ from torch import nn from torch.nn import functional as F from einops import rearrange +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0] @@ -1380,7 +1384,7 @@ def __init__( ): super().__init__() assert sample_size is not None, "sample_size must be specified" - print( + logger.info( f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}" ) @@ -1514,7 +1518,7 @@ def set_use_sdpa(self, sdpa: bool) -> None: def set_gradient_checkpointing(self, value=False): modules = self.down_blocks + [self.mid_block] + self.up_blocks for module in modules: - print(module.__class__.__name__, module.gradient_checkpointing, "->", value) + logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}") module.gradient_checkpointing = value # endregion @@ -1709,14 +1713,14 @@ def __call__(self, *args, **kwargs): def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): if ds_depth_1 is None: - print("Deep Shrink is disabled.") + logger.info("Deep Shrink is disabled.") self.ds_depth_1 = None self.ds_timesteps_1 = None self.ds_depth_2 = None self.ds_timesteps_2 = None self.ds_ratio = None else: - print( + logger.info( f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" ) self.ds_depth_1 = ds_depth_1 diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 472686ba4..bf546a1b1 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -5,6 +5,12 @@ import os from typing import List, Optional, Tuple, Union import safetensors +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) r""" # Metadata Example @@ -51,11 +57,13 @@ ARCH_SD_V2_512 = "stable-diffusion-v2-512" ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" +ARCH_STABLE_CASCADE = "stable-cascade" ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" +IMPL_STABILITY_AI_STABLE_CASCADE = "https://github.com/Stability-AI/StableCascade" IMPL_DIFFUSERS = "diffusers" PRED_TYPE_EPSILON = "epsilon" @@ -109,6 +117,7 @@ def build_metadata( merged_from: Optional[str] = None, timesteps: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, + stable_cascade: Optional[bool] = None, ): # if state_dict is None, hash is not calculated @@ -120,7 +129,9 @@ def build_metadata( # hash = precalculate_safetensors_hashes(state_dict) # metadata["modelspec.hash_sha256"] = hash - if sdxl: + if stable_cascade: + arch = ARCH_STABLE_CASCADE + elif sdxl: arch = ARCH_SD_XL_V1_BASE elif v2: if v_parameterization: @@ -138,9 +149,11 @@ def build_metadata( metadata["modelspec.architecture"] = arch if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: - is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion + is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion - if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: + if stable_cascade: + impl = IMPL_STABILITY_AI_STABLE_CASCADE + elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA impl = IMPL_STABILITY_AI else: @@ -231,8 +244,8 @@ def build_metadata( # # assert all values are filled # assert all([v is not None for v in metadata.values()]), metadata if not all([v is not None for v in metadata.values()]): - print(f"Internal error: some metadata values are None: {metadata}") - + logger.error(f"Internal error: some metadata values are None: {metadata}") + return metadata @@ -246,7 +259,7 @@ def get_title(metadata: dict) -> Optional[str]: def load_metadata_from_safetensors(model: str) -> dict: if not model.endswith(".safetensors"): return {} - + with safetensors.safe_open(model, framework="pt") as f: metadata = f.metadata() if metadata is None: diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 08b90c393..f03f1bae5 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -7,7 +7,10 @@ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet - +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) VAE_SCALE_FACTOR = 0.13025 MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" @@ -131,7 +134,7 @@ def convert_key(key): # temporary workaround for text_projection.weight.weight for Playground-v2 if "text_projection.weight.weight" in new_sd: - print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") + logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"] del new_sd["text_projection.weight.weight"] @@ -186,20 +189,20 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty checkpoint = None # U-Net - print("building U-Net") + logger.info("building U-Net") with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() - print("loading U-Net from checkpoint") + logger.info("loading U-Net from checkpoint") unet_sd = {} for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype) - print("U-Net: ", info) + logger.info(f"U-Net: {info}") # Text Encoders - print("building text encoders") + logger.info("building text encoders") # Text Encoder 1 is same to Stability AI's SDXL text_model1_cfg = CLIPTextConfig( @@ -252,7 +255,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty with init_empty_weights(): text_model2 = CLIPTextModelWithProjection(text_model2_cfg) - print("loading text encoders from checkpoint") + logger.info("loading text encoders from checkpoint") te1_sd = {} te2_sd = {} for k in list(state_dict.keys()): @@ -266,22 +269,22 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty te1_sd.pop("text_model.embeddings.position_ids") info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 - print("text encoder 1:", info1) + logger.info(f"text encoder 1: {info1}") converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32 - print("text encoder 2:", info2) + logger.info(f"text encoder 2: {info2}") # prepare vae - print("building VAE") + logger.info("building VAE") vae_config = model_util.create_vae_diffusers_config() with init_empty_weights(): vae = AutoencoderKL(**vae_config) - print("loading VAE from checkpoint") + logger.info("loading VAE from checkpoint") converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config) info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype) - print("VAE:", info) + logger.info(f"VAE: {info}") ckpt_info = (epoch, global_step) if epoch is not None else None return text_model1, text_model2, vae, unet, logit_scale, ckpt_info diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index babda8ec5..673cf9f65 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -30,7 +30,10 @@ from torch import nn from torch.nn import functional as F from einops import rearrange - +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) IN_CHANNELS: int = 4 OUT_CHANNELS: int = 4 @@ -332,7 +335,7 @@ def forward_body(self, x, emb): def forward(self, x, emb): if self.training and self.gradient_checkpointing: - # print("ResnetBlock2D: gradient_checkpointing") + # logger.info("ResnetBlock2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -366,7 +369,7 @@ def forward_body(self, hidden_states): def forward(self, hidden_states): if self.training and self.gradient_checkpointing: - # print("Downsample2D: gradient_checkpointing") + # logger.info("Downsample2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -653,7 +656,7 @@ def forward_body(self, hidden_states, context=None, timestep=None): def forward(self, hidden_states, context=None, timestep=None): if self.training and self.gradient_checkpointing: - # print("BasicTransformerBlock: checkpointing") + # logger.info("BasicTransformerBlock: checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -796,7 +799,7 @@ def forward_body(self, hidden_states, output_size=None): def forward(self, hidden_states, output_size=None): if self.training and self.gradient_checkpointing: - # print("Upsample2D: gradient_checkpointing") + # logger.info("Upsample2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -1046,7 +1049,7 @@ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> N for block in blocks: for module in block: if hasattr(module, "set_use_memory_efficient_attention"): - # print(module.__class__.__name__) + # logger.info(module.__class__.__name__) module.set_use_memory_efficient_attention(xformers, mem_eff) def set_use_sdpa(self, sdpa: bool) -> None: @@ -1061,7 +1064,7 @@ def set_gradient_checkpointing(self, value=False): for block in blocks: for module in block.modules(): if hasattr(module, "gradient_checkpointing"): - # print(module.__class__.__name__, module.gradient_checkpointing, "->", value) + # logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}") module.gradient_checkpointing = value # endregion @@ -1083,7 +1086,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): def call_module(module, h, emb, context): x = h for layer in module: - # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) + # logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) if isinstance(layer, ResnetBlock2D): x = layer(x, emb) elif isinstance(layer, Transformer2DModel): @@ -1135,14 +1138,14 @@ def __call__(self, *args, **kwargs): def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): if ds_depth_1 is None: - print("Deep Shrink is disabled.") + logger.info("Deep Shrink is disabled.") self.ds_depth_1 = None self.ds_timesteps_1 = None self.ds_depth_2 = None self.ds_timesteps_2 = None self.ds_ratio = None else: - print( + logger.info( f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" ) self.ds_depth_1 = ds_depth_1 @@ -1229,7 +1232,7 @@ def call_module(module, h, emb, context): if __name__ == "__main__": import time - print("create unet") + logger.info("create unet") unet = SdxlUNet2DConditionModel() unet.to("cuda") @@ -1238,7 +1241,7 @@ def call_module(module, h, emb, context): unet.train() # 使用メモリ量確認用の疑似学習ループ - print("preparing optimizer") + logger.info("preparing optimizer") # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working @@ -1253,12 +1256,12 @@ def call_module(module, h, emb, context): scaler = torch.cuda.amp.GradScaler(enabled=True) - print("start training") + logger.info("start training") steps = 10 batch_size = 1 for step in range(steps): - print(f"step {step}") + logger.info(f"step {step}") if step == 1: time_start = time.perf_counter() @@ -1278,4 +1281,4 @@ def call_module(module, h, emb, context): optimizer.zero_grad(set_to_none=True) time_end = time.perf_counter() - print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") + logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 5ad748d15..1932bf881 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -1,14 +1,21 @@ import argparse -import gc import math import os from typing import Optional + import torch +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() + from accelerate import init_empty_weights from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) TOKENIZER1_PATH = "openai/clip-vit-large-patch14" TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" @@ -21,7 +28,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: - print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") ( load_stable_diffusion_format, @@ -47,8 +54,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): unet.to(accelerator.device) vae.to(accelerator.device) - gc.collect() - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info @@ -62,7 +68,7 @@ def _load_target_model( load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: - print(f"load StableDiffusion checkpoint: {name_or_path}") + logger.info(f"load StableDiffusion checkpoint: {name_or_path}") ( text_encoder1, text_encoder2, @@ -76,7 +82,7 @@ def _load_target_model( from diffusers import StableDiffusionXLPipeline variant = "fp16" if weight_dtype == torch.float16 else None - print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") + logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") try: try: pipe = StableDiffusionXLPipeline.from_pretrained( @@ -84,12 +90,12 @@ def _load_target_model( ) except EnvironmentError as ex: if variant is not None: - print("try to load fp32 model") + logger.info("try to load fp32 model") pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None) else: raise ex except EnvironmentError as ex: - print( + logger.error( f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" ) raise ex @@ -112,7 +118,7 @@ def _load_target_model( with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype) - print("U-Net converted to original U-Net") + logger.info("U-Net converted to original U-Net") logit_scale = None ckpt_info = None @@ -120,13 +126,13 @@ def _load_target_model( # VAEを読み込む if vae_path is not None: vae = model_util.load_vae(vae_path, weight_dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info def load_tokenizers(args: argparse.Namespace): - print("prepare tokenizers") + logger.info("prepare tokenizers") original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH] tokeniers = [] @@ -135,14 +141,14 @@ def load_tokenizers(args: argparse.Namespace): if args.tokenizer_cache_dir: local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) if os.path.exists(local_tokenizer_path): - print(f"load tokenizer from cache: {local_tokenizer_path}") + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) if tokenizer is None: tokenizer = CLIPTokenizer.from_pretrained(original_path) if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - print(f"save Tokenizer to cache: {local_tokenizer_path}") + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") tokenizer.save_pretrained(local_tokenizer_path) if i == 1: @@ -151,7 +157,7 @@ def load_tokenizers(args: argparse.Namespace): tokeniers.append(tokenizer) if hasattr(args, "max_token_length") and args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") + logger.info(f"update token length: {args.max_token_length}") return tokeniers @@ -332,23 +338,23 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.v_parameterization: - print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") + logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") if args.clip_skip is not None: - print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") + logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") # if args.multires_noise_iterations: - # print( + # logger.info( # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" # ) # else: # if args.noise_offset is None: # args.noise_offset = DEFAULT_NOISE_OFFSET # elif args.noise_offset != DEFAULT_NOISE_OFFSET: - # print( + # logger.info( # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" # ) - # print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") assert ( not hasattr(args, "weighted_captions") or not args.weighted_captions @@ -357,7 +363,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin if supportTextEncoderCaching: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: args.cache_text_encoder_outputs = True - print( + logger.warning( "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" ) diff --git a/library/slicing_vae.py b/library/slicing_vae.py index 5c4e056d3..ea7653429 100644 --- a/library/slicing_vae.py +++ b/library/slicing_vae.py @@ -26,7 +26,10 @@ from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution from diffusers.models.autoencoder_kl import AutoencoderKLOutput - +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def slice_h(x, num_slices): # slice with pad 1 both sides: to eliminate side effect of padding of conv2d @@ -89,7 +92,7 @@ def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs): # sliced_tensor = torch.chunk(x, num_div, dim=1) # sliced_weight = torch.chunk(norm.weight, num_div, dim=0) # sliced_bias = torch.chunk(norm.bias, num_div, dim=0) - # print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape) + # logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape) # normed_tensor = [] # for i in range(num_div): # n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps) @@ -243,7 +246,7 @@ def forward(*args, **kwargs): self.num_slices = num_slices div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす - # print(f"initial divisor: {div}") + # logger.info(f"initial divisor: {div}") if div >= 2: div = int(div) for resnet in self.mid_block.resnets: @@ -253,11 +256,11 @@ def forward(*args, **kwargs): for i, down_block in enumerate(self.down_blocks[::-1]): if div >= 2: div = int(div) - # print(f"down block: {i} divisor: {div}") + # logger.info(f"down block: {i} divisor: {div}") for resnet in down_block.resnets: resnet.forward = wrapper(resblock_forward, resnet, div) if down_block.downsamplers is not None: - # print("has downsample") + # logger.info("has downsample") for downsample in down_block.downsamplers: downsample.forward = wrapper(self.downsample_forward, downsample, div * 2) div *= 2 @@ -307,7 +310,7 @@ def forward(self, x): def downsample_forward(self, _self, num_slices, hidden_states): assert hidden_states.shape[1] == _self.channels assert _self.use_conv and _self.padding == 0 - print("downsample forward", num_slices, hidden_states.shape) + logger.info(f"downsample forward {num_slices} {hidden_states.shape}") org_device = hidden_states.device cpu_device = torch.device("cpu") @@ -350,7 +353,7 @@ def downsample_forward(self, _self, num_slices, hidden_states): hidden_states = torch.cat([hidden_states, x], dim=2) hidden_states = hidden_states.to(org_device) - # print("downsample forward done", hidden_states.shape) + # logger.info(f"downsample forward done {hidden_states.shape}") return hidden_states @@ -426,7 +429,7 @@ def forward(*args, **kwargs): self.num_slices = num_slices div = num_slices / (2 ** (len(self.up_blocks) - 1)) - print(f"initial divisor: {div}") + logger.info(f"initial divisor: {div}") if div >= 2: div = int(div) for resnet in self.mid_block.resnets: @@ -436,11 +439,11 @@ def forward(*args, **kwargs): for i, up_block in enumerate(self.up_blocks): if div >= 2: div = int(div) - # print(f"up block: {i} divisor: {div}") + # logger.info(f"up block: {i} divisor: {div}") for resnet in up_block.resnets: resnet.forward = wrapper(resblock_forward, resnet, div) if up_block.upsamplers is not None: - # print("has upsample") + # logger.info("has upsample") for upsample in up_block.upsamplers: upsample.forward = wrapper(self.upsample_forward, upsample, div * 2) div *= 2 @@ -528,7 +531,7 @@ def upsample_forward(self, _self, num_slices, hidden_states, output_size=None): del x hidden_states = torch.cat(sliced, dim=2) - # print("us hidden_states", hidden_states.shape) + # logger.info(f"us hidden_states {hidden_states.shape}") del sliced hidden_states = hidden_states.to(org_device) diff --git a/library/stable_cascade.py b/library/stable_cascade.py new file mode 100644 index 000000000..cbccaa9b3 --- /dev/null +++ b/library/stable_cascade.py @@ -0,0 +1,1512 @@ +# コードは Stable Cascade からコピーし、一部修正しています。元ライセンスは MIT です。 +# The code is copied from Stable Cascade and modified. The original license is MIT. +# https://github.com/Stability-AI/StableCascade + +import math +from types import SimpleNamespace +from typing import List, Optional +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint +import torchvision + + +# region VectorQuantize + +# from torchtools https://github.com/pabloppp/pytorch-tools +# 依存ライブラリを増やしたくないのでここにコピペ + + +class vector_quantize(torch.autograd.Function): + @staticmethod + def forward(ctx, x, codebook): + with torch.no_grad(): + codebook_sqr = torch.sum(codebook**2, dim=1) + x_sqr = torch.sum(x**2, dim=1, keepdim=True) + + dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0) + _, indices = dist.min(dim=1) + + ctx.save_for_backward(indices, codebook) + ctx.mark_non_differentiable(indices) + + nn = torch.index_select(codebook, 0, indices) + return nn, indices + + @staticmethod + def backward(ctx, grad_output, grad_indices): + grad_inputs, grad_codebook = None, None + + if ctx.needs_input_grad[0]: + grad_inputs = grad_output.clone() + if ctx.needs_input_grad[1]: + # Gradient wrt. the codebook + indices, codebook = ctx.saved_tensors + + grad_codebook = torch.zeros_like(codebook) + grad_codebook.index_add_(0, indices, grad_output) + + return (grad_inputs, grad_codebook) + + +class VectorQuantize(nn.Module): + def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False): + """ + Takes an input of variable size (as long as the last dimension matches the embedding size). + Returns one tensor containing the nearest neighbour embeddings to each of the inputs, + with the same size as the input, vq and commitment components for the loss as a tuple + in the second output and the indices of the quantized vectors in the third: + quantized, (vq_loss, commit_loss), indices + """ + super(VectorQuantize, self).__init__() + + self.codebook = nn.Embedding(k, embedding_size) + self.codebook.weight.data.uniform_(-1.0 / k, 1.0 / k) + self.vq = vector_quantize.apply + + self.ema_decay = ema_decay + self.ema_loss = ema_loss + if ema_loss: + self.register_buffer("ema_element_count", torch.ones(k)) + self.register_buffer("ema_weight_sum", torch.zeros_like(self.codebook.weight)) + + def _laplace_smoothing(self, x, epsilon): + n = torch.sum(x) + return (x + epsilon) / (n + x.size(0) * epsilon) * n + + def _updateEMA(self, z_e_x, indices): + mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() + elem_count = mask.sum(dim=0) + weight_sum = torch.mm(mask.t(), z_e_x) + + self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1 - self.ema_decay) * elem_count) + self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5) + self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1 - self.ema_decay) * weight_sum) + + self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) + + def idx2vq(self, idx, dim=-1): + q_idx = self.codebook(idx) + if dim != -1: + q_idx = q_idx.movedim(-1, dim) + return q_idx + + def forward(self, x, get_losses=True, dim=-1): + if dim != -1: + x = x.movedim(dim, -1) + z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x + z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach()) + vq_loss, commit_loss = None, None + if self.ema_loss and self.training: + self._updateEMA(z_e_x.detach(), indices.detach()) + # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss + z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices) + if get_losses: + vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean() + commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean() + + z_q_x = z_q_x.view(x.shape) + if dim != -1: + z_q_x = z_q_x.movedim(-1, dim) + return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1]) + + +# endregion + + +class EfficientNetEncoder(nn.Module): + def __init__(self, c_latent=16): + super().__init__() + self.backbone = torchvision.models.efficientnet_v2_s(weights="DEFAULT").features.eval() + self.mapper = nn.Sequential( + nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 + ) + + def forward(self, x): + return self.mapper(self.backbone(x)) + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def encode(self, x): + """ + VAE と同じように使えるようにするためのメソッド。正しくはちゃんと呼び出し側で分けるべきだが、暫定的な対応。 + The method to make it usable like VAE. It should be separated properly, but it is a temporary response. + """ + # latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + x = self(x) + return SimpleNamespace(latent_dist=SimpleNamespace(sample=lambda: x)) + + +# なんかわりと乱暴な実装(;'∀') +# 一から学習することもないだろうから、無効化しておく + +# class Linear(torch.nn.Linear): +# def reset_parameters(self): +# return None + +# class Conv2d(torch.nn.Conv2d): +# def reset_parameters(self): +# return None + +from torch.nn import Conv2d +from torch.nn import Linear + + +class Attention2D(nn.Module): + def __init__(self, c, nhead, dropout=0.0): + super().__init__() + self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + kv = torch.cat([x, kv], dim=1) + x = self.attn(x, kv, kv, need_weights=False)[0] + x = x.permute(0, 2, 1).view(*orig_shape) + return x + + +class LayerNorm2d(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + +class GlobalResponseNorm(nn.Module): + "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): # , num_heads=4, expansion=2): + super().__init__() + self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + # self.depthwise = SAMBlock(c, num_heads, expansion) + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + Linear(c + c_skip, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), Linear(c * 4, c) + ) + + self.gradient_checkpointing = False + + def set_gradient_checkpointing(self, value): + self.gradient_checkpointing = value + + def forward_body(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + + def forward(self, x, x_skip=None): + if self.training and self.gradient_checkpointing: + # logger.info("ResnetBlock2D: gradient_checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, x_skip) + else: + x = self.forward_body(x, x_skip) + + return x + + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential(nn.SiLU(), Linear(c_cond, c)) + + self.gradient_checkpointing = False + + def set_gradient_checkpointing(self, value): + self.gradient_checkpointing = value + + def forward_body(self, x, kv): + kv = self.kv_mapper(kv) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x + + def forward(self, x, kv): + if self.training and self.gradient_checkpointing: + # logger.info("AttnBlock: gradient_checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, kv) + else: + x = self.forward_body(x, kv) + + return x + + +class FeedForwardBlock(nn.Module): + def __init__(self, c, dropout=0.0): + super().__init__() + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), Linear(c * 4, c) + ) + + self.gradient_checkpointing = False + + def set_gradient_checkpointing(self, value): + self.gradient_checkpointing = value + + def forward_body(self, x): + x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + + def forward(self, x): + if self.training and self.gradient_checkpointing: + # logger.info("FeedForwardBlock: gradient_checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x) + else: + x = self.forward_body(x) + + return x + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep, conds=["sca"]): + super().__init__() + self.mapper = Linear(c_timestep, c * 2) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) + + def forward(self, x, t): + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b + + +class UpDownBlock2d(nn.Module): + def __init__(self, c_in, c_out, mode, enabled=True): + super().__init__() + assert mode in ["up", "down"] + interpolation = ( + nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True) if enabled else nn.Identity() + ) + mapping = nn.Conv2d(c_in, c_out, kernel_size=1) + self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation]) + + self.mode = mode + + self.gradient_checkpointing = False + + def set_gradient_checkpointing(self, value): + self.gradient_checkpointing = value + + def forward_body(self, x): + org_dtype = x.dtype + for i, block in enumerate(self.blocks): + # 公式の実装では、常に float で計算しているが、すこしでもメモリを節約するために bfloat16 + Upsample のみ float に変換する + # In the official implementation, it always calculates in float, but for the sake of saving memory, it converts to float only for bfloat16 + Upsample + if x.dtype == torch.bfloat16 and (self.mode == "up" and i == 0 or self.mode != "up" and i == 1): + x = x.float() + x = block(x) + x = x.to(org_dtype) + return x + + def forward(self, x): + if self.training and self.gradient_checkpointing: + # logger.info("UpDownBlock2d: gradient_checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x) + else: + x = self.forward_body(x) + + return x + + +class StageAResBlock(nn.Module): + def __init__(self, c, c_hidden): + super().__init__() + # depthwise/attention + self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential(nn.ReplicationPad2d(1), nn.Conv2d(c, c, kernel_size=3, groups=c)) + + # channelwise + self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c_hidden), + nn.GELU(), + nn.Linear(c_hidden, c), + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + # Init weights + def _basic_init(module): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + def _norm(self, x, norm): + return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + def forward(self, x): + mods = self.gammas + + x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] + x = x + self.depthwise(x_temp) * mods[2] + + x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + + return x + + +class StageA(nn.Module): + def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, scale_factor=0.43): # 0.3764 + super().__init__() + self.c_latent = c_latent + self.scale_factor = scale_factor + c_levels = [c_hidden // (2**i) for i in reversed(range(levels))] + + # Encoder blocks + self.in_block = nn.Sequential(nn.PixelUnshuffle(2), nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)) + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = StageAResBlock(c_levels[i], c_levels[i] * 4) + down_blocks.append(block) + down_blocks.append( + nn.Sequential( + nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + ) + ) + self.down_blocks = nn.Sequential(*down_blocks) + self.down_blocks[0] + + self.codebook_size = codebook_size + self.vquantizer = VectorQuantize(c_latent, k=codebook_size) + + # Decoder blocks + up_blocks = [nn.Sequential(nn.Conv2d(c_latent, c_levels[-1], kernel_size=1))] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = StageAResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) + up_blocks.append(block) + if i < levels - 1: + up_blocks.append( + nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1) + ) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), + nn.PixelShuffle(2), + ) + + def encode(self, x, quantize=False): + x = self.in_block(x) + x = self.down_blocks(x) + if quantize: + qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) + return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 + else: + return x / self.scale_factor, None, None, None + + def decode(self, x): + x = x * self.scale_factor + x = self.up_blocks(x) + x = self.out_block(x) + return x + + def forward(self, x, quantize=False): + qe, x, _, vq_loss = self.encode(x, quantize) + x = self.decode(qe) + return x, vq_loss + + +r""" + +https://github.com/Stability-AI/StableCascade/blob/master/configs/inference/stage_b_3b.yaml + +# GLOBAL STUFF +model_version: 3B +dtype: bfloat16 + +# For demonstration purposes in reconstruct_images.ipynb +webdataset_path: file:inference/imagenet_1024.tar +batch_size: 4 +image_size: 1024 +grad_accum_steps: 1 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +stage_a_checkpoint_path: models/stage_a.safetensors +generator_checkpoint_path: models/stage_b_bf16.safetensors +""" + + +class StageB(nn.Module): + def __init__( + self, + c_in=4, + c_out=4, + c_r=64, + patch_size=2, + c_cond=1280, + c_hidden=[320, 640, 1280, 1280], + nhead=[-1, -1, 20, 20], + blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], + block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], + level_config=["CT", "CT", "CTA", "CTA"], + c_clip=1280, + c_clip_seq=4, + c_effnet=16, + c_pixels=3, + kernel_size=3, + dropout=[0, 0, 0.1, 0.1], + self_attn=True, + t_conds=["sca"], + ): + super().__init__() + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.effnet_mapper = nn.Sequential( + nn.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1), + nn.GELU(), + nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + ) + self.pixels_mapper = nn.Sequential( + nn.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1), + nn.GELU(), + nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + ) + self.clip_mapper = nn.Linear(c_clip, c_cond * c_clip_seq) + self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == "C": + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == "A": + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == "F": + return FeedForwardBlock(c_hidden, dropout=dropout) + elif block_type == "T": + return TimestepBlock(c_hidden, c_r, conds=t_conds) + else: + raise Exception(f"Block type {block_type} not supported") + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append( + nn.Sequential( + LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2), + ) + ) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append( + nn.Sequential( + LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), + nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2), + ) + ) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[0], c_out * (patch_size**2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + self.apply(self._init_weights) # General init + nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings + nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings + nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings + nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings + torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + elif isinstance(block, TimestepBlock): + for layer in block.modules(): + if isinstance(layer, nn.Linear): + nn.init.constant_(layer.weight, 0) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + return emb + + def gen_c_embeddings(self, clip): + if len(clip.shape) == 2: + clip = clip.unsqueeze(1) + clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock) + ): + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock) + ): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock) + ): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock) + ): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode="bilinear", align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock) + ): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock) + ): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, effnet, clip, pixels=None, **kwargs): + if pixels is None: + pixels = x.new_zeros(x.size(0), 3, 8, 8) + + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) + clip = self.gen_c_embeddings(clip) + + # Model Blocks + x = self.embedding(x) + x = x + self.effnet_mapper( + nn.functional.interpolate(effnet.float(), size=x.shape[-2:], mode="bilinear", align_corners=True) + ) + x = x + nn.functional.interpolate( + self.pixels_mapper(pixels).float(), size=x.shape[-2:], mode="bilinear", align_corners=True + ) + level_outputs = self._down_encode(x, r_embed, clip) + x = self._up_decode(level_outputs, r_embed, clip) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) + + +r""" + +https://github.com/Stability-AI/StableCascade/blob/master/configs/inference/stage_c_3b.yaml + +# GLOBAL STUFF +model_version: 3.6B +dtype: bfloat16 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors +""" + + +class StageC(nn.Module): + def __init__( + self, + c_in=16, + c_out=16, + c_r=64, + patch_size=1, + c_cond=2048, + c_hidden=[2048, 2048], + nhead=[32, 32], + blocks=[[8, 24], [24, 8]], + block_repeat=[[1, 1], [1, 1]], + level_config=["CTA", "CTA"], + c_clip_text=1280, + c_clip_text_pooled=1280, + c_clip_img=768, + c_clip_seq=4, + kernel_size=3, + dropout=[0.1, 0.1], + self_attn=True, + t_conds=["sca", "crp"], + switch_level=[False], + ): + super().__init__() + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond) + self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq) + self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq) + self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == "C": + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == "A": + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == "F": + return FeedForwardBlock(c_hidden, dropout=dropout) + elif block_type == "T": + return TimestepBlock(c_hidden, c_r, conds=t_conds) + else: + raise Exception(f"Block type {block_type} not supported") + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append( + nn.Sequential( + LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode="down", enabled=switch_level[i - 1]), + ) + ) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append( + nn.Sequential( + LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode="up", enabled=switch_level[i - 1]), + ) + ) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[0], c_out * (patch_size**2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + self.apply(self._init_weights) # General init + nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings + torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + elif isinstance(block, TimestepBlock): + for layer in block.modules(): + if isinstance(layer, nn.Linear): + nn.init.constant_(layer.weight, 0) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def set_gradient_checkpointing(self, value): + for block in self.down_blocks + self.up_blocks: + for layer in block: + if hasattr(layer, "set_gradient_checkpointing"): + layer.set_gradient_checkpointing(value) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + return emb + + def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): + clip_txt = self.clip_txt_mapper(clip_txt) + if len(clip_txt_pooled.shape) == 2: + clip_txt_pool = clip_txt_pooled.unsqueeze(1) + if len(clip_img.shape) == 2: + clip_img = clip_img.unsqueeze(1) + clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view( + clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1 + ) + clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) + clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip, cnet=None): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock) + ): + if cnet is not None: + next_cnet = cnet() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode="bilinear", align_corners=True) + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock) + ): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock) + ): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip, cnet=None): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock) + ): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode="bilinear", align_corners=True) + if cnet is not None: + next_cnet = cnet() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode="bilinear", align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock) + ): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock) + ): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs): + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) + clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) + + # Model Blocks + x = self.embedding(x) + # ControlNet is not supported yet + # if cnet is not None: + # cnet = ControlNetDeliverer(cnet) + level_outputs = self._down_encode(x, r_embed, clip, cnet) + x = self._up_decode(level_outputs, r_embed, clip, cnet) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 +class Previewer(nn.Module): + def __init__(self, c_in=16, c_hidden=512, c_out=3): + super().__init__() + self.blocks = nn.Sequential( + nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels + nn.GELU(), + nn.BatchNorm2d(c_hidden), + nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden), + nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), + ) + + def forward(self, x): + return self.blocks(x) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +def get_clip_conditions(captions: Optional[List[str]], input_ids, tokenizer, text_model): + # deprecated + + # self, batch: dict, tokenizer, text_model, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None + # is_eval の処理をここでやるのは微妙なので別のところでやる + # is_unconditional もここでやるのは微妙なので別のところでやる + # clip_image はとりあえずサポートしない + if captions is not None: + clip_tokens_unpooled = tokenizer( + captions, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" + ).to(text_model.device) + text_encoder_output = text_model(**clip_tokens_unpooled, output_hidden_states=True) + else: + text_encoder_output = text_model(input_ids, output_hidden_states=True) + + text_embeddings = text_encoder_output.hidden_states[-1] + text_pooled_embeddings = text_encoder_output.text_embeds.unsqueeze(1) + + return text_embeddings, text_pooled_embeddings + # return {"clip_text": text_embeddings, "clip_text_pooled": text_pooled_embeddings} # , "clip_img": image_embeddings} + + +# region gdf + + +class SimpleSampler: + def __init__(self, gdf): + self.gdf = gdf + self.current_step = -1 + + def __call__(self, *args, **kwargs): + self.current_step += 1 + return self.step(*args, **kwargs) + + def init_x(self, shape): + return torch.randn(*shape) + + def step(self, x, x0, epsilon, logSNR, logSNR_prev): + raise NotImplementedError("You should override the 'apply' function.") + + +class DDIMSampler(SimpleSampler): + def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=0): + a, b = self.gdf.input_scaler(logSNR) + if len(a.shape) == 1: + a, b = a.view(-1, *[1] * (len(x0.shape) - 1)), b.view(-1, *[1] * (len(x0.shape) - 1)) + + a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) + if len(a_prev.shape) == 1: + a_prev, b_prev = a_prev.view(-1, *[1] * (len(x0.shape) - 1)), b_prev.view(-1, *[1] * (len(x0.shape) - 1)) + + sigma_tau = eta * (b_prev**2 / b**2).sqrt() * (1 - a**2 / a_prev**2).sqrt() if eta > 0 else 0 + # x = a_prev * x0 + (1 - a_prev**2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) + x = a_prev * x0 + (b_prev**2 - sigma_tau**2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) + return x + + +class DDPMSampler(DDIMSampler): + def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=1): + return super().step(x, x0, epsilon, logSNR, logSNR_prev, eta) + + +class LCMSampler(SimpleSampler): + def step(self, x, x0, epsilon, logSNR, logSNR_prev): + a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) + if len(a_prev.shape) == 1: + a_prev, b_prev = a_prev.view(-1, *[1] * (len(x0.shape) - 1)), b_prev.view(-1, *[1] * (len(x0.shape) - 1)) + return x0 * a_prev + torch.randn_like(epsilon) * b_prev + + +class GDF: + def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, offset_noise=0): + self.schedule = schedule + self.input_scaler = input_scaler + self.target = target + self.noise_cond = noise_cond + self.loss_weight = loss_weight + self.offset_noise = offset_noise + + def setup_limits(self, stretch_max=True, stretch_min=True, shift=1): + stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min, shift) + return stretched_limits + + def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None): + if epsilon is None: + epsilon = torch.randn_like(x0) + if self.offset_noise > 0: + if offset is None: + offset = torch.randn([x0.size(0), x0.size(1)] + [1] * (len(x0.shape) - 2)).to(x0.device) + epsilon = epsilon + offset * self.offset_noise + logSNR = self.schedule(x0.size(0) if t is None else t, shift=shift).to(x0.device) + a, b = self.input_scaler(logSNR) # B + if len(a.shape) == 1: + a, b = a.view(-1, *[1] * (len(x0.shape) - 1)), b.view(-1, *[1] * (len(x0.shape) - 1)) # BxCxHxW + target = self.target(x0, epsilon, logSNR, a, b) + + # noised, noise, logSNR, t_cond + return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR, shift=loss_shift) + + def undiffuse(self, x, logSNR, pred): + a, b = self.input_scaler(logSNR) + if len(a.shape) == 1: + a, b = a.view(-1, *[1] * (len(x.shape) - 1)), b.view(-1, *[1] * (len(x.shape) - 1)) + return self.target.x0(x, pred, logSNR, a, b), self.target.epsilon(x, pred, logSNR, a, b) + + def sample( + self, + model, + model_inputs, + shape, + unconditional_inputs=None, + sampler=None, + schedule=None, + t_start=1.0, + t_end=0.0, + timesteps=20, + x_init=None, + cfg=3.0, + cfg_t_stop=None, + cfg_t_start=None, + cfg_rho=0.7, + sampler_params=None, + shift=1, + device="cpu", + ): + sampler_params = {} if sampler_params is None else sampler_params + if sampler is None: + sampler = DDPMSampler(self) + r_range = torch.linspace(t_start, t_end, timesteps + 1) + schedule = self.schedule if schedule is None else schedule + logSNR_range = schedule(r_range, shift=shift)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(device) + + x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone() + if cfg is not None: + if unconditional_inputs is None: + unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} + model_inputs = { + k: ( + torch.cat([v, v_u], dim=0) + if isinstance(v, torch.Tensor) + else ( + [ + ( + torch.cat([vi, vi_u], dim=0) + if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) + else None + ) + for vi, vi_u in zip(v, v_u) + ] + if isinstance(v, list) + else ( + {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} + if isinstance(v, dict) + else None + ) + ) + ) + for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items()) + } + for i in range(0, timesteps): + noise_cond = self.noise_cond(logSNR_range[i]) + if ( + cfg is not None + and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) + and (cfg_t_start is None or r_range[i].item() <= cfg_t_start) + ): + cfg_val = cfg + if isinstance(cfg_val, (list, tuple)): + assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2" + cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1 - r_range[i].item()) + pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(2) + pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val) + if cfg_rho > 0: + std_pos, std_cfg = pred.std(), pred_cfg.std() + pred = cfg_rho * (pred_cfg * std_pos / (std_cfg + 1e-9)) + pred_cfg * (1 - cfg_rho) + else: + pred = pred_cfg + else: + pred = model(x, noise_cond, **model_inputs) + x0, epsilon = self.undiffuse(x, logSNR_range[i], pred) + x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i + 1], **sampler_params) + altered_vars = yield (x0, x, pred) + + # Update some running variables if the user wants + if altered_vars is not None: + cfg = altered_vars.get("cfg", cfg) + cfg_rho = altered_vars.get("cfg_rho", cfg_rho) + sampler = altered_vars.get("sampler", sampler) + model_inputs = altered_vars.get("model_inputs", model_inputs) + x = altered_vars.get("x", x) + x_init = altered_vars.get("x_init", x_init) + + +class BaseSchedule: + def __init__(self, *args, force_limits=True, discrete_steps=None, shift=1, **kwargs): + self.setup(*args, **kwargs) + self.limits = None + self.discrete_steps = discrete_steps + self.shift = shift + if force_limits: + self.reset_limits() + + def reset_limits(self, shift=1, disable=False): + try: + self.limits = None if disable else self(torch.tensor([1.0, 0.0]), shift=shift).tolist() # min, max + return self.limits + except Exception: + print("WARNING: this schedule doesn't support t and will be unbounded") + return None + + def setup(self, *args, **kwargs): + raise NotImplementedError("this method needs to be overridden") + + def schedule(self, *args, **kwargs): + raise NotImplementedError("this method needs to be overridden") + + def __call__(self, t, *args, shift=1, **kwargs): + if isinstance(t, torch.Tensor): + batch_size = None + if self.discrete_steps is not None: + if t.dtype != torch.long: + t = (t * (self.discrete_steps - 1)).round().long() + t = t / (self.discrete_steps - 1) + t = t.clamp(0, 1) + else: + batch_size = t + t = None + logSNR = self.schedule(t, batch_size, *args, **kwargs) + if shift * self.shift != 1: + logSNR += 2 * np.log(1 / (shift * self.shift)) + if self.limits is not None: + logSNR = logSNR.clamp(*self.limits) + return logSNR + + +class CosineSchedule(BaseSchedule): + def setup(self, s=0.008, clamp_range=[0.0001, 0.9999], norm_instead=False): + self.s = torch.tensor([s]) + self.clamp_range = clamp_range + self.norm_instead = norm_instead + self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 + + def schedule(self, t, batch_size): + if t is None: + t = (1 - torch.rand(batch_size)).add(0.001).clamp(0.001, 1.0) + s, min_var = self.s.to(t.device), self.min_var.to(t.device) + var = torch.cos((s + t) / (1 + s) * torch.pi * 0.5).clamp(0, 1) ** 2 / min_var + if self.norm_instead: + var = var * (self.clamp_range[1] - self.clamp_range[0]) + self.clamp_range[0] + else: + var = var.clamp(*self.clamp_range) + logSNR = (var / (1 - var)).log() + return logSNR + + +class BaseScaler: + def __init__(self): + self.stretched_limits = None + + def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1): + min_logSNR = schedule(torch.ones(1), shift=shift) + max_logSNR = schedule(torch.zeros(1), shift=shift) + + min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1] + max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0] + self.stretched_limits = [min_a, max_a, min_b, max_b] + return self.stretched_limits + + def stretch_limits(self, a, b): + min_a, max_a, min_b, max_b = self.stretched_limits + return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b) + + def scalers(self, logSNR): + raise NotImplementedError("this method needs to be overridden") + + def __call__(self, logSNR): + a, b = self.scalers(logSNR) + if self.stretched_limits is not None: + a, b = self.stretch_limits(a, b) + return a, b + + +class VPScaler(BaseScaler): + def scalers(self, logSNR): + a_squared = logSNR.sigmoid() + a = a_squared.sqrt() + b = (1 - a_squared).sqrt() + return a, b + + +class EpsilonTarget: + def __call__(self, x0, epsilon, logSNR, a, b): + return epsilon + + def x0(self, noised, pred, logSNR, a, b): + return (noised - pred * b) / a + + def epsilon(self, noised, pred, logSNR, a, b): + return pred + + +class BaseNoiseCond: + def __init__(self, *args, shift=1, clamp_range=None, **kwargs): + clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range + self.shift = shift + self.clamp_range = clamp_range + self.setup(*args, **kwargs) + + def setup(self, *args, **kwargs): + pass # this method is optional, override it if required + + def cond(self, logSNR): + raise NotImplementedError("this method needs to be overridden") + + def __call__(self, logSNR): + if self.shift != 1: + logSNR = logSNR.clone() + 2 * np.log(self.shift) + return self.cond(logSNR).clamp(*self.clamp_range) + + +class CosineTNoiseCond(BaseNoiseCond): + def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999] + self.s = torch.tensor([s]) + self.clamp_range = clamp_range + self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 + + def cond(self, logSNR): + var = logSNR.sigmoid() + var = var.clamp(*self.clamp_range) + s, min_var = self.s.to(var.device), self.min_var.to(var.device) + t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return t + + +# --- Loss Weighting +class BaseLossWeight: + def weight(self, logSNR): + raise NotImplementedError("this method needs to be overridden") + + def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs): + clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range + if shift != 1: + logSNR = logSNR.clone() + 2 * np.log(shift) + return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range) + + +# class ComposedLossWeight(BaseLossWeight): +# def __init__(self, div, mul): +# self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul +# self.div = [div] if isinstance(div, BaseLossWeight) else div + +# def weight(self, logSNR): +# prod, div = 1, 1 +# for m in self.mul: +# prod *= m.weight(logSNR) +# for d in self.div: +# div *= d.weight(logSNR) +# return prod/div + +# class ConstantLossWeight(BaseLossWeight): +# def __init__(self, v=1): +# self.v = v + +# def weight(self, logSNR): +# return torch.ones_like(logSNR) * self.v + +# class SNRLossWeight(BaseLossWeight): +# def weight(self, logSNR): +# return logSNR.exp() + + +class P2LossWeight(BaseLossWeight): + def __init__(self, k=1.0, gamma=1.0, s=1.0): + self.k, self.gamma, self.s = k, gamma, s + + def weight(self, logSNR): + return (self.k + (logSNR * self.s).exp()) ** -self.gamma + + +# class SNRPlusOneLossWeight(BaseLossWeight): +# def weight(self, logSNR): +# return logSNR.exp() + 1 + +# class MinSNRLossWeight(BaseLossWeight): +# def __init__(self, max_snr=5): +# self.max_snr = max_snr + +# def weight(self, logSNR): +# return logSNR.exp().clamp(max=self.max_snr) + +# class MinSNRPlusOneLossWeight(BaseLossWeight): +# def __init__(self, max_snr=5): +# self.max_snr = max_snr + +# def weight(self, logSNR): +# return (logSNR.exp() + 1).clamp(max=self.max_snr) + +# class TruncatedSNRLossWeight(BaseLossWeight): +# def __init__(self, min_snr=1): +# self.min_snr = min_snr + +# def weight(self, logSNR): +# return logSNR.exp().clamp(min=self.min_snr) + +# class SechLossWeight(BaseLossWeight): +# def __init__(self, div=2): +# self.div = div + +# def weight(self, logSNR): +# return 1/(logSNR/self.div).cosh() + +# class DebiasedLossWeight(BaseLossWeight): +# def weight(self, logSNR): +# return 1/logSNR.exp().sqrt() + +# class SigmoidLossWeight(BaseLossWeight): +# def __init__(self, s=1): +# self.s = s + +# def weight(self, logSNR): +# return (logSNR * self.s).sigmoid() + + +class AdaptiveLossWeight(BaseLossWeight): + def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]): + self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets - 1) + self.bucket_losses = torch.ones(buckets) + self.weight_range = weight_range + + def weight(self, logSNR): + indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR) + return (1 / self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range) + + def update_buckets(self, logSNR, loss, beta=0.99): + indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu() + self.bucket_losses[indices] = self.bucket_losses[indices] * beta + loss.detach().cpu() * (1 - beta) + + +# endregion gdf diff --git a/library/stable_cascade_utils.py b/library/stable_cascade_utils.py new file mode 100644 index 000000000..fe2804196 --- /dev/null +++ b/library/stable_cascade_utils.py @@ -0,0 +1,615 @@ +import argparse +import json +import math +import os +import time +from typing import List +import numpy as np +import toml + +import torch +import torchvision +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from transformers import CLIPTokenizer, CLIPTextModelWithProjection, CLIPTextConfig +from accelerate import init_empty_weights, Accelerator, PartialState +from PIL import Image + +from library import stable_cascade as sc + +from library.sdxl_model_util import _load_state_dict_on_device +from library.device_utils import clean_memory_on_device +from library.train_util import ( + save_sd_model_on_epoch_end_or_stepwise_common, + save_sd_model_on_train_end_common, + line_to_prompt_dict, + get_hidden_states_stable_cascade, +) +from library import sai_model_spec + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CLIP_TEXT_MODEL_NAME: str = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + +EFFNET_PREPROCESS = torchvision.transforms.Compose( + [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))] +) + +TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_sc_te_outputs.npz" + + +def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0): + resolution_multiple = 42.67 + latent_height = math.ceil(height / compression_factor_b) + latent_width = math.ceil(width / compression_factor_b) + stage_c_latent_shape = (batch_size, 16, latent_height, latent_width) + + latent_height = math.ceil(height / compression_factor_a) + latent_width = math.ceil(width / compression_factor_a) + stage_b_latent_shape = (batch_size, 4, latent_height, latent_width) + + return stage_c_latent_shape, stage_b_latent_shape + + +# region load and save + + +def load_effnet(effnet_checkpoint_path, loading_device="cpu") -> sc.EfficientNetEncoder: + logger.info(f"Loading EfficientNet encoder from {effnet_checkpoint_path}") + effnet = sc.EfficientNetEncoder() + effnet_checkpoint = load_file(effnet_checkpoint_path) + info = effnet.load_state_dict(effnet_checkpoint if "state_dict" not in effnet_checkpoint else effnet_checkpoint["state_dict"]) + logger.info(info) + del effnet_checkpoint + return effnet + + +def load_tokenizer(args: argparse.Namespace): + # TODO commonize with sdxl_train_util.load_tokenizers + logger.info("prepare tokenizers") + + original_paths = [CLIP_TEXT_MODEL_NAME] + tokenizers = [] + for i, original_path in enumerate(original_paths): + tokenizer: CLIPTokenizer = None + if args.tokenizer_cache_dir: + local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) + + if tokenizer is None: + tokenizer = CLIPTokenizer.from_pretrained(original_path) + + if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + + tokenizers.append(tokenizer) + + if hasattr(args, "max_token_length") and args.max_token_length is not None: + logger.info(f"update token length: {args.max_token_length}") + + return tokenizers[0] + + +def load_stage_c_model(stage_c_checkpoint_path, dtype=None, device="cpu") -> sc.StageC: + # Generator + logger.info(f"Instantiating Stage C generator") + with init_empty_weights(): + generator_c = sc.StageC() + logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}") + stage_c_checkpoint = load_file(stage_c_checkpoint_path) + logger.info(f"Loading state dict") + info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype) + logger.info(info) + return generator_c + + +def load_stage_b_model(stage_b_checkpoint_path, dtype=None, device="cpu") -> sc.StageB: + logger.info(f"Instantiating Stage B generator") + with init_empty_weights(): + generator_b = sc.StageB() + logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}") + stage_b_checkpoint = load_file(stage_b_checkpoint_path) + logger.info(f"Loading state dict") + info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype) + logger.info(info) + return generator_b + + +def load_clip_text_model(text_model_checkpoint_path, dtype=None, device="cpu", save_text_model=False): + # CLIP encoders + logger.info(f"Loading CLIP text model") + if save_text_model or text_model_checkpoint_path is None: + logger.info(f"Loading CLIP text model from {CLIP_TEXT_MODEL_NAME}") + text_model = CLIPTextModelWithProjection.from_pretrained(CLIP_TEXT_MODEL_NAME) + + if save_text_model: + sd = text_model.state_dict() + logger.info(f"Saving CLIP text model to {text_model_checkpoint_path}") + save_file(sd, text_model_checkpoint_path) + else: + logger.info(f"Loading CLIP text model from {text_model_checkpoint_path}") + + # copy from sdxl_model_util.py + text_model2_cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=1280, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + text_model = CLIPTextModelWithProjection(text_model2_cfg) + + text_model_checkpoint = load_file(text_model_checkpoint_path) + info = _load_state_dict_on_device(text_model, text_model_checkpoint, device, dtype=dtype) + logger.info(info) + + return text_model + + +def load_stage_a_model(stage_a_checkpoint_path, dtype=None, device="cpu") -> sc.StageA: + logger.info(f"Loading Stage A vqGAN from {stage_a_checkpoint_path}") + stage_a = sc.StageA().to(device) + stage_a_checkpoint = load_file(stage_a_checkpoint_path) + info = stage_a.load_state_dict( + stage_a_checkpoint if "state_dict" not in stage_a_checkpoint else stage_a_checkpoint["state_dict"] + ) + logger.info(info) + return stage_a + + +def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") -> sc.Previewer: + logger.info(f"Loading Previewer from {previewer_checkpoint_path}") + previewer = sc.Previewer().to(device) + previewer_checkpoint = load_file(previewer_checkpoint_path) + info = previewer.load_state_dict( + previewer_checkpoint if "state_dict" not in previewer_checkpoint else previewer_checkpoint["state_dict"] + ) + logger.info(info) + return previewer + + +def get_sai_model_spec(args): + timestamp = time.time() + + reso = args.resolution + + title = args.metadata_title if args.metadata_title is not None else args.output_name + + if args.min_timestep is not None or args.max_timestep is not None: + min_time_step = args.min_timestep if args.min_timestep is not None else 0 + max_time_step = args.max_timestep if args.max_timestep is not None else 1000 + timesteps = (min_time_step, max_time_step) + else: + timesteps = None + + metadata = sai_model_spec.build_metadata( + None, + False, + False, + False, + False, + False, + timestamp, + title=title, + reso=reso, + is_stable_diffusion_ckpt=False, + author=args.metadata_author, + description=args.metadata_description, + license=args.metadata_license, + tags=args.metadata_tags, + timesteps=timesteps, + clip_skip=args.clip_skip, # None or int + stable_cascade=True, + ) + return metadata + + +def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata): + state_dict = stage_c.state_dict() + if save_dtype is not None: + state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()} + + save_file(state_dict, ckpt_file, metadata=sai_metadata) + + # save text model + if text_model is not None: + text_model_sd = text_model.state_dict() + + if save_dtype is not None: + text_model_sd = {k: v.to(save_dtype) for k, v in text_model_sd.items()} + + text_model_ckpt_file = os.path.splitext(ckpt_file)[0] + "_text_model.safetensors" + save_file(text_model_sd, text_model_ckpt_file) + + +def save_stage_c_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + stage_c, + text_model, +): + def stage_c_saver(ckpt_file, epoch_no, global_step): + sai_metadata = get_sai_model_spec(args) + stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata) + + save_sd_model_on_epoch_end_or_stepwise_common( + args, on_epoch_end, accelerator, True, True, epoch, num_train_epochs, global_step, stage_c_saver, None + ) + + +def save_stage_c_model_on_end( + args: argparse.Namespace, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + stage_c, + text_model, +): + def stage_c_saver(ckpt_file, epoch_no, global_step): + sai_metadata = get_sai_model_spec(args) + stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata) + + save_sd_model_on_train_end_common(args, True, True, epoch, global_step, stage_c_saver, None) + + +# endregion + +# region sample generation + + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + previewer, + tokenizer, + text_encoder, + stage_c, + gdf, + prompt_replacement=None, +): + if steps == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts): + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap unet and text_encoder(s) + stage_c = accelerator.unwrap_model(stage_c) + text_encoder = accelerator.unwrap_model(text_encoder) + + # read prompts + if args.sample_prompts.endswith(".txt"): + with open(args.sample_prompts, "r", encoding="utf-8") as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif args.sample_prompts.endswith(".toml"): + with open(args.sample_prompts, "r", encoding="utf-8") as f: + data = toml.load(f) + prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] + elif args.sample_prompts.endswith(".json"): + with open(args.sample_prompts, "r", encoding="utf-8") as f: + prompts = json.load(f) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + tokenizer, + text_encoder, + stage_c, + previewer, + gdf, + save_dir, + prompt_dict, + epoch, + steps, + prompt_replacement, + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + tokenizer, + text_encoder, + stage_c, + previewer, + gdf, + save_dir, + prompt_dict, + epoch, + steps, + prompt_replacement, + ) + + # I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here. + # with torch.cuda.device(torch.cuda.current_device()): + # torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + tokenizer, + text_model, + stage_c, + previewer, + gdf, + save_dir, + prompt_dict, + epoch, + steps, + prompt_replacement, +): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 20) + width = prompt_dict.get("width", 1024) + height = prompt_dict.get("height", 1024) + scale = prompt_dict.get("scale", 4) + seed = prompt_dict.get("seed") + # controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + negative_prompt = "" if negative_prompt is None else negative_prompt + cfg = scale + timesteps = sample_steps + shift = 2 + t_start = 1.0 + + stage_c_latent_shape, _ = calculate_latent_sizes(height, width, batch_size=1) + + # PREPARE CONDITIONS + input_ids = tokenizer( + [prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" + )["input_ids"].to(text_model.device) + cond_text, cond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model) + + input_ids = tokenizer( + [negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" + )["input_ids"].to(text_model.device) + uncond_text, uncond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model) + + device = accelerator.device + dtype = stage_c.dtype + cond_text = cond_text.to(device, dtype=dtype) + cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype) + + uncond_text = uncond_text.to(device, dtype=dtype) + uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype) + + zero_img_emb = torch.zeros(1, 768, device=device) + + # 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく + conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb} + unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": zero_img_emb} + + with torch.no_grad(): # , torch.cuda.amp.autocast(dtype=dtype): + sampling_c = gdf.sample( + stage_c, + conditions, + stage_c_latent_shape, + unconditions, + device=device, + cfg=cfg, + shift=shift, + timesteps=timesteps, + t_start=t_start, + ) + for sampled_c, _, _ in tqdm(sampling_c, total=timesteps): + sampled_c = sampled_c + + sampled_c = sampled_c.to(previewer.device, dtype=previewer.dtype) + image = previewer(sampled_c)[0] + image = torch.clamp(image, 0, 1) + image = image.cpu().numpy().transpose(1, 2, 0) + image = image * 255 + image = image.astype(np.uint8) + image = Image.fromarray(image) + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass + + +# endregion + + +def add_effnet_arguments(parser): + parser.add_argument( + "--effnet_checkpoint_path", + type=str, + required=True, + help="path to EfficientNet checkpoint / EfficientNetのチェックポイントのパス", + ) + return parser + + +def add_text_model_arguments(parser): + parser.add_argument( + "--text_model_checkpoint_path", + type=str, + help="path to CLIP text model checkpoint / CLIPテキストモデルのチェックポイントのパス", + ) + parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path") + return parser + + +def add_stage_a_arguments(parser): + parser.add_argument( + "--stage_a_checkpoint_path", + type=str, + required=True, + help="path to Stage A checkpoint / Stage Aのチェックポイントのパス", + ) + return parser + + +def add_stage_b_arguments(parser): + parser.add_argument( + "--stage_b_checkpoint_path", + type=str, + required=True, + help="path to Stage B checkpoint / Stage Bのチェックポイントのパス", + ) + return parser + + +def add_stage_c_arguments(parser): + parser.add_argument( + "--stage_c_checkpoint_path", + type=str, + required=True, + help="path to Stage C checkpoint / Stage Cのチェックポイントのパス", + ) + return parser + + +def add_previewer_arguments(parser): + parser.add_argument( + "--previewer_checkpoint_path", + type=str, + required=False, + help="path to previewer checkpoint / previewerのチェックポイントのパス", + ) + return parser + + +def add_training_arguments(parser): + parser.add_argument( + "--adaptive_loss_weight", + action="store_true", + help="if specified, use adaptive loss weight. if not, use P2 loss weight" + + " / Adaptive Loss Weightを使用する。指定しない場合はP2 Loss Weightを使用する", + ) diff --git a/library/train_util.py b/library/train_util.py index ba428e508..ff08d5f8d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6,6 +6,7 @@ import datetime import importlib import json +import logging import pathlib import re import shutil @@ -19,8 +20,7 @@ Tuple, Union, ) -from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs -import gc +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob import math import os @@ -31,7 +31,12 @@ import toml from tqdm import tqdm + import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torchvision import transforms @@ -64,7 +69,12 @@ import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec +from library.utils import setup_logging + +setup_logging() +import logging +logger = logging.getLogger(__name__) # from library.attention_processors import FlashAttnProcessor # from library.hypernetwork import replace_attentions_for_hypernetwork from library.original_unet import UNet2DConditionModel @@ -73,6 +83,8 @@ TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ +HIGH_VRAM = False + # checkpointファイル名 EPOCH_STATE_NAME = "{}-{:06d}-state" EPOCH_FILE_NAME = "{}-{:06d}" @@ -121,6 +133,7 @@ ) TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +STABLE_CASCADE_LATENTS_CACHE_SUFFIX = "_sc_latents.npz" class ImageInfo: @@ -211,7 +224,7 @@ def add_if_new_reso(self, reso): self.reso_to_id[reso] = bucket_id self.resos.append(reso) self.buckets.append([]) - # print(reso, bucket_id, len(self.buckets)) + # logger.info(reso, bucket_id, len(self.buckets)) def round_to_steps(self, x): x = int(x + 0.5) @@ -237,7 +250,7 @@ def select_bucket(self, image_width, image_height): scale = reso[0] / image_width resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5)) - # print("use predef", image_width, image_height, reso, resized_size) + # logger.info(f"use predef, {image_width}, {image_height}, {reso}, {resized_size}") else: # 縮小のみを行う if image_width * image_height > self.max_area: @@ -256,21 +269,21 @@ def select_bucket(self, image_width, image_height): b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) ar_height_rounded = b_width_in_hr / b_height_rounded - # print(b_width_rounded, b_height_in_wr, ar_width_rounded) - # print(b_width_in_hr, b_height_rounded, ar_height_rounded) + # logger.info(b_width_rounded, b_height_in_wr, ar_width_rounded) + # logger.info(b_width_in_hr, b_height_rounded, ar_height_rounded) if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5)) else: resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded) - # print(resized_size) + # logger.info(resized_size) else: resized_size = (image_width, image_height) # リサイズは不要 # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする) bucket_width = resized_size[0] - resized_size[0] % self.reso_steps bucket_height = resized_size[1] - resized_size[1] % self.reso_steps - # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height) + # logger.info(f"use arbitrary {image_width}, {image_height}, {resized_size}, {bucket_width}, {bucket_height}") reso = (bucket_width, bucket_height) @@ -779,15 +792,15 @@ def make_buckets(self): bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) min_size and max_size are ignored when enable_bucket is False """ - print("loading image sizes.") + logger.info("loading image sizes.") for info in tqdm(self.image_data.values()): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) if self.enable_bucket: - print("make buckets") + logger.info("make buckets") else: - print("prepare dataset") + logger.info("prepare dataset") # bucketを作成し、画像をbucketに振り分ける if self.enable_bucket: @@ -802,7 +815,7 @@ def make_buckets(self): if not self.bucket_no_upscale: self.bucket_manager.make_buckets() else: - print( + logger.warning( "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" ) @@ -813,7 +826,7 @@ def make_buckets(self): image_width, image_height ) - # print(image_info.image_key, image_info.bucket_reso) + # logger.info(image_info.image_key, image_info.bucket_reso) img_ar_errors.append(abs(ar_error)) self.bucket_manager.sort() @@ -831,20 +844,20 @@ def make_buckets(self): # bucket情報を表示、格納する if self.enable_bucket: self.bucket_info = {"buckets": {}} - print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") + logger.info("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): count = len(bucket) if count > 0: self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} - print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") + logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}") img_ar_errors = np.array(img_ar_errors) mean_img_ar_error = np.mean(np.abs(img_ar_errors)) self.bucket_info["mean_img_ar_error"] = mean_img_ar_error - print(f"mean ar error (without repeats): {mean_img_ar_error}") + logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる - self.buckets_indices: List(BucketBatchIndex) = [] + self.buckets_indices: List[BucketBatchIndex] = [] for bucket_index, bucket in enumerate(self.bucket_manager.buckets): batch_count = int(math.ceil(len(bucket) / self.batch_size)) for batch_index in range(batch_count): @@ -861,7 +874,7 @@ def make_buckets(self): # num_of_image_types = len(set(bucket)) # bucket_batch_size = min(self.batch_size, num_of_image_types) # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count) + # # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count) # for batch_index in range(batch_count): # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) # ↑ここまで @@ -898,9 +911,9 @@ def is_text_encoder_output_cacheable(self): ] ) - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + def cache_latents(self, vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと - print("caching latents.") + logger.info("caching latents.") image_infos = list(self.image_data.values()) @@ -910,7 +923,7 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # split by resolution batches = [] batch = [] - print("checking cache validity...") + logger.info("checking cache validity...") for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] @@ -919,11 +932,11 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # check disk cache exists and size of latents if cache_to_disk: - info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" + info.latents_npz = os.path.splitext(info.absolute_path)[0] + cache_file_suffix if not is_main_process: # store to info only continue - cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug) + cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug, divisor) if cache_available: # do not add to batch continue @@ -947,7 +960,7 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc return # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded - print("caching latents...") + logger.info("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) @@ -955,21 +968,25 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する # SD1/2に対応するにはv2のフラグを持つ必要があるので後回し def cache_text_encoder_outputs( - self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True + self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process, cache_file_suffix ): - assert len(tokenizers) == 2, "only support SDXL" + """ + 最後の Text Encoder の pool がキャッシュされる。 + The last Text Encoder's pool is cached. + """ + # assert len(tokenizers) == 2, "only support SDXL" # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと - print("caching text encoder outputs.") + logger.info("caching text encoder outputs.") image_infos = list(self.image_data.values()) - print("checking cache existence...") + logger.info("checking cache existence...") image_infos_to_cache = [] for info in tqdm(image_infos): # subset = self.image_to_subset[info.image_key] if cache_to_disk: - te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX + te_out_npz = os.path.splitext(info.absolute_path)[0] + cache_file_suffix info.text_encoder_outputs_npz = te_out_npz if not is_main_process: # store to info only @@ -994,7 +1011,7 @@ def cache_text_encoder_outputs( batches = [] for info in image_infos_to_cache: input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) - input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) + input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) if len(tokenizers) > 1 else None batch.append((info, input_ids1, input_ids2)) if len(batch) >= self.batch_size: @@ -1005,11 +1022,11 @@ def cache_text_encoder_outputs( batches.append(batch) # iterate batches: call text encoder and cache outputs for memory or disk - print("caching text encoder outputs...") + logger.info("caching text encoder outputs...") for batch in tqdm(batches): infos, input_ids1, input_ids2 = zip(*batch) input_ids1 = torch.stack(input_ids1, dim=0) - input_ids2 = torch.stack(input_ids2, dim=0) + input_ids2 = torch.stack(input_ids2, dim=0) if input_ids2[0] is not None else None cache_batch_text_encoder_outputs( infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype ) @@ -1258,7 +1275,9 @@ def __getitem__(self, index): # example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions]) # example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions]) example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list) - example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) + example["text_encoder_outputs2_list"] = ( + torch.stack(text_encoder_outputs2_list) if text_encoder_outputs2_list[0] is not None else None + ) example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) if images[0] is not None: @@ -1315,7 +1334,7 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): if self.caching_mode == "text": input_ids1 = self.get_input_ids(caption, self.tokenizers[0]) - input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) + input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) if len(self.tokenizers) > 1 else None else: input_ids1 = None input_ids2 = None @@ -1404,7 +1423,7 @@ def read_caption(img_path, caption_extension): try: lines = f.readlines() except UnicodeDecodeError as e: - print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") raise e assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" caption = lines[0].strip() @@ -1413,11 +1432,11 @@ def read_caption(img_path, caption_extension): def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): - print(f"not directory: {subset.image_dir}") + logger.warning(f"not directory: {subset.image_dir}") return [], [] img_paths = glob_images(subset.image_dir, "*") - print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] @@ -1425,7 +1444,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): for img_path in img_paths: cap_for_img = read_caption(img_path, subset.caption_extension) if cap_for_img is None and subset.class_tokens is None: - print( + logger.warning( f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" ) captions.append("") @@ -1444,36 +1463,38 @@ def load_dreambooth_dir(subset: DreamBoothSubset): number_of_missing_captions_to_show = 5 remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show - print( + logger.warning( f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。" ) for i, missing_caption in enumerate(missing_captions): if i >= number_of_missing_captions_to_show: - print(missing_caption + f"... and {remaining_missing_captions} more") + logger.warning(missing_caption + f"... and {remaining_missing_captions} more") break - print(missing_caption) + logger.warning(missing_caption) return img_paths, captions - print("prepare images.") + logger.info("prepare images.") num_train_images = 0 num_reg_images = 0 reg_infos: List[ImageInfo] = [] for subset in subsets: if subset.num_repeats < 1: - print( + logger.warning( f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" ) continue if subset in self.subsets: - print( + logger.warning( f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" ) continue img_paths, captions = load_dreambooth_dir(subset) if len(img_paths) < 1: - print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") + logger.warning( + f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します" + ) continue if subset.is_reg: @@ -1491,15 +1512,15 @@ def load_dreambooth_dir(subset: DreamBoothSubset): subset.img_count = len(img_paths) self.subsets.append(subset) - print(f"{num_train_images} train images with repeating.") + logger.info(f"{num_train_images} train images with repeating.") self.num_train_images = num_train_images - print(f"{num_reg_images} reg images.") + logger.info(f"{num_reg_images} reg images.") if num_train_images < num_reg_images: - print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") + logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") if num_reg_images == 0: - print("no regularization images / 正則化画像が見つかりませんでした") + logger.warning("no regularization images / 正則化画像が見つかりませんでした") else: # num_repeatsを計算する:どうせ大した数ではないのでループで処理する n = 0 @@ -1544,27 +1565,29 @@ def __init__( for subset in subsets: if subset.num_repeats < 1: - print( + logger.warning( f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" ) continue if subset in self.subsets: - print( + logger.warning( f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" ) continue # メタデータを読み込む if os.path.exists(subset.metadata_file): - print(f"loading existing metadata: {subset.metadata_file}") + logger.info(f"loading existing metadata: {subset.metadata_file}") with open(subset.metadata_file, "rt", encoding="utf-8") as f: metadata = json.load(f) else: raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") if len(metadata) < 1: - print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") + logger.warning( + f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します" + ) continue tags_list = [] @@ -1583,12 +1606,15 @@ def __init__( # なければnpzを探す if abs_path is None: - if os.path.exists(os.path.splitext(image_key)[0] + ".npz"): - abs_path = os.path.splitext(image_key)[0] + ".npz" - else: - npz_path = os.path.join(subset.image_dir, image_key + ".npz") - if os.path.exists(npz_path): - abs_path = npz_path + abs_path = os.path.splitext(image_key)[0] + ".npz" + if not os.path.exists(abs_path): + abs_path = os.path.splitext(image_key)[0] + STABLE_CASCADE_LATENTS_CACHE_SUFFIX + if not os.path.exists(abs_path): + abs_path = os.path.join(subset.image_dir, image_key + ".npz") + if not os.path.exists(abs_path): + abs_path = os.path.join(subset.image_dir, image_key + STABLE_CASCADE_LATENTS_CACHE_SUFFIX) + if not os.path.exists(abs_path): + abs_path = None assert abs_path is not None, f"no image / 画像がありません: {image_key}" @@ -1608,7 +1634,7 @@ def __init__( if not subset.color_aug and not subset.random_crop: # if npz exists, use them - image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) + image_info.latents_npz = self.image_key_to_npz_file(subset, image_key) self.register_image(image_info, subset) @@ -1622,7 +1648,7 @@ def __init__( # check existence of all npz files use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets]) if use_npz_latents: - flip_aug_in_subset = False + # flip_aug_in_subset = False npz_any = False npz_all = True @@ -1632,9 +1658,12 @@ def __init__( has_npz = image_info.latents_npz is not None npz_any = npz_any or has_npz - if subset.flip_aug: - has_npz = has_npz and image_info.latents_npz_flipped is not None - flip_aug_in_subset = True + # flip は同一の .npz 内に格納するようにした: + # そのためここでチェック漏れがあり実行時にエラーになる可能性があるので要検討 + # if subset.flip_aug: + # has_npz = has_npz and image_info.latents_npz_flipped is not None + # flip_aug_in_subset = True + npz_all = npz_all and has_npz if npz_any and not npz_all: @@ -1642,14 +1671,16 @@ def __init__( if not npz_any: use_npz_latents = False - print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") + logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") elif not npz_all: use_npz_latents = False - print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") - if flip_aug_in_subset: - print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") + logger.warning( + f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します" + ) + # if flip_aug_in_subset: + # logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") # else: - # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") + # logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") # check min/max bucket size sizes = set() @@ -1665,7 +1696,9 @@ def __init__( if sizes is None: if use_npz_latents: use_npz_latents = False - print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") + logger.warning( + f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します" + ) assert ( resolution is not None @@ -1679,8 +1712,8 @@ def __init__( self.bucket_no_upscale = bucket_no_upscale else: if not enable_bucket: - print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") - print("using bucket info in metadata / メタデータ内のbucket情報を使います") + logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") + logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います") self.enable_bucket = True assert ( @@ -1694,34 +1727,29 @@ def __init__( # npz情報をきれいにしておく if not use_npz_latents: for image_info in self.image_data.values(): - image_info.latents_npz = image_info.latents_npz_flipped = None + image_info.latents_npz = None # image_info.latents_npz_flipped = def image_key_to_npz_file(self, subset: FineTuningSubset, image_key): base_name = os.path.splitext(image_key)[0] - npz_file_norm = base_name + ".npz" + npz_file_norm = base_name + ".npz" + if not os.path.exists(npz_file_norm): + npz_file_norm = base_name + STABLE_CASCADE_LATENTS_CACHE_SUFFIX if os.path.exists(npz_file_norm): - # image_key is full path - npz_file_flip = base_name + "_flip.npz" - if not os.path.exists(npz_file_flip): - npz_file_flip = None - return npz_file_norm, npz_file_flip + return npz_file_norm # if not full path, check image_dir. if image_dir is None, return None if subset.image_dir is None: - return None, None + return None # image_key is relative path npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz") - npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz") - if not os.path.exists(npz_file_norm): - npz_file_norm = None - npz_file_flip = None - elif not os.path.exists(npz_file_flip): - npz_file_flip = None + npz_file_norm = os.path.join(subset.image_dir, base_name + STABLE_CASCADE_LATENTS_CACHE_SUFFIX) + if os.path.exists(npz_file_norm): + return npz_file_norm - return npz_file_norm, npz_file_flip + return None class ControlNetDataset(BaseDataset): @@ -1803,7 +1831,7 @@ def __init__( assert subset is not None, "internal error: subset not found" if not os.path.isdir(subset.conditioning_data_dir): - print(f"not directory: {subset.conditioning_data_dir}") + logger.warning(f"not directory: {subset.conditioning_data_dir}") continue img_basename = os.path.basename(info.absolute_path) @@ -1861,7 +1889,9 @@ def __getitem__(self, index): assert ( cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + cond_img = cv2.resize( + cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA + ) # INTER_AREAでやりたいのでcv2でリサイズ # TODO support random crop # 現在サポートしているcropはrandomではなく中央のみ @@ -1921,17 +1951,26 @@ def enable_XTI(self, *args, **kwargs): for dataset in self.datasets: dataset.enable_XTI(*args, **kwargs) - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, cache_file_suffix=".npz", divisor=8): for i, dataset in enumerate(self.datasets): - print(f"[Dataset {i}]") - dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + logger.info(f"[Dataset {i}]") + dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor) def cache_text_encoder_outputs( - self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True + self, + tokenizers, + text_encoders, + device, + weight_dtype, + cache_to_disk=False, + is_main_process=True, + cache_file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX, ): for i, dataset in enumerate(self.datasets): - print(f"[Dataset {i}]") - dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) + logger.info(f"[Dataset {i}]") + dataset.cache_text_encoder_outputs( + tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process, cache_file_suffix + ) def set_caching_mode(self, caching_mode): for dataset in self.datasets: @@ -1964,8 +2003,8 @@ def disable_token_padding(self): dataset.disable_token_padding() -def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): - expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 +def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, divisor: int = 8) -> bool: + expected_latents_size = (reso[1] // divisor, reso[0] // divisor) # bucket_resoはWxHなので注意 if not os.path.exists(npz_path): return False @@ -2014,12 +2053,15 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli def debug_dataset(train_dataset, show_input_ids=False): - print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します") + logger.info(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") + logger.info( + "`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します" + ) epoch = 1 while True: - print(f"\nepoch: {epoch}") + logger.info(f"") + logger.info(f"epoch: {epoch}") steps = (epoch - 1) * len(train_dataset) + 1 indices = list(range(len(train_dataset))) @@ -2029,11 +2071,11 @@ def debug_dataset(train_dataset, show_input_ids=False): for i, idx in enumerate(indices): train_dataset.set_current_epoch(epoch) train_dataset.set_current_step(steps) - print(f"steps: {steps} ({i + 1}/{len(train_dataset)})") + logger.info(f"steps: {steps} ({i + 1}/{len(train_dataset)})") example = train_dataset[idx] if example["latents"] is not None: - print(f"sample has latents from npz file: {example['latents'].size()}") + logger.info(f"sample has latents from npz file: {example['latents'].size()}") for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( zip( example["image_keys"], @@ -2046,26 +2088,26 @@ def debug_dataset(train_dataset, show_input_ids=False): example["flippeds"], ) ): - print( + logger.info( f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) if "network_multipliers" in example: print(f"network multiplier: {example['network_multipliers'][j]}") if show_input_ids: - print(f"input ids: {iid}") - if "input_ids2" in example: - print(f"input ids2: {example['input_ids2'][j]}") + logger.info(f"input ids: {iid}") + if "input_ids2" in example and example["input_ids2"] is not None: + logger.info(f"input ids2: {example['input_ids2'][j]}") if example["images"] is not None: im = example["images"][j] - print(f"image size: {im.size()}") + logger.info(f"image size: {im.size()}") im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c im = im[:, :, ::-1] # RGB -> BGR (OpenCV) if "conditioning_images" in example: cond_img = example["conditioning_images"][j] - print(f"conditioning image size: {cond_img.size()}") + logger.info(f"conditioning image size: {cond_img.size()}") cond_img = ((cond_img.numpy() + 1.0) * 127.5).astype(np.uint8) cond_img = np.transpose(cond_img, (1, 2, 0)) cond_img = cond_img[:, :, ::-1] @@ -2213,12 +2255,12 @@ def trim_and_resize_if_required( if image_width > reso[0]: trim_size = image_width - reso[0] p = trim_size // 2 if not random_crop else random.randint(0, trim_size) - # print("w", trim_size, p) + # logger.info(f"w {trim_size} {p}") image = image[:, p : p + reso[0]] if image_height > reso[1]: trim_size = image_height - reso[1] p = trim_size // 2 if not random_crop else random.randint(0, trim_size) - # print("h", trim_size, p) + # logger.info(f"h {trim_size} {p}) image = image[p : p + reso[1]] # random cropの場合のcropされた値をどうcrop left/topに反映するべきか全くアイデアがない @@ -2231,7 +2273,7 @@ def trim_and_resize_if_required( def cache_batch_latents( - vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool + vae: Union[AutoencoderKL, torch.nn.Module], cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool ) -> None: r""" requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz @@ -2278,32 +2320,44 @@ def cache_batch_latents( if flip_aug: info.latents_flipped = flipped_latent - # FIXME this slows down caching a lot, specify this as an option - if torch.cuda.is_available(): - torch.cuda.empty_cache() + if not HIGH_VRAM: + clean_memory_on_device(vae.device) def cache_batch_text_encoder_outputs( image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype ): input_ids1 = input_ids1.to(text_encoders[0].device) - input_ids2 = input_ids2.to(text_encoders[1].device) + input_ids2 = input_ids2.to(text_encoders[1].device) if input_ids2 is not None else None with torch.no_grad(): - b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl( - max_token_length, - input_ids1, - input_ids2, - tokenizers[0], - tokenizers[1], - text_encoders[0], - text_encoders[1], - dtype, - ) + # TODO SDXL と Stable Cascade で統一する + if len(tokenizers) == 1: + # Stable Cascade + b_hidden_state1, b_pool2 = get_hidden_states_stable_cascade( + max_token_length, input_ids1, tokenizers[0], text_encoders[0], dtype + ) + + b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768 + b_pool2 = b_pool2.detach().to("cpu") # b,1280 + + b_hidden_state2 = [None] * input_ids1.shape[0] + else: + # SDXL + b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl( + max_token_length, + input_ids1, + input_ids2, + tokenizers[0], + tokenizers[1], + text_encoders[0], + text_encoders[1], + dtype, + ) # ここでcpuに移動しておかないと、上書きされてしまう b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768 - b_hidden_state2 = b_hidden_state2.detach().to("cpu") # b,n*75+2,1280 + b_hidden_state2 = b_hidden_state2.detach().to("cpu") if b_hidden_state2[0] is not None else b_hidden_state2 # b,n*75+2,1280 b_pool2 = b_pool2.detach().to("cpu") # b,1280 for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2): @@ -2316,18 +2370,25 @@ def cache_batch_text_encoder_outputs( def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2): - np.savez( - npz_path, - hidden_state1=hidden_state1.cpu().float().numpy(), - hidden_state2=hidden_state2.cpu().float().numpy(), - pool2=pool2.cpu().float().numpy(), - ) + save_kwargs = { + "hidden_state1": hidden_state1.cpu().float().numpy(), + "pool2": pool2.cpu().float().numpy(), + } + if hidden_state2 is not None: + save_kwargs["hidden_state2"] = hidden_state2.cpu().float().numpy() + np.savez(npz_path, **save_kwargs) + # np.savez( + # npz_path, + # hidden_state1=hidden_state1.cpu().float().numpy(), + # hidden_state2=hidden_state2.cpu().float().numpy() if hidden_state2 is not None else None, + # pool2=pool2.cpu().float().numpy(), + # ) def load_text_encoder_outputs_from_disk(npz_path): with np.load(npz_path) as f: hidden_state1 = torch.from_numpy(f["hidden_state1"]) - hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None + hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f and f["hidden_state2"] is not None else None pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None return hidden_state1, hidden_state2, pool2 @@ -2456,7 +2517,7 @@ def get_git_revision_hash() -> str: # def replace_unet_cross_attn_to_xformers(): -# print("CrossAttention.forward has been replaced to enable xformers.") +# logger.info("CrossAttention.forward has been replaced to enable xformers.") # try: # import xformers.ops # except ImportError: @@ -2499,10 +2560,10 @@ def get_git_revision_hash() -> str: # diffusers.models.attention.CrossAttention.forward = forward_xformers def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -2510,7 +2571,7 @@ def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdp unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_sdpa(True) @@ -2521,17 +2582,17 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform replace_vae_attn_to_memory_efficient() elif xformers: # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ - print("Use Diffusers xformers for VAE") + logger.info("Use Diffusers xformers for VAE") vae.encoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) vae.decoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) def replace_vae_attn_to_memory_efficient(): - print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") + logger.info("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states): - print("forward_flash_attn") + logger.info("forward_flash_attn") q_bucket_size = 512 k_bucket_size = 1024 @@ -2674,9 +2735,20 @@ def get_sai_model_spec( return metadata +def add_tokenizer_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) + + def add_sd_models_arguments(parser: argparse.ArgumentParser): # for pretrained models - parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む") + parser.add_argument( + "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" + ) parser.add_argument( "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" ) @@ -2686,12 +2758,7 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): default=None, help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル", ) - parser.add_argument( - "--tokenizer_cache_dir", - type=str, - default=None, - help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", - ) + add_tokenizer_arguments(parser) def add_optimizer_arguments(parser: argparse.ArgumentParser): @@ -2716,7 +2783,10 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない" + "--max_grad_norm", + default=1.0, + type=float, + help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない", ) parser.add_argument( @@ -2763,13 +2833,23 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): - parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ") - parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") parser.add_argument( - "--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名" + "--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ" ) parser.add_argument( - "--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類" + "--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名" + ) + parser.add_argument( + "--huggingface_repo_id", + type=str, + default=None, + help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名", + ) + parser.add_argument( + "--huggingface_repo_type", + type=str, + default=None, + help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類", ) parser.add_argument( "--huggingface_path_in_repo", @@ -2805,10 +2885,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="precision in saving / 保存時に精度を変更して保存する", ) parser.add_argument( - "--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する" + "--save_every_n_epochs", + type=int, + default=None, + help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する", ) parser.add_argument( - "--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する" + "--save_every_n_steps", + type=int, + default=None, + help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する", ) parser.add_argument( "--save_n_epoch_ratio", @@ -2860,7 +2946,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", ) - parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う") + parser.add_argument( + "--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う" + ) parser.add_argument( "--dynamo_backend", type=str, @@ -2878,7 +2966,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)", ) parser.add_argument( - "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", ) parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") @@ -2910,7 +3001,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数", ) parser.add_argument( - "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="use mixed precision / 混合精度を使う場合、その精度", ) parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") parser.add_argument( @@ -2952,7 +3047,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: choices=["tensorboard", "wandb", "all"], help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)", ) - parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") + parser.add_argument( + "--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列" + ) parser.add_argument( "--log_tracker_name", type=str, @@ -3035,13 +3132,24 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--lowram", action="store_true", - help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", + help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込む等(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", + ) + parser.add_argument( + "--highvram", + action="store_true", + help="disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM) " + + "/ VRAMが少ない環境向け最適化を無効にする。たとえば各latentのキャッシュ後のCUDAキャッシュクリアを行わない等(VRAMが多い環境向け)", ) parser.add_argument( - "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する" + "--sample_every_n_steps", + type=int, + default=None, + help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する", + ) + parser.add_argument( + "--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する" ) - parser.add_argument("--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する") parser.add_argument( "--sample_every_n_epochs", type=int, @@ -3049,7 +3157,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)", ) parser.add_argument( - "--sample_prompts", type=str, default=None, help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル" + "--sample_prompts", + type=str, + default=None, + help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル", ) parser.add_argument( "--sample_sampler", @@ -3126,17 +3237,32 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: def verify_training_args(args: argparse.Namespace): - if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") - if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + r""" + Verify training arguments. Also reflect highvram option to global variable + 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する + """ + if args.highvram: + print("highvram is enabled / highvramが有効です") + global HIGH_VRAM + HIGH_VRAM = True if args.cache_latents_to_disk and not args.cache_latents: args.cache_latents = True - print( + logger.warning( "cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします" ) + if not hasattr(args, "v_parameterization"): + # Stable Cascade: skip following checks + return + + if args.v_parameterization and not args.v2: + logger.warning( + "v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません" + ) + if args.v2 and args.clip_skip is not None: + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + # noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time # # Listを使って数えてもいいけど並べてしまえ # if args.noise_offset is not None and args.multires_noise_iterations is not None: @@ -3164,7 +3290,7 @@ def verify_training_args(args: argparse.Namespace): ) if args.zero_terminal_snr and not args.v_parameterization: - print( + logger.warning( f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected" + " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります" ) @@ -3174,8 +3300,12 @@ def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool ): # dataset common - parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする") + parser.add_argument( + "--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ" + ) + parser.add_argument( + "--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする" + ) parser.add_argument("--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字") parser.add_argument( "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子" @@ -3211,8 +3341,12 @@ def add_dataset_arguments( default=None, help="suffix for caption text / captionのテキストの末尾に付ける文字列", ) - parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") - parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") + parser.add_argument( + "--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする" + ) + parser.add_argument( + "--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする" + ) parser.add_argument( "--face_crop_aug_range", type=str, @@ -3225,7 +3359,9 @@ def add_dataset_arguments( help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)", ) parser.add_argument( - "--debug_dataset", action="store_true", help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)" + "--debug_dataset", + action="store_true", + help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)", ) parser.add_argument( "--resolution", @@ -3238,14 +3374,18 @@ def add_dataset_arguments( action="store_true", help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする(augmentationは使用不可) ", ) - parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ") + parser.add_argument( + "--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ" + ) parser.add_argument( "--cache_latents_to_disk", action="store_true", help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", ) parser.add_argument( - "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする" + "--enable_bucket", + action="store_true", + help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする", ) parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") @@ -3256,7 +3396,9 @@ def add_dataset_arguments( help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", ) parser.add_argument( - "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" + "--bucket_no_upscale", + action="store_true", + help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します", ) parser.add_argument( @@ -3300,13 +3442,20 @@ def add_dataset_arguments( if support_dreambooth: # DreamBooth dataset - parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") + parser.add_argument( + "--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ" + ) if support_caption: # caption dataset - parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") parser.add_argument( - "--dataset_repeats", type=int, default=1, help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数" + "--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル" + ) + parser.add_argument( + "--dataset_repeats", + type=int, + default=1, + help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数", ) @@ -3334,7 +3483,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar if args.output_config: # check if config file exists if os.path.exists(config_path): - print(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") + logger.error(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") exit(1) # convert args to dictionary @@ -3362,14 +3511,14 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar with open(config_path, "w") as f: toml.dump(args_dict, f) - print(f"Saved config file / 設定ファイルを保存しました: {config_path}") + logger.info(f"Saved config file / 設定ファイルを保存しました: {config_path}") exit(0) if not os.path.exists(config_path): - print(f"{config_path} not found.") + logger.info(f"{config_path} not found.") exit(1) - print(f"Loading settings from {config_path}...") + logger.info(f"Loading settings from {config_path}...") with open(config_path, "r") as f: config_dict = toml.load(f) @@ -3388,7 +3537,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar config_args = argparse.Namespace(**ignore_nesting_dict) args = parser.parse_args(namespace=config_args) args.config_file = os.path.splitext(args.config_file)[0] - print(args.config_file) + logger.info(args.config_file) return args @@ -3403,11 +3552,11 @@ def resume_from_local_or_hf_if_specified(accelerator, args): return if not args.resume_from_huggingface: - print(f"resume training from local state: {args.resume}") + logger.info(f"resume training from local state: {args.resume}") accelerator.load_state(args.resume) return - print(f"resume training from huggingface state: {args.resume}") + logger.info(f"resume training from huggingface state: {args.resume}") repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1] path_in_repo = "/".join(args.resume.split("/")[2:]) revision = None @@ -3419,7 +3568,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): repo_type = "model" else: path_in_repo, revision, repo_type = divided - print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") + logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") list_files = huggingface_util.list_dir( repo_id=repo_id, @@ -3444,7 +3593,9 @@ def task(): loop = asyncio.get_event_loop() results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files])) if len(results) == 0: - raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした") + raise ValueError( + "No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした" + ) dirname = os.path.dirname(results[0]) accelerator.load_state(dirname) @@ -3491,7 +3642,7 @@ def get_optimizer(args, trainable_params): # value = tuple(value) optimizer_kwargs[key] = value - # print("optkwargs:", optimizer_kwargs) + # logger.info(f"optkwargs {optimizer}_{kwargs}") lr = args.learning_rate optimizer = None @@ -3501,7 +3652,7 @@ def get_optimizer(args, trainable_params): import lion_pytorch except ImportError: raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") - print(f"use Lion optimizer | {optimizer_kwargs}") + logger.info(f"use Lion optimizer | {optimizer_kwargs}") optimizer_class = lion_pytorch.Lion optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) @@ -3512,14 +3663,14 @@ def get_optimizer(args, trainable_params): raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") if optimizer_type == "AdamW8bit".lower(): - print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") optimizer_class = bnb.optim.AdamW8bit optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov8bit".lower(): - print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") if "momentum" not in optimizer_kwargs: - print( + logger.warning( f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します" ) optimizer_kwargs["momentum"] = 0.9 @@ -3528,7 +3679,7 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) elif optimizer_type == "Lion8bit".lower(): - print(f"use 8-bit Lion optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit Lion optimizer | {optimizer_kwargs}") try: optimizer_class = bnb.optim.Lion8bit except AttributeError: @@ -3536,7 +3687,7 @@ def get_optimizer(args, trainable_params): "No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください" ) elif optimizer_type == "PagedAdamW8bit".lower(): - print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") try: optimizer_class = bnb.optim.PagedAdamW8bit except AttributeError: @@ -3544,7 +3695,7 @@ def get_optimizer(args, trainable_params): "No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) elif optimizer_type == "PagedLion8bit".lower(): - print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") try: optimizer_class = bnb.optim.PagedLion8bit except AttributeError: @@ -3555,7 +3706,7 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): - print(f"use PagedAdamW optimizer | {optimizer_kwargs}") + logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}") try: import bitsandbytes as bnb except ImportError: @@ -3569,7 +3720,7 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW32bit".lower(): - print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") + logger.info(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") try: import bitsandbytes as bnb except ImportError: @@ -3583,16 +3734,18 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov".lower(): - print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") + logger.info(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") if "momentum" not in optimizer_kwargs: - print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します") + logger.info( + f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します" + ) optimizer_kwargs["momentum"] = 0.9 optimizer_class = torch.optim.SGD optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower(): - # check lr and lr_count, and print warning + # check lr and lr_count, and logger.info warning actual_lr = lr lr_count = 1 if type(trainable_params) == list and type(trainable_params[0]) == dict: @@ -3603,12 +3756,12 @@ def get_optimizer(args, trainable_params): lr_count = len(lrs) if actual_lr <= 0.1: - print( + logger.warning( f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}" ) - print("recommend option: lr=1.0 / 推奨は1.0です") + logger.warning("recommend option: lr=1.0 / 推奨は1.0です") if lr_count > 1: - print( + logger.warning( f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" ) @@ -3624,25 +3777,25 @@ def get_optimizer(args, trainable_params): # set optimizer if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): optimizer_class = experimental.DAdaptAdamPreprint - print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdaGrad".lower(): optimizer_class = dadaptation.DAdaptAdaGrad - print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdam".lower(): optimizer_class = dadaptation.DAdaptAdam - print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdan".lower(): optimizer_class = dadaptation.DAdaptAdan - print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdanIP".lower(): optimizer_class = experimental.DAdaptAdanIP - print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptLion".lower(): optimizer_class = dadaptation.DAdaptLion - print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptSGD".lower(): optimizer_class = dadaptation.DAdaptSGD - print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") else: raise ValueError(f"Unknown optimizer type: {optimizer_type}") @@ -3655,7 +3808,7 @@ def get_optimizer(args, trainable_params): except ImportError: raise ImportError("No Prodigy / Prodigy がインストールされていないようです") - print(f"use Prodigy optimizer | {optimizer_kwargs}") + logger.info(f"use Prodigy optimizer | {optimizer_kwargs}") optimizer_class = prodigyopt.Prodigy optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) @@ -3664,14 +3817,16 @@ def get_optimizer(args, trainable_params): if "relative_step" not in optimizer_kwargs: optimizer_kwargs["relative_step"] = True # default if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): - print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします") + logger.info( + f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします" + ) optimizer_kwargs["relative_step"] = True - print(f"use Adafactor optimizer | {optimizer_kwargs}") + logger.info(f"use Adafactor optimizer | {optimizer_kwargs}") if optimizer_kwargs["relative_step"]: - print(f"relative_step is true / relative_stepがtrueです") + logger.info(f"relative_step is true / relative_stepがtrueです") if lr != 0.0: - print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") + logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") args.learning_rate = None # trainable_paramsがgroupだった時の処理:lrを削除する @@ -3683,37 +3838,37 @@ def get_optimizer(args, trainable_params): if has_group_lr: # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない - print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") + logger.warning(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") args.unet_lr = None args.text_encoder_lr = None if args.lr_scheduler != "adafactor": - print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") + logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど lr = None else: if args.max_grad_norm != 0.0: - print( + logger.warning( f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません" ) if args.lr_scheduler != "constant_with_warmup": - print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") + logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: - print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") + logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") optimizer_class = transformers.optimization.Adafactor optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "AdamW".lower(): - print(f"use AdamW optimizer | {optimizer_kwargs}") + logger.info(f"use AdamW optimizer | {optimizer_kwargs}") optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) if optimizer is None: # 任意のoptimizerを使う optimizer_type = args.optimizer_type # lowerでないやつ(微妙) - print(f"use {optimizer_type} | {optimizer_kwargs}") + logger.info(f"use {optimizer_type} | {optimizer_kwargs}") if "." not in optimizer_type: optimizer_module = torch.optim else: @@ -3759,7 +3914,7 @@ def wrap_check_needless_num_warmup_steps(return_vals): # using any lr_scheduler from other library if args.lr_scheduler_type: lr_scheduler_type = args.lr_scheduler_type - print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") + logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") if "." not in lr_scheduler_type: # default to use torch.optim lr_scheduler_module = torch.optim.lr_scheduler else: @@ -3775,7 +3930,7 @@ def wrap_check_needless_num_warmup_steps(return_vals): type(optimizer) == transformers.optimization.Adafactor ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" initial_lr = float(name.split(":")[1]) - # print("adafactor scheduler init lr", initial_lr) + # logger.info(f"adafactor scheduler init lr {initial_lr}") return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) name = SchedulerType(name) @@ -3840,20 +3995,20 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): if support_metadata: if args.in_json is not None and (args.color_aug or args.random_crop): - print( + logger.warning( f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます" ) def load_tokenizer(args: argparse.Namespace): - print("prepare tokenizer") + logger.info("prepare tokenizer") original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH tokenizer: CLIPTokenizer = None if args.tokenizer_cache_dir: local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) if os.path.exists(local_tokenizer_path): - print(f"load tokenizer from cache: {local_tokenizer_path}") + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 if tokenizer is None: @@ -3863,10 +4018,10 @@ def load_tokenizer(args: argparse.Namespace): tokenizer = CLIPTokenizer.from_pretrained(original_path) if hasattr(args, "max_token_length") and args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") + logger.info(f"update token length: {args.max_token_length}") if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - print(f"save Tokenizer to cache: {local_tokenizer_path}") + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") tokenizer.save_pretrained(local_tokenizer_path) return tokenizer @@ -3888,7 +4043,9 @@ def prepare_accelerator(args: argparse.Namespace): log_with = args.log_with if log_with in ["tensorboard", "all"]: if logging_dir is None: - raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください") + raise ValueError( + "logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください" + ) if log_with in ["wandb", "all"]: try: import wandb @@ -3907,9 +4064,13 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers = ( InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, - DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph) - if args.ddp_gradient_as_bucket_view or args.ddp_static_graph - else None, + ( + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph + ) + if args.ddp_gradient_as_bucket_view or args.ddp_static_graph + else None + ), ) kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) accelerator = Accelerator( @@ -3920,6 +4081,7 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers=kwargs_handlers, dynamo_backend=dynamo_backend, ) + print("accelerator device:", accelerator.device) return accelerator @@ -3946,17 +4108,17 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une name_or_path = os.path.realpath(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: - print(f"load StableDiffusion checkpoint: {name_or_path}") + logger.info(f"load StableDiffusion checkpoint: {name_or_path}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint( args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2 ) else: # Diffusers model is loaded to CPU - print(f"load Diffusers pretrained models: {name_or_path}") + logger.info(f"load Diffusers pretrained models: {name_or_path}") try: pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) except EnvironmentError as ex: - print( + logger.error( f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" ) raise ex @@ -3967,7 +4129,7 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une # Diffusers U-Net to original U-Net # TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう - # print(f"unet config: {unet.config}") + # logger.info(f"unet config: {unet.config}") original_unet = UNet2DConditionModel( unet.config.sample_size, unet.config.attention_head_dim, @@ -3977,12 +4139,12 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une ) original_unet.load_state_dict(unet.state_dict()) unet = original_unet - print("U-Net converted to original U-Net") + logger.info("U-Net converted to original U-Net") # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, weight_dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") return text_encoder, vae, unet, load_stable_diffusion_format @@ -3991,7 +4153,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio # load models for each process for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: - print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model( args, @@ -4006,8 +4168,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio unet.to(accelerator.device) vae.to(accelerator.device) - gc.collect() - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() return text_encoder, vae, unet, load_stable_diffusion_format @@ -4058,7 +4219,9 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod # v1: ... の三連を ... へ戻す states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # for i in range(1, args.max_token_length, tokenizer.model_max_length): - states_list.append(encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]) # の後から の前まで + states_list.append( + encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2] + ) # の後から の前まで states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # encoder_hidden_states = torch.cat(states_list, dim=1) @@ -4179,6 +4342,54 @@ def get_hidden_states_sdxl( return hidden_states1, hidden_states2, pool2 +def get_hidden_states_stable_cascade( + max_token_length: int, + input_ids2: torch.Tensor, + tokenizer2: CLIPTokenizer, + text_encoder2: CLIPTextModel, + weight_dtype: Optional[str] = None, + accelerator: Optional[Accelerator] = None, +): + # ここに Stable Cascade 用のコードがあるのはとても気持ち悪いが、変に整理するよりわかりやすいので、とりあえずこのまま + # It's very awkward to have Stable Cascade code here, but it's easier to understand than to organize it in a strange way, so for now it's as it is. + + # input_ids: b,n,77 -> b*n, 77 + b_size = input_ids2.size()[0] + input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 + + # text_encoder2 + enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) + hidden_states2 = enc_out["hidden_states"][-1] # ** last layer ** + + # pool2 = enc_out["text_embeds"] + unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2) + pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) + + # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 + n_size = 1 if max_token_length is None else max_token_length // 75 + hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) + + if max_token_length is not None: + # bs*3, 77, 768 or 1024 + + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [hidden_states2[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer2.model_max_length): + chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # の後から 最後の前まで + states_list.append(chunk) # の後から の前まで + states_list.append(hidden_states2[:, -1].unsqueeze(1)) # のどちらか + hidden_states2 = torch.cat(states_list, dim=1) + + # pool はnの最初のものを使う + pool2 = pool2[::n_size] + + if weight_dtype is not None: + # this is required for additional network training + hidden_states2 = hidden_states2.to(weight_dtype) + + return hidden_states2, pool2 + + def default_if_none(value, default): return default if value is None else value @@ -4300,7 +4511,8 @@ def save_sd_model_on_epoch_end_or_stepwise_common( ckpt_name = get_step_ckpt_name(args, ext, global_step) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + logger.info("") + logger.info(f"saving checkpoint: {ckpt_file}") sd_saver(ckpt_file, epoch_no, global_step) if args.huggingface_repo_id is not None: @@ -4315,7 +4527,7 @@ def save_sd_model_on_epoch_end_or_stepwise_common( remove_ckpt_file = os.path.join(args.output_dir, remove_ckpt_name) if os.path.exists(remove_ckpt_file): - print(f"removing old checkpoint: {remove_ckpt_file}") + logger.info(f"removing old checkpoint: {remove_ckpt_file}") os.remove(remove_ckpt_file) else: @@ -4324,7 +4536,8 @@ def save_sd_model_on_epoch_end_or_stepwise_common( else: out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step)) - print(f"\nsaving model: {out_dir}") + logger.info("") + logger.info(f"saving model: {out_dir}") diffusers_saver(out_dir) if args.huggingface_repo_id is not None: @@ -4338,7 +4551,7 @@ def save_sd_model_on_epoch_end_or_stepwise_common( remove_out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, remove_no)) if os.path.exists(remove_out_dir): - print(f"removing old model: {remove_out_dir}") + logger.info(f"removing old model: {remove_out_dir}") shutil.rmtree(remove_out_dir) if args.save_state: @@ -4351,13 +4564,14 @@ def save_sd_model_on_epoch_end_or_stepwise_common( def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no): model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) - print(f"\nsaving state at epoch {epoch_no}") + logger.info("") + logger.info(f"saving state at epoch {epoch_no}") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - print("uploading state to huggingface.") + logger.info("uploading state to huggingface.") huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs @@ -4365,20 +4579,21 @@ def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, ep remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) if os.path.exists(state_dir_old): - print(f"removing old state: {state_dir_old}") + logger.info(f"removing old state: {state_dir_old}") shutil.rmtree(state_dir_old) def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no): model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) - print(f"\nsaving state at step {step_no}") + logger.info("") + logger.info(f"saving state at step {step_no}") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - print("uploading state to huggingface.") + logger.info("uploading state to huggingface.") huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no)) last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps @@ -4390,21 +4605,22 @@ def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_n if remove_step_no > 0: state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no)) if os.path.exists(state_dir_old): - print(f"removing old state: {state_dir_old}") + logger.info(f"removing old state: {state_dir_old}") shutil.rmtree(state_dir_old) def save_state_on_train_end(args: argparse.Namespace, accelerator): model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) - print("\nsaving last state.") + logger.info("") + logger.info("saving last state.") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - print("uploading last state to huggingface.") + logger.info("uploading last state to huggingface.") huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name)) @@ -4453,7 +4669,7 @@ def save_sd_model_on_train_end_common( ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") + logger.info(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") sd_saver(ckpt_file, epoch, global_step) if args.huggingface_repo_id is not None: @@ -4462,7 +4678,7 @@ def save_sd_model_on_train_end_common( out_dir = os.path.join(args.output_dir, model_name) os.makedirs(out_dir, exist_ok=True) - print(f"save trained model as Diffusers to {out_dir}") + logger.info(f"save trained model as Diffusers to {out_dir}") diffusers_saver(out_dir) if args.huggingface_repo_id is not None: @@ -4572,7 +4788,7 @@ def get_my_scheduler( # clip_sample=Trueにする if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") + # logger.info("set clip_sample to True") scheduler.config.clip_sample = True return scheduler @@ -4631,8 +4847,8 @@ def line_to_prompt_dict(line: str) -> dict: continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(ex) return prompt_dict @@ -4654,6 +4870,7 @@ def sample_images_common( """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した """ + if steps == 0: if not args.sample_at_first: return @@ -4668,13 +4885,16 @@ def sample_images_common( if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch return - print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): - print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + org_vae_device = vae.device # CPUにいるはず - vae.to(device) + vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device # unwrap unet and text_encoder(s) unet = accelerator.unwrap_model(unet) @@ -4684,10 +4904,6 @@ def sample_images_common( text_encoder = accelerator.unwrap_model(text_encoder) # read prompts - - # with open(args.sample_prompts, "rt", encoding="utf-8") as f: - # prompts = f.readlines() - if args.sample_prompts.endswith(".txt"): with open(args.sample_prompts, "r", encoding="utf-8") as f: lines = f.readlines() @@ -4700,12 +4916,11 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - schedulers: dict = {} + # schedulers: dict = {} cannot find where this is used default_scheduler = get_my_scheduler( sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization, ) - schedulers[args.sample_sampler] = default_scheduler pipeline = pipe_class( text_encoder=text_encoder, @@ -4718,105 +4933,58 @@ def sample_images_common( requires_safety_checker=False, clip_skip=args.clip_skip, ) - pipeline.to(device) - + pipeline.to(distributed_state.device) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) - rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None - - with torch.no_grad(): - # with accelerator.autocast(): - for i, prompt_dict in enumerate(prompts): - if not accelerator.is_main_process: - continue - - if isinstance(prompt_dict, str): - prompt_dict = line_to_prompt_dict(prompt_dict) - - assert isinstance(prompt_dict, dict) - negative_prompt = prompt_dict.get("negative_prompt") - sample_steps = prompt_dict.get("sample_steps", 30) - width = prompt_dict.get("width", 512) - height = prompt_dict.get("height", 512) - scale = prompt_dict.get("scale", 7.5) - seed = prompt_dict.get("seed") - controlnet_image = prompt_dict.get("controlnet_image") - prompt: str = prompt_dict.get("prompt", "") - sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) - - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - scheduler = schedulers.get(sampler_name) - if scheduler is None: - scheduler = get_my_scheduler( - sample_sampler=sampler_name, - v_parameterization=args.v_parameterization, - ) - schedulers[sampler_name] = scheduler - pipeline.scheduler = scheduler - - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - - if controlnet_image is not None: - controlnet_image = Image.open(controlnet_image).convert("RGB") - controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) - - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - print(f"prompt: {prompt}") - print(f"negative_prompt: {negative_prompt}") - print(f"height: {height}") - print(f"width: {width}") - print(f"sample_steps: {sample_steps}") - print(f"scale: {scale}") - print(f"sample_sampler: {sampler_name}") - if seed is not None: - print(f"seed: {seed}") - with accelerator.autocast(): - latents = pipeline( - prompt=prompt, - height=height, - width=width, - num_inference_steps=sample_steps, - guidance_scale=scale, - negative_prompt=negative_prompt, - controlnet=controlnet, - controlnet_image=controlnet_image, - ) + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) - image = pipeline.latents_to_image(latents)[0] + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - img_filename = ( - f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" - ) + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass - image.save(os.path.join(save_dir, img_filename)) + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) - # wandb有効時のみログを送信 - try: - wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet + ) # clear pipeline and cache to reduce vram usage del pipeline - torch.cuda.empty_cache() + + # I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here. + # with torch.cuda.device(torch.cuda.current_device()): + # torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) torch.set_rng_state(rng_state) if cuda_rng_state is not None: @@ -4824,8 +4992,105 @@ def sample_images_common( vae.to(org_vae_device) +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + pipeline, + save_dir, + prompt_dict, + epoch, + steps, + prompt_replacement, + controlnet=None, +): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + scheduler = get_my_scheduler( + sample_sampler=sampler_name, + v_parameterization=args.v_parameterization, + ) + pipeline.scheduler = scheduler + + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + with accelerator.autocast(): + latents = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + negative_prompt=negative_prompt, + controlnet=controlnet, + controlnet_image=controlnet_image, + ) + + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() + + image = pipeline.latents_to_image(latents)[0] + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass + + # endregion + # region 前処理用 @@ -4844,7 +5109,7 @@ def __getitem__(self, idx): # convert to tensor temporarily so dataloader will accept it tensor_pil = transforms.functional.pil_to_tensor(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None return (tensor_pil, img_path) diff --git a/library/utils.py b/library/utils.py index 7d801a676..3037c055d 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,6 +1,266 @@ +import logging +import sys import threading +import torch +from torchvision import transforms from typing import * +from diffusers import EulerAncestralDiscreteScheduler +import diffusers.schedulers.scheduling_euler_ancestral_discrete +from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput def fire_in_thread(f, *args, **kwargs): - threading.Thread(target=f, args=args, kwargs=kwargs).start() \ No newline at end of file + threading.Thread(target=f, args=args, kwargs=kwargs).start() + + +def add_logging_arguments(parser): + parser.add_argument( + "--console_log_level", + type=str, + default=None, + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO", + ) + parser.add_argument( + "--console_log_file", + type=str, + default=None, + help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する", + ) + parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力") + + +def setup_logging(args=None, log_level=None, reset=False): + if logging.root.handlers: + if reset: + # remove all handlers + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + else: + return + + # log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO + if log_level is None and args is not None: + log_level = args.console_log_level + if log_level is None: + log_level = "INFO" + log_level = getattr(logging, log_level) + + msg_init = None + if args is not None and args.console_log_file: + handler = logging.FileHandler(args.console_log_file, mode="w") + else: + handler = None + if not args or not args.console_log_simple: + try: + from rich.logging import RichHandler + from rich.console import Console + from rich.logging import RichHandler + + handler = RichHandler(console=Console(stderr=True)) + except ImportError: + # print("rich is not installed, using basic logging") + msg_init = "rich is not installed, using basic logging" + + if handler is None: + handler = logging.StreamHandler(sys.stdout) # same as print + handler.propagate = False + + formatter = logging.Formatter( + fmt="%(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + handler.setFormatter(formatter) + logging.root.setLevel(log_level) + logging.root.addHandler(handler) + + if msg_init is not None: + logger = logging.getLogger(__name__) + logger.info(msg_init) + + + +# TODO make inf_utils.py + + +# region Gradual Latent hires fix + + +class GradualLatent: + def __init__( + self, + ratio, + start_timesteps, + every_n_steps, + ratio_step, + s_noise=1.0, + gaussian_blur_ksize=None, + gaussian_blur_sigma=0.5, + gaussian_blur_strength=0.5, + unsharp_target_x=True, + ): + self.ratio = ratio + self.start_timesteps = start_timesteps + self.every_n_steps = every_n_steps + self.ratio_step = ratio_step + self.s_noise = s_noise + self.gaussian_blur_ksize = gaussian_blur_ksize + self.gaussian_blur_sigma = gaussian_blur_sigma + self.gaussian_blur_strength = gaussian_blur_strength + self.unsharp_target_x = unsharp_target_x + + def __str__(self) -> str: + return ( + f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, " + + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, " + + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, " + + f"unsharp_target_x={self.unsharp_target_x})" + ) + + def apply_unshark_mask(self, x: torch.Tensor): + if self.gaussian_blur_ksize is None: + return x + blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma) + # mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength) + mask = (x - blurred) * self.gaussian_blur_strength + sharpened = x + mask + return sharpened + + def interpolate(self, x: torch.Tensor, resized_size, unsharp=True): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.float() + + x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if unsharp and self.gaussian_blur_ksize: + x = self.apply_unshark_mask(x) + + return x + + +class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.resized_size = None + self.gradual_latent = None + + def set_gradual_latent_params(self, size, gradual_latent: GradualLatent): + self.resized_size = size + self.gradual_latent = gradual_latent + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + + """ + + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + # logger.warning( + print( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") + + sigma_from = self.sigmas[self.step_index] + sigma_to = self.sigmas[self.step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = sigma_down - sigma + + device = model_output.device + if self.resized_size is None: + prev_sample = sample + derivative * dt + + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + model_output.shape, dtype=model_output.dtype, device=device, generator=generator + ) + s_noise = 1.0 + else: + print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape) + s_noise = self.gradual_latent.s_noise + + if self.gradual_latent.unsharp_target_x: + prev_sample = sample + derivative * dt + prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size) + else: + sample = self.gradual_latent.interpolate(sample, self.resized_size) + derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False) + prev_sample = sample + derivative * dt + + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), + dtype=model_output.dtype, + device=device, + generator=generator, + ) + + prev_sample = prev_sample + noise * sigma_up * s_noise + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + +# endregion diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 51f581b29..6ec60a89b 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -2,10 +2,13 @@ import os import torch from safetensors.torch import load_file - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def main(file): - print(f"loading: {file}") + logger.info(f"loading: {file}") if os.path.splitext(file)[1] == ".safetensors": sd = load_file(file) else: @@ -17,16 +20,16 @@ def main(file): for key in keys: if "lora_up" in key or "lora_down" in key: values.append((key, sd[key])) - print(f"number of LoRA modules: {len(values)}") + logger.info(f"number of LoRA modules: {len(values)}") if args.show_all_keys: for key in [k for k in keys if k not in values]: values.append((key, sd[key])) - print(f"number of all modules: {len(values)}") + logger.info(f"number of all modules: {len(values)}") for key, value in values: value = value.to(torch.float32) - print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") + logger.info(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") def setup_parser() -> argparse.ArgumentParser: diff --git a/networks/control_net_lllite.py b/networks/control_net_lllite.py index 4ebfef7a4..c9377bee8 100644 --- a/networks/control_net_lllite.py +++ b/networks/control_net_lllite.py @@ -2,7 +2,10 @@ from typing import Optional, List, Type import torch from library import sdxl_original_unet - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied SKIP_INPUT_BLOCKS = False @@ -125,7 +128,7 @@ def set_cond_image(self, cond_image): return # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance - # print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") + # logger.info(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") cx = self.conditioning1(cond_image) if not self.is_conv2d: # reshape / b,c,h,w -> b,h*w,c @@ -155,7 +158,7 @@ def forward(self, x): cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) if self.use_zeros_for_batch_uncond: cx[0::2] = 0.0 # uncond is zero - # print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") + # logger.info(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") # downで入力の次元数を削減し、conditioning image embeddingと結合する # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している @@ -286,7 +289,7 @@ def create_modules( # create module instances self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule) - print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") + logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") def forward(self, x): return x # dummy @@ -319,7 +322,7 @@ def load_weights(self, file): return info def apply_to(self): - print("applying LLLite for U-Net...") + logger.info("applying LLLite for U-Net...") for module in self.unet_modules: module.apply_to() self.add_module(module.lllite_name, module) @@ -374,19 +377,19 @@ def save_weights(self, file, dtype, metadata): # sdxl_original_unet.USE_REENTRANT = False # test shape etc - print("create unet") + logger.info("create unet") unet = sdxl_original_unet.SdxlUNet2DConditionModel() unet.to("cuda").to(torch.float16) - print("create ControlNet-LLLite") + logger.info("create ControlNet-LLLite") control_net = ControlNetLLLite(unet, 32, 64) control_net.apply_to() control_net.to("cuda") - print(control_net) + logger.info(control_net) - # print number of parameters - print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad)) + # logger.info number of parameters + logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}") input() @@ -398,12 +401,12 @@ def save_weights(self, file, dtype, metadata): # # visualize # import torchviz - # print("run visualize") + # logger.info("run visualize") # controlnet.set_control(conditioning_image) # output = unet(x, t, ctx, y) - # print("make_dot") + # logger.info("make_dot") # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) - # print("render") + # logger.info("render") # image.format = "svg" # "png" # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time # input() @@ -414,12 +417,12 @@ def save_weights(self, file, dtype, metadata): scaler = torch.cuda.amp.GradScaler(enabled=True) - print("start training") + logger.info("start training") steps = 10 sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0] for step in range(steps): - print(f"step {step}") + logger.info(f"step {step}") batch_size = 1 conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 @@ -439,7 +442,7 @@ def save_weights(self, file, dtype, metadata): scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) - print(sample_param) + logger.info(f"{sample_param}") # from safetensors.torch import save_file diff --git a/networks/control_net_lllite_for_train.py b/networks/control_net_lllite_for_train.py index 026880015..65b3520cf 100644 --- a/networks/control_net_lllite_for_train.py +++ b/networks/control_net_lllite_for_train.py @@ -6,7 +6,10 @@ from typing import Optional, List, Type import torch from library import sdxl_original_unet - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied SKIP_INPUT_BLOCKS = False @@ -270,7 +273,7 @@ def apply_to_modules( # create module instances self.lllite_modules = apply_to_modules(self, target_modules) - print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.") + logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.") # def prepare_optimizer_params(self): def prepare_params(self): @@ -281,8 +284,8 @@ def prepare_params(self): train_params.append(p) else: non_train_params.append(p) - print(f"count of trainable parameters: {len(train_params)}") - print(f"count of non-trainable parameters: {len(non_train_params)}") + logger.info(f"count of trainable parameters: {len(train_params)}") + logger.info(f"count of non-trainable parameters: {len(non_train_params)}") for p in non_train_params: p.requires_grad_(False) @@ -388,7 +391,7 @@ def load_lllite_weights(self, file, non_lllite_unet_sd=None): matches = pattern.findall(module_name) if matches is not None: for m in matches: - print(module_name, m) + logger.info(f"{module_name} {m}") module_name = module_name.replace(m, m.replace("_", "@")) module_name = module_name.replace("_", ".") module_name = module_name.replace("@", "_") @@ -407,7 +410,7 @@ def forward(self, x, timesteps=None, context=None, y=None, cond_image=None, **kw def replace_unet_linear_and_conv2d(): - print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net") + logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net") sdxl_original_unet.torch.nn.Linear = LLLiteLinear sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d @@ -419,10 +422,10 @@ def replace_unet_linear_and_conv2d(): replace_unet_linear_and_conv2d() # test shape etc - print("create unet") + logger.info("create unet") unet = SdxlUNet2DConditionModelControlNetLLLite() - print("enable ControlNet-LLLite") + logger.info("enable ControlNet-LLLite") unet.apply_lllite(32, 64, None, False, 1.0) unet.to("cuda") # .to(torch.float16) @@ -439,14 +442,14 @@ def replace_unet_linear_and_conv2d(): # unet_sd[converted_key] = model_sd[key] # info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd) - # print(info) + # logger.info(info) - # print(unet) + # logger.info(unet) - # print number of parameters + # logger.info number of parameters params = unet.prepare_params() - print("number of parameters", sum(p.numel() for p in params)) - # print("type any key to continue") + logger.info(f"number of parameters {sum(p.numel() for p in params)}") + # logger.info("type any key to continue") # input() unet.set_use_memory_efficient_attention(True, False) @@ -455,12 +458,12 @@ def replace_unet_linear_and_conv2d(): # # visualize # import torchviz - # print("run visualize") + # logger.info("run visualize") # controlnet.set_control(conditioning_image) # output = unet(x, t, ctx, y) - # print("make_dot") + # logger.info("make_dot") # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) - # print("render") + # logger.info("render") # image.format = "svg" # "png" # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time # input() @@ -471,13 +474,13 @@ def replace_unet_linear_and_conv2d(): scaler = torch.cuda.amp.GradScaler(enabled=True) - print("start training") + logger.info("start training") steps = 10 batch_size = 1 sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0] for step in range(steps): - print(f"step {step}") + logger.info(f"step {step}") conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 x = torch.randn(batch_size, 4, 128, 128).cuda() @@ -494,9 +497,9 @@ def replace_unet_linear_and_conv2d(): scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) - print(sample_param) + logger.info(sample_param) # from safetensors.torch import save_file - # print("save weights") + # logger.info("save weights") # unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None) diff --git a/networks/dylora.py b/networks/dylora.py index e5a55d198..262699014 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -15,7 +15,10 @@ from typing import List, Tuple, Union import torch from torch import nn - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class DyLoRAModule(torch.nn.Module): """ @@ -223,7 +226,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(f"{lora_name} {value.size()} {dim}") # support old LoRA without alpha for key in modules_dim.keys(): @@ -267,11 +270,11 @@ def __init__( self.apply_to_conv = apply_to_conv if modules_dim is not None: - print(f"create LoRA network from weights") + logger.info("create LoRA network from weights") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") if self.apply_to_conv: - print(f"apply LoRA to Conv2d with kernel size (3,3).") + logger.info("apply LoRA to Conv2d with kernel size (3,3).") # create module instances def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]: @@ -308,7 +311,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules return loras self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE @@ -316,7 +319,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras = create_modules(True, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") def set_multiplier(self, multiplier): self.multiplier = multiplier @@ -336,12 +339,12 @@ def load_weights(self, file): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -359,12 +362,12 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): apply_unet = True if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -375,7 +378,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - print(f"weights are merged") + logger.info(f"weights are merged") """ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): diff --git a/networks/extract_lora_from_dylora.py b/networks/extract_lora_from_dylora.py index 0abee9836..1184cd8a5 100644 --- a/networks/extract_lora_from_dylora.py +++ b/networks/extract_lora_from_dylora.py @@ -10,7 +10,10 @@ from tqdm import tqdm from library import train_util, model_util import numpy as np - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name): if model_util.is_safetensors(file_name): @@ -40,13 +43,13 @@ def split_lora_model(lora_sd, unit): rank = value.size()[0] if rank > max_rank: max_rank = rank - print(f"Max rank: {max_rank}") + logger.info(f"Max rank: {max_rank}") rank = unit split_models = [] new_alpha = None while rank < max_rank: - print(f"Splitting rank {rank}") + logger.info(f"Splitting rank {rank}") new_sd = {} for key, value in lora_sd.items(): if "lora_down" in key: @@ -57,7 +60,7 @@ def split_lora_model(lora_sd, unit): # なぜかscaleするとおかしくなる…… # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0] # scale = math.sqrt(this_rank / rank) # rank is > unit - # print(key, value.size(), this_rank, rank, value, scale) + # logger.info(key, value.size(), this_rank, rank, value, scale) # new_alpha = value * scale # always same # new_sd[key] = new_alpha new_sd[key] = value @@ -69,10 +72,10 @@ def split_lora_model(lora_sd, unit): def split(args): - print("loading Model...") + logger.info("loading Model...") lora_sd, metadata = load_state_dict(args.model) - print("Splitting Model...") + logger.info("Splitting Model...") original_rank, split_models = split_lora_model(lora_sd, args.unit) comment = metadata.get("ss_training_comment", "") @@ -94,7 +97,7 @@ def split(args): filename, ext = os.path.splitext(args.save_to) model_file_name = filename + f"-{new_rank:04d}{ext}" - print(f"saving model to: {model_file_name}") + logger.info(f"saving model to: {model_file_name}") save_to_file(model_file_name, state_dict, new_metadata) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index b9027adba..43c1d0058 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -11,7 +11,10 @@ from tqdm import tqdm from library import sai_model_spec, model_util, sdxl_model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # CLAMP_QUANTILE = 0.99 # MIN_DIFF = 1e-1 @@ -66,14 +69,14 @@ def str_to_dtype(p): # load models if not sdxl: - print(f"loading original SD model : {model_org}") + logger.info(f"loading original SD model : {model_org}") text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) text_encoders_o = [text_encoder_o] if load_dtype is not None: text_encoder_o = text_encoder_o.to(load_dtype) unet_o = unet_o.to(load_dtype) - print(f"loading tuned SD model : {model_tuned}") + logger.info(f"loading tuned SD model : {model_tuned}") text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) text_encoders_t = [text_encoder_t] if load_dtype is not None: @@ -85,7 +88,7 @@ def str_to_dtype(p): device_org = load_original_model_to if load_original_model_to else "cpu" device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu" - print(f"loading original SDXL model : {model_org}") + logger.info(f"loading original SDXL model : {model_org}") text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org ) @@ -95,7 +98,7 @@ def str_to_dtype(p): text_encoder_o2 = text_encoder_o2.to(load_dtype) unet_o = unet_o.to(load_dtype) - print(f"loading original SDXL model : {model_tuned}") + logger.info(f"loading original SDXL model : {model_tuned}") text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned ) @@ -135,7 +138,7 @@ def str_to_dtype(p): # Text Encoder might be same if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: text_encoder_different = True - print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") + logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") diffs[lora_name] = diff @@ -144,7 +147,7 @@ def str_to_dtype(p): del text_encoder if not text_encoder_different: - print("Text encoder is same. Extract U-Net only.") + logger.warning("Text encoder is same. Extract U-Net only.") lora_network_o.text_encoder_loras = [] diffs = {} # clear diffs @@ -166,7 +169,7 @@ def str_to_dtype(p): del unet_t # make LoRA with svd - print("calculating by svd") + logger.info("calculating by svd") lora_weights = {} with torch.no_grad(): for lora_name, mat in tqdm(list(diffs.items())): @@ -185,7 +188,7 @@ def str_to_dtype(p): if device: mat = mat.to(device) - # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) + # logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim if conv2d: @@ -230,7 +233,7 @@ def str_to_dtype(p): lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict info = lora_network_save.load_state_dict(lora_sd) - print(f"Loading extracted LoRA weights: {info}") + logger.info(f"Loading extracted LoRA weights: {info}") dir_name = os.path.dirname(save_to) if dir_name and not os.path.exists(dir_name): @@ -257,7 +260,7 @@ def str_to_dtype(p): metadata.update(sai_metadata) lora_network_save.save_weights(save_to, save_dtype, metadata) - print(f"LoRA weights are saved to: {save_to}") + logger.info(f"LoRA weights are saved to: {save_to}") def setup_parser() -> argparse.ArgumentParser: diff --git a/networks/lora.py b/networks/lora.py index 0c75cd428..948b30b0e 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -11,7 +11,12 @@ import numpy as np import torch import re +from library.utils import setup_logging +setup_logging() +import logging + +logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -46,7 +51,7 @@ def __init__( # if limit_rank: # self.lora_dim = min(lora_dim, in_dim, out_dim) # if self.lora_dim != lora_dim: - # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # else: self.lora_dim = lora_dim @@ -177,7 +182,7 @@ def merge_to(self, sd, dtype, device): else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + self.multiplier * conved * self.scale # set weight to org_module @@ -216,7 +221,7 @@ def set_region(self, region): self.region_mask = None def default_forward(self, x): - # print("default_forward", self.lora_name, x.size()) + # logger.info(f"default_forward {self.lora_name} {x.size()}") return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale def forward(self, x): @@ -245,7 +250,8 @@ def get_mask_for_x(self, x): if mask is None: # raise ValueError(f"mask is None for resolution {area}") # emb_layers in SDXL doesn't have mask - # print(f"mask is None for resolution {area}, {x.size()}") + # if "emb" not in self.lora_name: + # print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}") mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts if len(x.size()) != 4: @@ -263,6 +269,8 @@ def regional_forward(self, x): lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale mask = self.get_mask_for_x(lx) # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + # if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked) + # mask = mask.squeeze(-1) lx = lx * mask x = self.org_forward(x) @@ -291,7 +299,7 @@ def postp_to_q(self, x): if has_real_uncond: query[-self.network.batch_size :] = x[-self.network.batch_size :] - # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) + # logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}") return query def sub_prompt_forward(self, x): @@ -306,7 +314,7 @@ def sub_prompt_forward(self, x): lx = x[emb_idx :: self.network.num_sub_prompts] lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale - # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) + # logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}") x = self.org_forward(x) x[emb_idx :: self.network.num_sub_prompts] += lx @@ -314,7 +322,7 @@ def sub_prompt_forward(self, x): return x def to_out_forward(self, x): - # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) + # logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}") if self.network.is_last_network: masks = [None] * self.network.num_sub_prompts @@ -332,7 +340,7 @@ def to_out_forward(self, x): ) self.network.shared[self.lora_name] = (lx, masks) - # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) @@ -351,7 +359,7 @@ def to_out_forward(self, x): if has_real_uncond: out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond - # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") # if num_sub_prompts > num of LoRAs, fill with zero for i in range(len(masks)): if masks[i] is None: @@ -374,7 +382,7 @@ def to_out_forward(self, x): x1 = x1 + lx1 out[self.network.batch_size + i] = x1 - # print("to_out_forward", x.size(), out.size(), has_real_uncond) + # logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}") return out @@ -511,7 +519,9 @@ def parse_floats(s): len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + logger.warning( + f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" + ) block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -520,7 +530,7 @@ def parse_floats(s): len(block_alphas) == num_total_blocks ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" else: - print( + logger.warning( f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" ) block_alphas = [network_alpha] * num_total_blocks @@ -540,13 +550,13 @@ def parse_floats(s): else: if conv_alpha is None: conv_alpha = 1.0 - print( + logger.warning( f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" ) conv_block_alphas = [conv_alpha] * num_total_blocks else: if conv_dim is not None: - print( + logger.warning( f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" ) conv_block_dims = [conv_dim] * num_total_blocks @@ -586,7 +596,7 @@ def get_list(name_with_suffix) -> List[float]: elif name == "zeros": return [0.0 + base_lr] * max_len else: - print( + logger.error( "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" % (name) ) @@ -598,14 +608,14 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = get_list(up_lr_weight) if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) up_lr_weight = up_lr_weight[:max_len] down_lr_weight = down_lr_weight[:max_len] if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) if down_lr_weight != None and len(down_lr_weight) < max_len: down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) @@ -613,24 +623,24 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): - print("apply block learning rate / 階層別学習率を適用します。") + logger.info("apply block learning rate / 階層別学習率を適用します。") if down_lr_weight != None: down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] - print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) + logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") else: - print("down_lr_weight: all 1.0, すべて1.0") + logger.info("down_lr_weight: all 1.0, すべて1.0") if mid_lr_weight != None: mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - print("mid_lr_weight:", mid_lr_weight) + logger.info(f"mid_lr_weight: {mid_lr_weight}") else: - print("mid_lr_weight: 1.0") + logger.info("mid_lr_weight: 1.0") if up_lr_weight != None: up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] - print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) + logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") else: - print("up_lr_weight: all 1.0, すべて1.0") + logger.info("up_lr_weight: all 1.0, すべて1.0") return down_lr_weight, mid_lr_weight, up_lr_weight @@ -711,7 +721,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(lora_name, value.size(), dim) # support old LoRA without alpha for key in modules_dim.keys(): @@ -786,20 +796,26 @@ def __init__( self.module_dropout = module_dropout if modules_dim is not None: - print(f"create LoRA network from weights") + logger.info(f"create LoRA network from weights") elif block_dims is not None: - print(f"create LoRA network from block_dims") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - print(f"block_dims: {block_dims}") - print(f"block_alphas: {block_alphas}") + logger.info(f"create LoRA network from block_dims") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + logger.info(f"block_dims: {block_dims}") + logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: - print(f"conv_block_dims: {conv_block_dims}") - print(f"conv_block_alphas: {conv_block_alphas}") + logger.info(f"conv_block_dims: {conv_block_dims}") + logger.info(f"conv_block_alphas: {conv_block_alphas}") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) if self.conv_lora_dim is not None: - print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + ) # create module instances def create_modules( @@ -884,15 +900,15 @@ def create_modules( for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: index = i + 1 - print(f"create LoRA for Text Encoder {index}:") + logger.info(f"create LoRA for Text Encoder {index}:") else: index = None - print(f"create LoRA for Text Encoder:") + logger.info(f"create LoRA for Text Encoder:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE @@ -900,15 +916,15 @@ def create_modules( target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - print( + logger.warning( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: - print(f"\t{name}") + logger.info(f"\t{name}") self.up_lr_weight: List[float] = None self.down_lr_weight: List[float] = None @@ -926,6 +942,10 @@ def set_multiplier(self, multiplier): for lora in self.text_encoder_loras + self.unet_loras: lora.multiplier = self.multiplier + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file @@ -939,12 +959,12 @@ def load_weights(self, file): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -966,12 +986,12 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): apply_unet = True if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -982,7 +1002,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - print(f"weights are merged") + logger.info(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない def set_block_lr_weight( @@ -1113,7 +1133,7 @@ def set_region(self, sub_prompt_index, is_last_network, mask): for lora in self.text_encoder_loras + self.unet_loras: lora.set_network(self) - def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): + def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None): self.batch_size = batch_size self.num_sub_prompts = num_sub_prompts self.current_size = (height, width) @@ -1128,7 +1148,7 @@ def set_current_generation(self, batch_size, num_sub_prompts, width, height, sha device = ref_weight.device def resize_add(mh, mw): - # print(mh, mw, mh * mw) + # logger.info(mh, mw, mh * mw) m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 m = m.to(device, dtype=dtype) mask_dic[mh * mw] = m @@ -1139,6 +1159,13 @@ def resize_add(mh, mw): resize_add(h, w) if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2 resize_add(h + h % 2, w + w % 2) + + # deep shrink + if ds_ratio is not None: + hd = int(h * ds_ratio) + wd = int(w * ds_ratio) + resize_add(hd, wd) + h = (h + 1) // 2 w = (w + 1) // 2 diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py index 47d75ac4d..b99b02442 100644 --- a/networks/lora_diffusers.py +++ b/networks/lora_diffusers.py @@ -9,8 +9,15 @@ import numpy as np from tqdm import tqdm from transformers import CLIPTextModel + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def make_unet_conversion_map() -> Dict[str, str]: unet_conversion_map_layer = [] @@ -248,7 +255,7 @@ def create_network_from_weights( elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(f"{lora_name} {value.size()} {dim}") # support old LoRA without alpha for key in modules_dim.keys(): @@ -291,12 +298,12 @@ def __init__( super().__init__() self.multiplier = multiplier - print(f"create LoRA network from weights") + logger.info("create LoRA network from weights") # convert SDXL Stability AI's U-Net modules to Diffusers converted = self.convert_unet_modules(modules_dim, modules_alpha) if converted: - print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)") + logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)") # create module instances def create_modules( @@ -331,7 +338,7 @@ def create_modules( lora_name = lora_name.replace(".", "_") if lora_name not in modules_dim: - # print(f"skipped {lora_name} (not found in modules_dim)") + # logger.info(f"skipped {lora_name} (not found in modules_dim)") skipped.append(lora_name) continue @@ -362,18 +369,18 @@ def create_modules( text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") if len(skipped_te) > 0: - print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.") + logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.") # extend U-Net target modules to include Conv2d 3x3 target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras: List[LoRAModule] self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") if len(skipped_un) > 0: - print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.") + logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.") # assertion names = set() @@ -420,11 +427,11 @@ def set_multiplier(self, multiplier): def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") for lora in self.text_encoder_loras: lora.apply_to(multiplier) if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") for lora in self.unet_loras: lora.apply_to(multiplier) @@ -433,16 +440,16 @@ def unapply_to(self): lora.unapply_to() def merge_to(self, multiplier=1.0): - print("merge LoRA weights to original weights") + logger.info("merge LoRA weights to original weights") for lora in tqdm(self.text_encoder_loras + self.unet_loras): lora.merge_to(multiplier) - print(f"weights are merged") + logger.info(f"weights are merged") def restore_from(self, multiplier=1.0): - print("restore LoRA weights from original weights") + logger.info("restore LoRA weights from original weights") for lora in tqdm(self.text_encoder_loras + self.unet_loras): lora.restore_from(multiplier) - print(f"weights are restored") + logger.info(f"weights are restored") def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): # convert SDXL Stability AI's state dict to Diffusers' based state dict @@ -463,7 +470,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): my_state_dict = self.state_dict() for key in state_dict.keys(): if state_dict[key].size() != my_state_dict[key].size(): - # print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}") + # logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}") state_dict[key] = state_dict[key].view(my_state_dict[key].size()) return super().load_state_dict(state_dict, strict) @@ -476,7 +483,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline import torch - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = get_preferred_device() parser = argparse.ArgumentParser() parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface") @@ -490,7 +497,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): image_prefix = args.model_id.replace("/", "_") + "_" # load Diffusers model - print(f"load model from {args.model_id}") + logger.info(f"load model from {args.model_id}") pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline] if args.sdxl: # use_safetensors=True does not work with 0.18.2 @@ -503,7 +510,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder] # load LoRA weights - print(f"load LoRA weights from {args.lora_weights}") + logger.info(f"load LoRA weights from {args.lora_weights}") if os.path.splitext(args.lora_weights)[1] == ".safetensors": from safetensors.torch import load_file @@ -512,10 +519,10 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): lora_sd = torch.load(args.lora_weights) # create by LoRA weights and load weights - print(f"create LoRA network") + logger.info(f"create LoRA network") lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0) - print(f"load LoRA network weights") + logger.info(f"load LoRA network weights") lora_network.load_state_dict(lora_sd) lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this @@ -544,34 +551,34 @@ def seed_everything(seed): random.seed(seed) # create image with original weights - print(f"create image with original weights") + logger.info(f"create image with original weights") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "original.png") # apply LoRA network to the model: slower than merge_to, but can be reverted easily - print(f"apply LoRA network to the model") + logger.info(f"apply LoRA network to the model") lora_network.apply_to(multiplier=1.0) - print(f"create image with applied LoRA") + logger.info(f"create image with applied LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "applied_lora.png") # unapply LoRA network to the model - print(f"unapply LoRA network to the model") + logger.info(f"unapply LoRA network to the model") lora_network.unapply_to() - print(f"create image with unapplied LoRA") + logger.info(f"create image with unapplied LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "unapplied_lora.png") # merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to) - print(f"merge LoRA network to the model") + logger.info(f"merge LoRA network to the model") lora_network.merge_to(multiplier=1.0) - print(f"create image with LoRA") + logger.info(f"create image with LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "merged_lora.png") @@ -579,31 +586,31 @@ def seed_everything(seed): # restore (unmerge) LoRA weights: numerically unstable # マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない # 保存したstate_dictから元の重みを復元するのが確実 - print(f"restore (unmerge) LoRA weights") + logger.info(f"restore (unmerge) LoRA weights") lora_network.restore_from(multiplier=1.0) - print(f"create image without LoRA") + logger.info(f"create image without LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "unmerged_lora.png") # restore original weights - print(f"restore original weights") + logger.info(f"restore original weights") pipe.unet.load_state_dict(org_unet_sd) pipe.text_encoder.load_state_dict(org_text_encoder_sd) if args.sdxl: pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd) - print(f"create image with restored original weights") + logger.info(f"create image with restored original weights") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "restore_original.png") # use convenience function to merge LoRA weights - print(f"merge LoRA weights with convenience function") + logger.info(f"merge LoRA weights with convenience function") merge_lora_weights(pipe, lora_sd, multiplier=1.0) - print(f"create image with merged LoRA weights") + logger.info(f"create image with merged LoRA weights") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "convenience_merged_lora.png") diff --git a/networks/lora_fa.py b/networks/lora_fa.py index a357d7f7f..919222ce8 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -14,7 +14,10 @@ import numpy as np import torch import re - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -49,7 +52,7 @@ def __init__( # if limit_rank: # self.lora_dim = min(lora_dim, in_dim, out_dim) # if self.lora_dim != lora_dim: - # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # else: self.lora_dim = lora_dim @@ -197,7 +200,7 @@ def merge_to(self, sd, dtype, device): else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + self.multiplier * conved * self.scale # set weight to org_module @@ -236,7 +239,7 @@ def set_region(self, region): self.region_mask = None def default_forward(self, x): - # print("default_forward", self.lora_name, x.size()) + # logger.info("default_forward", self.lora_name, x.size()) return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale def forward(self, x): @@ -278,7 +281,7 @@ def regional_forward(self, x): # apply mask for LoRA result lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale mask = self.get_mask_for_x(lx) - # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + # logger.info("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) lx = lx * mask x = self.org_forward(x) @@ -307,7 +310,7 @@ def postp_to_q(self, x): if has_real_uncond: query[-self.network.batch_size :] = x[-self.network.batch_size :] - # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) + # logger.info("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) return query def sub_prompt_forward(self, x): @@ -322,7 +325,7 @@ def sub_prompt_forward(self, x): lx = x[emb_idx :: self.network.num_sub_prompts] lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale - # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) + # logger.info("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) x = self.org_forward(x) x[emb_idx :: self.network.num_sub_prompts] += lx @@ -330,7 +333,7 @@ def sub_prompt_forward(self, x): return x def to_out_forward(self, x): - # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) + # logger.info("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) if self.network.is_last_network: masks = [None] * self.network.num_sub_prompts @@ -348,7 +351,7 @@ def to_out_forward(self, x): ) self.network.shared[self.lora_name] = (lx, masks) - # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) @@ -367,7 +370,7 @@ def to_out_forward(self, x): if has_real_uncond: out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond - # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) # for i in range(len(masks)): # if masks[i] is None: # masks[i] = torch.zeros_like(masks[-1]) @@ -389,7 +392,7 @@ def to_out_forward(self, x): x1 = x1 + lx1 out[self.network.batch_size + i] = x1 - # print("to_out_forward", x.size(), out.size(), has_real_uncond) + # logger.info("to_out_forward", x.size(), out.size(), has_real_uncond) return out @@ -526,7 +529,7 @@ def parse_floats(s): len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -535,7 +538,7 @@ def parse_floats(s): len(block_alphas) == num_total_blocks ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" else: - print( + logger.warning( f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" ) block_alphas = [network_alpha] * num_total_blocks @@ -555,13 +558,13 @@ def parse_floats(s): else: if conv_alpha is None: conv_alpha = 1.0 - print( + logger.warning( f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" ) conv_block_alphas = [conv_alpha] * num_total_blocks else: if conv_dim is not None: - print( + logger.warning( f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" ) conv_block_dims = [conv_dim] * num_total_blocks @@ -601,7 +604,7 @@ def get_list(name_with_suffix) -> List[float]: elif name == "zeros": return [0.0 + base_lr] * max_len else: - print( + logger.error( "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" % (name) ) @@ -613,14 +616,14 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = get_list(up_lr_weight) if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) up_lr_weight = up_lr_weight[:max_len] down_lr_weight = down_lr_weight[:max_len] if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) if down_lr_weight != None and len(down_lr_weight) < max_len: down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) @@ -628,24 +631,24 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): - print("apply block learning rate / 階層別学習率を適用します。") + logger.info("apply block learning rate / 階層別学習率を適用します。") if down_lr_weight != None: down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] - print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) + logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") else: - print("down_lr_weight: all 1.0, すべて1.0") + logger.info("down_lr_weight: all 1.0, すべて1.0") if mid_lr_weight != None: mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - print("mid_lr_weight:", mid_lr_weight) + logger.info(f"mid_lr_weight: {mid_lr_weight}") else: - print("mid_lr_weight: 1.0") + logger.info("mid_lr_weight: 1.0") if up_lr_weight != None: up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] - print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) + logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") else: - print("up_lr_weight: all 1.0, すべて1.0") + logger.info("up_lr_weight: all 1.0, すべて1.0") return down_lr_weight, mid_lr_weight, up_lr_weight @@ -726,7 +729,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(lora_name, value.size(), dim) # support old LoRA without alpha for key in modules_dim.keys(): @@ -801,20 +804,20 @@ def __init__( self.module_dropout = module_dropout if modules_dim is not None: - print(f"create LoRA network from weights") + logger.info(f"create LoRA network from weights") elif block_dims is not None: - print(f"create LoRA network from block_dims") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - print(f"block_dims: {block_dims}") - print(f"block_alphas: {block_alphas}") + logger.info(f"create LoRA network from block_dims") + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"block_dims: {block_dims}") + logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: - print(f"conv_block_dims: {conv_block_dims}") - print(f"conv_block_alphas: {conv_block_alphas}") + logger.info(f"conv_block_dims: {conv_block_dims}") + logger.info(f"conv_block_alphas: {conv_block_alphas}") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") if self.conv_lora_dim is not None: - print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( @@ -899,15 +902,15 @@ def create_modules( for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: index = i + 1 - print(f"create LoRA for Text Encoder {index}:") + logger.info(f"create LoRA for Text Encoder {index}:") else: index = None - print(f"create LoRA for Text Encoder:") + logger.info(f"create LoRA for Text Encoder:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE @@ -915,15 +918,15 @@ def create_modules( target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - print( + logger.warning( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: - print(f"\t{name}") + logger.info(f"\t{name}") self.up_lr_weight: List[float] = None self.down_lr_weight: List[float] = None @@ -954,12 +957,12 @@ def load_weights(self, file): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -981,12 +984,12 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): apply_unet = True if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -997,7 +1000,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - print(f"weights are merged") + logger.info(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない def set_block_lr_weight( @@ -1144,7 +1147,7 @@ def set_current_generation(self, batch_size, num_sub_prompts, width, height, sha device = ref_weight.device def resize_add(mh, mw): - # print(mh, mw, mh * mw) + # logger.info(mh, mw, mh * mw) m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 m = m.to(device, dtype=dtype) mask_dic[mh * mw] = m diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index 0dc066fd1..6aaa58107 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -5,27 +5,34 @@ import library.train_util as train_util import argparse from transformers import CLIPTokenizer + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() import library.model_util as model_util import lora +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +DEVICE = get_preferred_device() def interrogate(args): weights_dtype = torch.float16 # いろいろ準備する - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") args.pretrained_model_name_or_path = args.sd_model args.vae = None text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE) - print(f"loading LoRA: {args.model}") + logger.info(f"loading LoRA: {args.model}") network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい @@ -35,11 +42,11 @@ def interrogate(args): has_te_weight = True break if not has_te_weight: - print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") + logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") return del vae - print("loading tokenizer") + logger.info("loading tokenizer") if args.v2: tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") else: @@ -53,7 +60,7 @@ def interrogate(args): # トークンをひとつひとつ当たっていく token_id_start = 0 token_id_end = max(tokenizer.all_special_ids) - print(f"interrogate tokens are: {token_id_start} to {token_id_end}") + logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}") def get_all_embeddings(text_encoder): embs = [] @@ -79,24 +86,24 @@ def get_all_embeddings(text_encoder): embs.extend(encoder_hidden_states) return torch.stack(embs) - print("get original text encoder embeddings.") + logger.info("get original text encoder embeddings.") orig_embs = get_all_embeddings(text_encoder) network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0) info = network.load_state_dict(weights_sd, strict=False) - print(f"Loading LoRA weights: {info}") + logger.info(f"Loading LoRA weights: {info}") network.to(DEVICE, dtype=weights_dtype) network.eval() del unet - print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") - print("get text encoder embeddings with lora.") + logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") + logger.info("get text encoder embeddings with lora.") lora_embs = get_all_embeddings(text_encoder) # 比べる:とりあえず単純に差分の絶対値で - print("comparing...") + logger.info("comparing...") diffs = {} for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))): diff = torch.mean(torch.abs(orig_emb - lora_emb)) diff --git a/networks/merge_lora.py b/networks/merge_lora.py index 71492621e..fea8a3f32 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -7,7 +7,10 @@ from library import sai_model_spec, train_util import library.model_util as model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -61,10 +64,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): name_to_module[lora_name] = child_module for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -73,10 +76,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): # find original module for this lora module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") + logger.info(f"no module found for LoRA weight: {key}") continue module = name_to_module[module_name] - # print(f"apply {key} to {module}") + # logger.info(f"apply {key} to {module}") down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -104,7 +107,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + ratio * conved * scale module.weight = torch.nn.Parameter(weight) @@ -118,7 +121,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): v2 = None base_model = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) if lora_metadata is not None: @@ -151,10 +154,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") # merge - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if "alpha" in key: continue @@ -196,8 +199,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_down] = merged_sd[key_down][perm] merged_sd[key_up] = merged_sd[key_up][:,perm] - print("merged model") - print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") # check all dims are same dims_list = list(set(base_dims.values())) @@ -239,7 +242,7 @@ def str_to_dtype(p): save_dtype = merge_dtype if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) @@ -264,18 +267,18 @@ def str_to_dtype(p): ) if args.v2: # TODO read sai modelspec - print( + logger.warning( "Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) - print(f"saving SD model to: {args.save_to}") + logger.info(f"saving SD model to: {args.save_to}") model_util.save_stable_diffusion_checkpoint( args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae ) else: state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) - print(f"calculating hashes and creating metadata...") + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -289,12 +292,12 @@ def str_to_dtype(p): ) if v2: # TODO read sai modelspec - print( + logger.warning( "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) metadata.update(sai_metadata) - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) diff --git a/networks/merge_lora_old.py b/networks/merge_lora_old.py index ffd6b2b40..334d127b7 100644 --- a/networks/merge_lora_old.py +++ b/networks/merge_lora_old.py @@ -6,7 +6,10 @@ from safetensors.torch import load_file, save_file import library.model_util as model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == '.safetensors': @@ -54,10 +57,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): name_to_module[lora_name] = child_module for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): # find original module for this lora module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") + logger.info(f"no module found for LoRA weight: {key}") continue module = name_to_module[module_name] - # print(f"apply {key} to {module}") + # logger.info(f"apply {key} to {module}") down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -96,10 +99,10 @@ def merge_lora_models(models, ratios, merge_dtype): alpha = None dim = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if 'alpha' in key: if key in merged_sd: @@ -117,7 +120,7 @@ def merge_lora_models(models, ratios, merge_dtype): dim = lora_sd[key].size()[0] merged_sd[key] = lora_sd[key] * ratio - print(f"dim (rank): {dim}, alpha: {alpha}") + logger.info(f"dim (rank): {dim}, alpha: {alpha}") if alpha is None: alpha = dim @@ -142,19 +145,21 @@ def str_to_dtype(p): save_dtype = merge_dtype if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) - print(f"\nsaving SD model to: {args.save_to}") + logger.info("") + logger.info(f"saving SD model to: {args.save_to}") model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) else: state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) - print(f"\nsaving model to: {args.save_to}") + logger.info(f"") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype) diff --git a/networks/oft.py b/networks/oft.py index 1d088f877..461a98698 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -8,7 +8,10 @@ import numpy as np import torch import re - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -237,7 +240,7 @@ def __init__( self.dim = dim self.alpha = alpha - print( + logger.info( f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}" ) @@ -258,7 +261,7 @@ def create_modules( if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv): oft_name = prefix + "." + name + "." + child_name oft_name = oft_name.replace(".", "_") - # print(oft_name) + # logger.info(oft_name) oft = module_class( oft_name, @@ -279,7 +282,7 @@ def create_modules( target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) - print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.") + logger.info(f"create OFT for U-Net: {len(self.unet_ofts)} modules.") # assertion names = set() @@ -316,7 +319,7 @@ def is_mergeable(self): # TODO refactor to common function with apply_to def merge_to(self, text_encoder, unet, weights_sd, dtype, device): - print("enable OFT for U-Net") + logger.info("enable OFT for U-Net") for oft in self.unet_ofts: sd_for_lora = {} @@ -326,7 +329,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): oft.load_state_dict(sd_for_lora, False) oft.merge_to() - print(f"weights are merged") + logger.info(f"weights are merged") # 二つのText Encoderに別々の学習率を設定できるようにするといいかも def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): @@ -338,11 +341,11 @@ def enumerate_params(ofts): for oft in ofts: params.extend(oft.parameters()) - # print num of params + # logger.info num of params num_params = 0 for p in params: num_params += p.numel() - print(f"OFT params: {num_params}") + logger.info(f"OFT params: {num_params}") return params param_data = {"params": enumerate_params(self.unet_ofts)} diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 03fc545e7..c5932a893 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -8,6 +8,10 @@ from tqdm import tqdm from library import train_util, model_util import numpy as np +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) MIN_SV = 1e-6 @@ -206,7 +210,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn scale = network_alpha/network_dim if dynamic_method: - print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") + logger.info(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") lora_down_weight = None lora_up_weight = None @@ -275,10 +279,10 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn del param_dict if verbose: - print(verbose_str) + logger.info(verbose_str) - print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") - print("resizing complete") + logger.info(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") + logger.info("resizing complete") return o_lora_sd, network_dim, new_alpha @@ -304,10 +308,10 @@ def str_to_dtype(p): if save_dtype is None: save_dtype = merge_dtype - print("loading Model...") + logger.info("loading Model...") lora_sd, metadata = load_state_dict(args.model, merge_dtype) - print("Resizing Lora...") + logger.info("Resizing Lora...") state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) # update metadata @@ -329,7 +333,7 @@ def str_to_dtype(p): metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index c513eb59f..3383a80de 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -8,7 +8,10 @@ from library import sai_model_spec, sdxl_model_util, train_util import library.model_util as model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ name_to_module[lora_name] = child_module for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in tqdm(lora_sd.keys()): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -78,10 +81,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ # find original module for this lora module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") + logger.info(f"no module found for LoRA weight: {key}") continue module = name_to_module[module_name] - # print(f"apply {key} to {module}") + # logger.info(f"apply {key} to {module}") down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -92,7 +95,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ # W <- W + U * D weight = module.weight - # print(module_name, down_weight.size(), up_weight.size()) + # logger.info(module_name, down_weight.size(), up_weight.size()) if len(weight.size()) == 2: # linear weight = weight + ratio * (up_weight @ down_weight) * scale @@ -107,7 +110,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + ratio * conved * scale module.weight = torch.nn.Parameter(weight) @@ -121,7 +124,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): v2 = None base_model = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) if lora_metadata is not None: @@ -154,10 +157,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") # merge - print(f"merging...") + logger.info(f"merging...") for key in tqdm(lora_sd.keys()): if "alpha" in key: continue @@ -200,8 +203,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_down] = merged_sd[key_down][perm] merged_sd[key_up] = merged_sd[key_up][:,perm] - print("merged model") - print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") # check all dims are same dims_list = list(set(base_dims.values())) @@ -243,7 +246,7 @@ def str_to_dtype(p): save_dtype = merge_dtype if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") ( text_model1, @@ -265,14 +268,14 @@ def str_to_dtype(p): None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from ) - print(f"saving SD model to: {args.save_to}") + logger.info(f"saving SD model to: {args.save_to}") sdxl_model_util.save_stable_diffusion_checkpoint( args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype ) else: state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) - print(f"calculating hashes and creating metadata...") + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -286,7 +289,7 @@ def str_to_dtype(p): ) metadata.update(sai_metadata) - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 16e813b36..cb00a6000 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -1,4 +1,3 @@ -import math import argparse import os import time @@ -8,7 +7,10 @@ from library import sai_model_spec, train_util import library.model_util as model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) CLAMP_QUANTILE = 0.99 @@ -41,12 +43,12 @@ def save_to_file(file_name, state_dict, dtype, metadata): def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): - print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") + logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} v2 = None base_model = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) if lora_metadata is not None: @@ -56,7 +58,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) # merge - print(f"merging...") + logger.info(f"merging...") for key in tqdm(list(lora_sd.keys())): if "lora_down" not in key: continue @@ -73,7 +75,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty out_dim = up_weight.size()[0] conv2d = len(down_weight.size()) == 4 kernel_size = None if not conv2d else down_weight.size()[2:4] - # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) + # logger.info(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) # make original weight if not exist if lora_module_name not in merged_sd: @@ -110,7 +112,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty merged_sd[lora_module_name] = weight # extract from merged weights - print("extract new lora...") + logger.info("extract new lora...") merged_lora_sd = {} with torch.no_grad(): for lora_module_name, mat in tqdm(list(merged_sd.items())): @@ -188,7 +190,7 @@ def str_to_dtype(p): args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype ) - print(f"calculating hashes and creating metadata...") + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -203,12 +205,12 @@ def str_to_dtype(p): ) if v2: # TODO read sai modelspec - print( + logger.warning( "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) metadata.update(sai_metadata) - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, save_dtype, metadata) diff --git a/requirements.txt b/requirements.txt index d07b73671..a54b3b473 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,5 +29,7 @@ huggingface-hub==0.20.1 protobuf==3.20.3 # open clip for SDXL open-clip-torch==2.20.0 +# For logging +rich==13.7.0 # for kohya_ss library -e . diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 0db9e340e..641b3209f 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -16,10 +16,9 @@ import diffusers import numpy as np -import torch - -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory, get_preferred_device init_ipex() import torchvision @@ -55,6 +54,13 @@ from library.sdxl_original_unet import InferSdxlUNet2DConditionModel from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite +from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -76,12 +82,12 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -89,7 +95,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_memory_efficient_attention(False, False) unet.set_use_sdpa(True) @@ -106,7 +112,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform def replace_vae_attn_to_memory_efficient(): - print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states, **kwargs): @@ -162,7 +168,7 @@ def forward_flash_attn_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_xformers(): - print("VAE: Attention.forward has been replaced to xformers") + logger.info("VAE: Attention.forward has been replaced to xformers") import xformers.ops def forward_xformers(self, hidden_states, **kwargs): @@ -218,7 +224,7 @@ def forward_xformers_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_sdpa(): - print("VAE: Attention.forward has been replaced to sdpa") + logger.info("VAE: Attention.forward has been replaced to sdpa") def forward_sdpa(self, hidden_states, **kwargs): residual = hidden_states @@ -340,6 +346,8 @@ def __init__( self.control_nets: List[ControlNetLLLite] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + self.gradual_latent: GradualLatent = None + # Textual Inversion def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids @@ -352,7 +360,7 @@ def get_token_replacer(self, tokenizer): token_replacements = self.token_replacements_list[tokenizer_index] def replace_tokens(tokens): - # print("replace_tokens", tokens, "=>", token_replacements) + # logger.info("replace_tokens", tokens, "=>", token_replacements) if isinstance(tokens, torch.Tensor): tokens = tokens.tolist() @@ -370,6 +378,14 @@ def replace_tokens(tokens): def set_control_nets(self, ctrl_nets): self.control_nets = ctrl_nets + def set_gradual_latent(self, gradual_latent): + if gradual_latent is None: + print("gradual_latent is disabled") + self.gradual_latent = None + else: + print(f"gradual_latent is enabled: {gradual_latent}") + self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) + @torch.no_grad() def __call__( self, @@ -444,7 +460,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") + logger.info(f"negative_scale is ignored if guidance scalle <= 1.0") negative_scale = None # get unconditional embeddings for classifier free guidance @@ -548,7 +564,7 @@ def __call__( text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt if init_image is not None and self.clip_vision_model is not None: - print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) @@ -640,8 +656,7 @@ def __call__( init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist init_latents = init_latent_dist.sample(generator=generator) else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() init_latents = [] for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): init_latent_dist = self.vae.encode( @@ -704,7 +719,116 @@ def __call__( control_net.set_cond_image(None) each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) + + # # first, we downscale the latents to the half of the size + # # 最初に1/2に縮小する + # height, width = latents.shape[-2:] + # # latents = torch.nn.functional.interpolate(latents.float(), scale_factor=0.5, mode="bicubic", align_corners=False).to( + # # latents.dtype + # # ) + # latents = latents[:, :, ::2, ::2] + # current_scale = 0.5 + + # # how much to increase the scale at each step: .125 seems to work well (because it's 1/8?) + # # 各ステップに拡大率をどのくらい増やすか:.125がよさそう(たぶん1/8なので) + # scale_step = 0.125 + + # # timesteps at which to start increasing the scale: 1000 seems to be enough + # # 拡大を開始するtimesteps: 1000で十分そうである + # start_timesteps = 1000 + + # # how many steps to wait before increasing the scale again + # # small values leads to blurry images (because the latents are blurry after the upscale, so some denoising might be needed) + # # large values leads to flat images + + # # 何ステップごとに拡大するか + # # 小さいとボケる(拡大後のlatentsはボケた感じになるので、そこから数stepのdenoiseが必要と思われる) + # # 大きすぎると細部が書き込まれずのっぺりした感じになる + # every_n_steps = 5 + + # scale_step = input("scale step:") + # scale_step = float(scale_step) + # start_timesteps = input("start timesteps:") + # start_timesteps = int(start_timesteps) + # every_n_steps = input("every n steps:") + # every_n_steps = int(every_n_steps) + + # # for i, t in enumerate(tqdm(timesteps)): + # i = 0 + # last_step = 0 + # while i < len(timesteps): + # t = timesteps[i] + # print(f"[{i}] t={t}") + + # print(i, t, current_scale, latents.shape) + # if t < start_timesteps and current_scale < 1.0 and i % every_n_steps == 0: + # if i == last_step: + # pass + # else: + # print("upscale") + # current_scale = min(current_scale + scale_step, 1.0) + + # h = int(height * current_scale) // 8 * 8 + # w = int(width * current_scale) // 8 * 8 + + # latents = torch.nn.functional.interpolate(latents.float(), size=(h, w), mode="bicubic", align_corners=False).to( + # latents.dtype + # ) + # last_step = i + # i = max(0, i - every_n_steps + 1) + + # diff = timesteps[i] - timesteps[last_step] + # # resized_init_noise = torch.nn.functional.interpolate( + # # init_noise.float(), size=(h, w), mode="bicubic", align_corners=False + # # ).to(latents.dtype) + # # latents = self.scheduler.add_noise(latents, resized_init_noise, diff) + # latents = self.scheduler.add_noise(latents, torch.randn_like(latents), diff * 4) + # # latents += torch.randn_like(latents) / 100 * diff + # continue + + enable_gradual_latent = False + if self.gradual_latent: + if not hasattr(self.scheduler, "set_gradual_latent_params"): + print("gradual_latent is not supported for this scheduler. Ignoring.") + print(self.scheduler.__class__.__name__) + else: + enable_gradual_latent = True + step_elapsed = 1000 + current_ratio = self.gradual_latent.ratio + + # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする + height, width = latents.shape[-2:] + org_dtype = latents.dtype + if org_dtype == torch.bfloat16: + latents = latents.float() + latents = torch.nn.functional.interpolate( + latents, scale_factor=current_ratio, mode="bicubic", align_corners=False + ).to(org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if self.gradual_latent.gaussian_blur_ksize: + latents = self.gradual_latent.apply_unshark_mask(latents) + for i, t in enumerate(tqdm(timesteps)): + resized_size = None + if enable_gradual_latent: + # gradually upscale the latents / latentsを徐々にアップスケールする + if ( + t < self.gradual_latent.start_timesteps + and current_ratio < 1.0 + and step_elapsed >= self.gradual_latent.every_n_steps + ): + current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) + # make divisible by 8 because size of latents must be divisible at bottom of UNet + h = int(height * current_ratio) // 8 * 8 + w = int(width * current_ratio) // 8 * 8 + resized_size = (h, w) + self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) + step_elapsed = 0 + else: + self.scheduler.set_gradual_latent_params(None, None) + step_elapsed += 1 + # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -715,7 +839,7 @@ def __call__( if not enabled or ratio >= 1.0: continue if ratio < i / len(timesteps): - print(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") control_net.set_cond_image(None) each_control_net_enabled[j] = False @@ -773,6 +897,8 @@ def __call__( if is_cancelled_callback is not None and is_cancelled_callback(): return None + i += 1 + if return_latents: return latents @@ -780,8 +906,7 @@ def __call__( if vae_batch_size >= batch_size: image = self.vae.decode(latents.to(self.vae.dtype)).sample else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() images = [] for i in tqdm(range(0, batch_size, vae_batch_size)): images.append( @@ -796,8 +921,7 @@ def __call__( # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() if output_type == "pil": # image = self.numpy_to_pil(image) @@ -935,7 +1059,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L if word.strip() == "BREAK": # pad until next multiple of tokenizer's max token length pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) - print(f"BREAK pad_len: {pad_len}") + logger.info(f"BREAK pad_len: {pad_len}") for i in range(pad_len): # v2のときEOSをつけるべきかどうかわからないぜ # if i == 0: @@ -965,7 +1089,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L tokens.append(text_token) weights.append(text_weight) if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights @@ -1238,7 +1362,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): elif len(count_range) == 2: count_range = [int(count_range[0]), int(count_range[1])] else: - print(f"invalid count range: {count_range}") + logger.warning(f"invalid count range: {count_range}") count_range = [1, 1] if count_range[0] > count_range[1]: count_range = [count_range[1], count_range[0]] @@ -1306,9 +1430,8 @@ def replacer(): # endregion - # def load_clip_l14_336(dtype): -# print(f"loading CLIP: {CLIP_ID_L14_336}") +# logger.info(f"loading CLIP: {CLIP_ID_L14_336}") # text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) # return text_encoder @@ -1323,6 +1446,7 @@ class BatchDataBase(NamedTuple): mask_image: Any clip_prompt: str guide_image: Any + raw_prompt: str class BatchDataExt(NamedTuple): @@ -1378,7 +1502,7 @@ def main(args): replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) # tokenizerを読み込む - print("loading tokenizer") + logger.info("loading tokenizer") tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) # schedulerを用意する @@ -1406,7 +1530,7 @@ def main(args): scheduler_module = diffusers.schedulers.scheduling_euler_discrete has_clip_sample = False elif args.sampler == "euler_a" or args.sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_cls = EulerAncestralDiscreteSchedulerGL scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete has_clip_sample = False elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": @@ -1452,7 +1576,7 @@ def reset_sampler_noises(self, noises): self.sampler_noises = noises def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + # logger.info("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): noise = self.sampler_noises[self.sampler_noise_index] if shape != noise.shape: @@ -1461,7 +1585,7 @@ def randn(self, shape, device=None, dtype=None, layout=None, generator=None): noise = None if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) self.sampler_noise_index += 1 @@ -1493,11 +1617,11 @@ def __getattr__(self, item): # ↓以下は結局PipeでFalseに設定されるので意味がなかった # # clip_sample=Trueにする # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") + # logger.info("set clip_sample to True") # scheduler.config.clip_sample = True # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + device = get_preferred_device() # custom pipelineをコピったやつを生成する if args.vae_slices: @@ -1522,7 +1646,7 @@ def __getattr__(self, item): vae_dtype = dtype if args.no_half_vae: - print("set vae_dtype to float32") + logger.info("set vae_dtype to float32") vae_dtype = torch.float32 vae.to(vae_dtype).to(device) vae.eval() @@ -1547,10 +1671,10 @@ def __getattr__(self, item): network_merge = args.network_merge_n_models else: network_merge = 0 - print(f"network_merge: {network_merge}") + logger.info(f"network_merge: {network_merge}") for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) + logger.info(f"import network module: {network_module}") imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] @@ -1568,7 +1692,7 @@ def __getattr__(self, item): raise ValueError("No weight. Weight is required.") network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + logger.info(f"load network weights from: {network_weight}") if model_util.is_safetensors(network_weight) and args.network_show_meta: from safetensors.torch import safe_open @@ -1576,7 +1700,7 @@ def __getattr__(self, item): with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + logger.info(f"metadata for: {network_weight}: {metadata}") network, weights_sd = imported_module.create_network_from_weights( network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs @@ -1586,20 +1710,20 @@ def __getattr__(self, item): mergeable = network.is_mergeable() if network_merge and not mergeable: - print("network is not mergiable. ignore merge option.") + logger.warning("network is not mergiable. ignore merge option.") if not mergeable or i >= network_merge: # not merging network.apply_to([text_encoder1, text_encoder2], unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - print(f"weights are loaded: {info}") + logger.info(f"weights are loaded: {info}") if args.opt_channels_last: network.to(memory_format=torch.channels_last) network.to(dtype).to(device) if network_pre_calc: - print("backup original weights") + logger.info("backup original weights") network.backup_weights() networks.append(network) @@ -1613,7 +1737,7 @@ def __getattr__(self, item): # upscalerの指定があれば取得する upscaler = None if args.highres_fix_upscaler: - print("import upscaler module:", args.highres_fix_upscaler) + logger.info(f"import upscaler module: {args.highres_fix_upscaler}") imported_module = importlib.import_module(args.highres_fix_upscaler) us_kwargs = {} @@ -1622,7 +1746,7 @@ def __getattr__(self, item): key, value = net_arg.split("=") us_kwargs[key] = value - print("create upscaler") + logger.info("create upscaler") upscaler = imported_module.create_upscaler(**us_kwargs) upscaler.to(dtype).to(device) @@ -1639,7 +1763,7 @@ def __getattr__(self, item): # control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) if args.control_net_lllite_models: for i, model_file in enumerate(args.control_net_lllite_models): - print(f"loading ControlNet-LLLite: {model_file}") + logger.info(f"loading ControlNet-LLLite: {model_file}") from safetensors.torch import load_file @@ -1670,7 +1794,7 @@ def __getattr__(self, item): control_nets.append((control_net, ratio)) if args.opt_channels_last: - print(f"set optimizing: channels last") + logger.info(f"set optimizing: channels last") text_encoder1.to(memory_format=torch.channels_last) text_encoder2.to(memory_format=torch.channels_last) vae.to(memory_format=torch.channels_last) @@ -1694,7 +1818,7 @@ def __getattr__(self, item): args.clip_skip, ) pipe.set_control_nets(control_nets) - print("pipeline is ready.") + logger.info("pipeline is ready.") if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() @@ -1703,6 +1827,29 @@ def __getattr__(self, item): if args.ds_depth_1 is not None: unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + # Gradual Latent + if args.gradual_latent_timesteps is not None: + if args.gradual_latent_unsharp_params: + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + + gradual_latent = GradualLatent( + args.gradual_latent_ratio, + args.gradual_latent_timesteps, + args.gradual_latent_every_n_steps, + args.gradual_latent_ratio_step, + args.gradual_latent_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + # Textual Inversionを処理する if args.textual_inversion_embeddings: token_ids_embeds1 = [] @@ -1736,7 +1883,7 @@ def __getattr__(self, item): token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") assert ( min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 ), f"token ids1 is not ordered" @@ -1766,10 +1913,10 @@ def __getattr__(self, item): # promptを取得する if args.from_file is not None: - print(f"reading prompts from {args.from_file}") + logger.info(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0] + prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] elif args.prompt is not None: prompt_list = [args.prompt] else: @@ -1795,7 +1942,7 @@ def load_images(path): for p in paths: image = Image.open(p) if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") + logger.info(f"convert image to RGB from {image.mode}: {p}") image = image.convert("RGB") images.append(image) @@ -1811,14 +1958,14 @@ def resize_images(imgs, size): return resized if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") + logger.info(f"load image for img2img: {args.image_path}") init_images = load_images(args.image_path) assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") + logger.info(f"loaded {len(init_images)} images for img2img") # CLIP Vision if args.clip_vision_strength is not None: - print(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}") vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) vision_model.to(device, dtype) processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) @@ -1826,22 +1973,22 @@ def resize_images(imgs, size): pipe.clip_vision_model = vision_model pipe.clip_vision_processor = processor pipe.clip_vision_strength = args.clip_vision_strength - print(f"CLIP Vision model loaded.") + logger.info(f"CLIP Vision model loaded.") else: init_images = None if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") + logger.info(f"load mask for inpainting: {args.mask_path}") mask_images = load_images(args.mask_path) assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") + logger.info(f"loaded {len(mask_images)} mask images for inpainting") else: mask_images = None # promptがないとき、画像のPngInfoから取得する if init_images is not None and len(prompt_list) == 0 and not args.interactive: - print("get prompts from images' metadata") + logger.info("get prompts from images' metadata") for img in init_images: if "prompt" in img.text: prompt = img.text["prompt"] @@ -1870,17 +2017,17 @@ def resize_images(imgs, size): h = int(h * args.highres_fix_scale + 0.5) if init_images is not None: - print(f"resize img2img source images to {w}*{h}") + logger.info(f"resize img2img source images to {w}*{h}") init_images = resize_images(init_images, (w, h)) if mask_images is not None: - print(f"resize img2img mask images to {w}*{h}") + logger.info(f"resize img2img mask images to {w}*{h}") mask_images = resize_images(mask_images, (w, h)) regional_network = False if networks and mask_images: # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 regional_network = True - print("use mask as region") + logger.info("use mask as region") size = None for i, network in enumerate(networks): @@ -1905,14 +2052,16 @@ def resize_images(imgs, size): prev_image = None # for VGG16 guided if args.guide_image_path is not None: - print(f"load image for ControlNet guidance: {args.guide_image_path}") + logger.info(f"load image for ControlNet guidance: {args.guide_image_path}") guide_images = [] for p in args.guide_image_path: guide_images.extend(load_images(p)) - print(f"loaded {len(guide_images)} guide images for guidance") + logger.info(f"loaded {len(guide_images)} guide images for guidance") if len(guide_images) == 0: - print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + logger.warning( + f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" + ) guide_images = None else: guide_images = None @@ -1938,7 +2087,7 @@ def resize_images(imgs, size): max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") iter_seed = random.randint(0, 0x7FFFFFFF) # バッチ処理の関数 @@ -1950,7 +2099,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - print("process 1st stage") + logger.info("process 1st stage") batch_1st = [] for _, base, ext in batch: @@ -1995,7 +2144,7 @@ def scale_and_round(x): images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") + logger.info("process 2nd stage") width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height if upscaler: @@ -2041,7 +2190,7 @@ def scale_and_round(x): # このバッチの情報を取り出す ( return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image), + (step_first, _, _, _, init_image, mask_image, _, guide_image, _), ( width, height, @@ -2063,6 +2212,7 @@ def scale_and_round(x): prompts = [] negative_prompts = [] + raw_prompts = [] start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) noises = [ torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) @@ -2093,11 +2243,16 @@ def scale_and_round(x): all_images_are_same = True all_masks_are_same = True all_guide_images_are_same = True - for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + for i, ( + _, + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + _, + ) in enumerate(batch): prompts.append(prompt) negative_prompts.append(negative_prompt) seeds.append(seed) clip_prompts.append(clip_prompt) + raw_prompts.append(raw_prompt) if init_image is not None: init_images.append(init_image) @@ -2161,7 +2316,7 @@ def scale_and_round(x): n.restore_weights() for n in networks: n.pre_calculation() - print("pre-calculation... done") + logger.info("pre-calculation... done") images = pipe( prompts, @@ -2195,8 +2350,8 @@ def scale_and_round(x): # save image highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) ): if highres_fix: seed -= 1 # record original seed @@ -2212,6 +2367,8 @@ def scale_and_round(x): metadata.add_text("negative-scale", str(negative_scale)) if clip_prompt is not None: metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) metadata.add_text("original-height", str(original_height)) metadata.add_text("original-width", str(original_width)) metadata.add_text("original-height-negative", str(original_height_negative)) @@ -2240,7 +2397,9 @@ def scale_and_round(x): cv2.waitKey() cv2.destroyAllWindows() except ImportError: - print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + logger.error( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) return images @@ -2253,7 +2412,8 @@ def scale_and_round(x): # interactive valid = False while not valid: - print("\nType prompt:") + logger.info("") + logger.info("Type prompt:") try: raw_prompt = input() except EOFError: @@ -2300,76 +2460,84 @@ def scale_and_round(x): ds_timesteps_2 = args.ds_timesteps_2 ds_ratio = args.ds_ratio + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: try: m = re.match(r"w (\d+)", parg, re.IGNORECASE) if m: width = int(m.group(1)) - print(f"width: {width}") + logger.info(f"width: {width}") continue m = re.match(r"h (\d+)", parg, re.IGNORECASE) if m: height = int(m.group(1)) - print(f"height: {height}") + logger.info(f"height: {height}") continue m = re.match(r"ow (\d+)", parg, re.IGNORECASE) if m: original_width = int(m.group(1)) - print(f"original width: {original_width}") + logger.info(f"original width: {original_width}") continue m = re.match(r"oh (\d+)", parg, re.IGNORECASE) if m: original_height = int(m.group(1)) - print(f"original height: {original_height}") + logger.info(f"original height: {original_height}") continue m = re.match(r"nw (\d+)", parg, re.IGNORECASE) if m: original_width_negative = int(m.group(1)) - print(f"original width negative: {original_width_negative}") + logger.info(f"original width negative: {original_width_negative}") continue m = re.match(r"nh (\d+)", parg, re.IGNORECASE) if m: original_height_negative = int(m.group(1)) - print(f"original height negative: {original_height_negative}") + logger.info(f"original height negative: {original_height_negative}") continue m = re.match(r"ct (\d+)", parg, re.IGNORECASE) if m: crop_top = int(m.group(1)) - print(f"crop top: {crop_top}") + logger.info(f"crop top: {crop_top}") continue m = re.match(r"cl (\d+)", parg, re.IGNORECASE) if m: crop_left = int(m.group(1)) - print(f"crop left: {crop_left}") + logger.info(f"crop left: {crop_left}") continue m = re.match(r"s (\d+)", parg, re.IGNORECASE) if m: # steps steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") + logger.info(f"steps: {steps}") continue m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) if m: # seed seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") + logger.info(f"seeds: {seeds}") continue m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) if m: # scale scale = float(m.group(1)) - print(f"scale: {scale}") + logger.info(f"scale: {scale}") continue m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) @@ -2378,25 +2546,25 @@ def scale_and_round(x): negative_scale = None else: negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") + logger.info(f"negative scale: {negative_scale}") continue m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) if m: # strength strength = float(m.group(1)) - print(f"strength: {strength}") + logger.info(f"strength: {strength}") continue m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") + logger.info(f"negative prompt: {negative_prompt}") continue m = re.match(r"c (.+)", parg, re.IGNORECASE) if m: # clip prompt clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") + logger.info(f"clip prompt: {clip_prompt}") continue m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) @@ -2404,47 +2572,131 @@ def scale_and_round(x): network_muls = [float(v) for v in m.group(1).split(",")] while len(network_muls) < len(networks): network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") + logger.info(f"network mul: {network_muls}") continue # Deep Shrink m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 1 ds_depth_1 = int(m.group(1)) - print(f"deep shrink depth 1: {ds_depth_1}") + logger.info(f"deep shrink depth 1: {ds_depth_1}") continue m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 1 ds_timesteps_1 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 1: {ds_timesteps_1}") + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") continue m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 2 ds_depth_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink depth 2: {ds_depth_2}") + logger.info(f"deep shrink depth 2: {ds_depth_2}") continue m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 2 ds_timesteps_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 2: {ds_timesteps_2}") + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") continue m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink ratio ds_ratio = float(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink ratio: {ds_ratio}") + logger.info(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + print(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + print(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + print(f"gradual latent unsharp params: {gl_unsharp_params}") continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") # override Deep Shrink if ds_depth_1 is not None: @@ -2452,6 +2704,30 @@ def scale_and_round(x): ds_depth_1 = args.ds_depth_1 or 3 unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + # override Gradual Latent + if gl_timesteps is not None: + if gl_timesteps < 0: + gl_timesteps = args.gradual_latent_timesteps or 650 + if gl_unsharp_params is not None: + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + gradual_latent = GradualLatent( + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + # prepare seed if seeds is not None: # given in prompt # 数が足りないなら前のをそのまま使う @@ -2462,7 +2738,7 @@ def scale_and_round(x): if len(predefined_seeds) > 0: seed = predefined_seeds.pop(0) else: - print("predefined seeds are exhausted") + logger.error("predefined seeds are exhausted") seed = None elif args.iter_same_seed: seeds = iter_seed @@ -2472,7 +2748,7 @@ def scale_and_round(x): if seed is None: seed = random.randint(0, 0x7FFFFFFF) if args.interactive: - print(f"seed: {seed}") + logger.info(f"seed: {seed}") # prepare init image, guide image and mask init_image = mask_image = guide_image = None @@ -2488,7 +2764,7 @@ def scale_and_round(x): width = width - width % 32 height = height - height % 32 if width != init_image.size[0] or height != init_image.size[1]: - print( + logger.warning( f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" ) @@ -2513,7 +2789,9 @@ def scale_and_round(x): b1 = BatchData( False, - BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), BatchDataExt( width, height, @@ -2548,18 +2826,25 @@ def scale_and_round(x): process_batch(batch_data, highres_fix) batch_data.clear() - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") parser.add_argument( - "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む" + "--from_file", + type=str, + default=None, + help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", ) parser.add_argument( - "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)" + "--interactive", + action="store_true", + help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", ) parser.add_argument( "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" @@ -2571,7 +2856,9 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする") + parser.add_argument( + "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" + ) parser.add_argument( "--use_original_file_name", action="store_true", @@ -2582,10 +2869,16 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") parser.add_argument( - "--original_height", type=int, default=None, help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値" + "--original_height", + type=int, + default=None, + help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値", ) parser.add_argument( - "--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値" + "--original_width", + type=int, + default=None, + help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値", ) parser.add_argument( "--original_height_negative", @@ -2599,8 +2892,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", ) - parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値") - parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値") + parser.add_argument( + "--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値" + ) + parser.add_argument( + "--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値" + ) parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") parser.add_argument( "--vae_batch_size", @@ -2614,7 +2911,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", ) - parser.add_argument("--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない") + parser.add_argument( + "--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない" + ) parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") parser.add_argument( "--sampler", @@ -2646,9 +2945,14 @@ def setup_parser() -> argparse.ArgumentParser: default=7.5, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", ) - parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") parser.add_argument( - "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", ) parser.add_argument( "--tokenizer_cache_dir", @@ -2679,25 +2983,46 @@ def setup_parser() -> argparse.ArgumentParser: help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", ) parser.add_argument( - "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" + "--opt_channels_last", + action="store_true", + help="set channels last option to model / モデルにchannels lastを指定し最適化する", ) parser.add_argument( - "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名" + "--network_module", + type=str, + default=None, + nargs="*", + help="additional network module to use / 追加ネットワークを使う時そのモジュール名", ) parser.add_argument( "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" ) - parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率") parser.add_argument( - "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" + "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" ) - parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument( - "--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする" + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", + ) + parser.add_argument( + "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" ) - parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( - "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" + "--network_merge_n_models", + type=int, + default=None, + help="merge this number of networks / この数だけネットワークをマージする", + ) + parser.add_argument( + "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" + ) + parser.add_argument( + "--network_pre_calc", + action="store_true", + help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", ) parser.add_argument( "--network_regional_mask_max_color_codes", @@ -2712,7 +3037,9 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", ) - parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") + parser.add_argument( + "--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う" + ) parser.add_argument( "--max_embeddings_multiples", type=int, @@ -2729,7 +3056,10 @@ def setup_parser() -> argparse.ArgumentParser: help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", ) parser.add_argument( - "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" + "--highres_fix_steps", + type=int, + default=28, + help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", ) parser.add_argument( "--highres_fix_strength", @@ -2738,7 +3068,9 @@ def setup_parser() -> argparse.ArgumentParser: help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", ) parser.add_argument( - "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" + "--highres_fix_save_1st", + action="store_true", + help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", ) parser.add_argument( "--highres_fix_latents_upscaling", @@ -2746,7 +3078,10 @@ def setup_parser() -> argparse.ArgumentParser: help="use latents upscaling for highres fix / highres fixでlatentで拡大する", ) parser.add_argument( - "--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名" + "--highres_fix_upscaler", + type=str, + default=None, + help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", ) parser.add_argument( "--highres_fix_upscaler_args", @@ -2761,11 +3096,18 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" + "--negative_scale", + type=float, + default=None, + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", ) parser.add_argument( - "--control_net_lllite_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + "--control_net_lllite_models", + type=str, + default=None, + nargs="*", + help="ControlNet models to use / 使用するControlNetのモデル名", ) # parser.add_argument( # "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" @@ -2814,6 +3156,45 @@ def setup_parser() -> argparse.ArgumentParser: "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" ) + # gradual latent + parser.add_argument( + "--gradual_latent_timesteps", + type=int, + default=None, + help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", + ) + parser.add_argument( + "--gradual_latent_ratio", + type=float, + default=0.5, + help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", + ) + parser.add_argument( + "--gradual_latent_ratio_step", + type=float, + default=0.125, + help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", + ) + parser.add_argument( + "--gradual_latent_every_n_steps", + type=int, + default=3, + help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", + ) + parser.add_argument( + "--gradual_latent_s_noise", + type=float, + default=1.0, + help="s_noise for Gradual Latent / Gradual Latentのs_noise", + ) + parser.add_argument( + "--gradual_latent_unsharp_params", + type=str, + default=None, + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", + ) + # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) @@ -2825,4 +3206,5 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + setup_logging(args, reset=True) main(args) diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 15a70678f..084735665 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -8,10 +8,9 @@ import random from einops import repeat import numpy as np -import torch - -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, get_preferred_device init_ipex() from tqdm import tqdm @@ -23,6 +22,10 @@ from library import model_util, sdxl_model_util import networks.lora as lora +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # scheduler: このあたりの設定はSD1/2と同じでいいらしい # scheduler: The settings around here seem to be the same as SD1/2 @@ -85,7 +88,7 @@ def get_timestep_embedding(x, outdim): guidance_scale = 7 seed = None # 1 - DEVICE = "cuda" + DEVICE = get_preferred_device() DTYPE = torch.float16 # bfloat16 may work parser = argparse.ArgumentParser() @@ -140,7 +143,7 @@ def get_timestep_embedding(x, outdim): vae_dtype = DTYPE if DTYPE == torch.float16: - print("use float32 for vae") + logger.info("use float32 for vae") vae_dtype = torch.float32 vae.to(DEVICE, dtype=vae_dtype) vae.eval() @@ -187,7 +190,7 @@ def generate_image(prompt, prompt2, negative_prompt, seed=None): emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256) - # print("emb1", emb1.shape) + # logger.info("emb1", emb1.shape) c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE) uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right @@ -217,7 +220,7 @@ def call_text_encoder(text, text2): enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) text_embedding2_penu = enc_out["hidden_states"][-2] - # print("hidden_states2", text_embedding2_penu.shape) + # logger.info("hidden_states2", text_embedding2_penu.shape) text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion # 連結して終了 concat and finish @@ -226,7 +229,7 @@ def call_text_encoder(text, text2): # cond c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2) - # print(c_ctx.shape, c_ctx_p.shape, c_vector.shape) + # logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape) c_vector = torch.cat([c_ctx_pool, c_vector], dim=1) # uncond @@ -323,4 +326,4 @@ def call_text_encoder(text, text2): seed = int(seed) generate_image(prompt, prompt2, negative_prompt, seed) - print("Done!") + logger.info("Done!") diff --git a/sdxl_train.py b/sdxl_train.py index a3f6f3a17..aa161e8ac 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -1,7 +1,6 @@ # training with captions import argparse -import gc import math import os from multiprocessing import Value @@ -9,10 +8,9 @@ import toml from tqdm import tqdm -import torch - -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed @@ -20,6 +18,14 @@ from library import sdxl_model_util import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + import library.config_util as config_util import library.sdxl_train_util as sdxl_train_util from library.config_util import ( @@ -91,8 +97,11 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) - assert not args.weighted_captions, "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + assert ( + not args.weighted_captions + ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" assert ( not args.train_text_encoder or not args.cache_text_encoder_outputs ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" @@ -117,18 +126,18 @@ def train(args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -139,7 +148,7 @@ def train(args): ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -169,7 +178,7 @@ def train(args): train_util.debug_dataset(train_dataset_group, True) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" ) return @@ -185,7 +194,7 @@ def train(args): ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -252,9 +261,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -363,7 +370,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -407,8 +416,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) @@ -433,7 +441,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) # accelerator.print( # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" # ) @@ -453,7 +463,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) @@ -537,7 +547,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # print("text encoder outputs verified") + # logger.info("text encoder outputs verified") # get size embeddings orig_size = batch["original_sizes_hw"] @@ -724,12 +734,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): logit_scale, ckpt_info, ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) @@ -752,7 +763,9 @@ def setup_parser() -> argparse.ArgumentParser: help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", ) - parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") + parser.add_argument( + "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + ) parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument( "--no_half_vae", diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 7a88feb84..b11999bd6 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -2,7 +2,6 @@ # training code for ControlNet-LLLite with passing cond_image to U-Net's forward import argparse -import gc import json import math import os @@ -13,10 +12,9 @@ import toml from tqdm import tqdm -import torch - -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -45,6 +43,12 @@ apply_debiased_estimation, ) import networks.control_net_lllite_for_train as control_net_lllite_for_train +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # TODO 他のスクリプトと共通化する @@ -65,6 +69,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_user_config = args.dataset_config is not None @@ -78,11 +83,11 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -114,7 +119,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -124,7 +129,9 @@ def train(args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" else: - print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) if args.cache_text_encoder_outputs: assert ( @@ -132,7 +139,7 @@ def train(args): ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -164,9 +171,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -231,8 +236,8 @@ def train(args): accelerator.print("prepare optimizer, data loader etc.") trainable_params = list(unet.prepare_params()) - print(f"trainable params count: {len(trainable_params)}") - print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + logger.info(f"trainable params count: {len(trainable_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -254,7 +259,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -291,8 +298,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) @@ -323,8 +329,10 @@ def train(args): accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -341,7 +349,7 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -548,12 +556,13 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) @@ -569,8 +578,12 @@ def setup_parser() -> argparse.ArgumentParser: choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", ) - parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数") - parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument( + "--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数" + ) + parser.add_argument( + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") parser.add_argument( "--network_dropout", diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index b94bf5c1b..89a1bc8e0 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -1,5 +1,4 @@ import argparse -import gc import json import math import os @@ -10,10 +9,9 @@ import toml from tqdm import tqdm -import torch - -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -41,6 +39,12 @@ apply_debiased_estimation, ) import networks.control_net_lllite as control_net_lllite +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # TODO 他のスクリプトと共通化する @@ -61,6 +65,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_user_config = args.dataset_config is not None @@ -74,11 +79,11 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -110,7 +115,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -120,7 +125,9 @@ def train(args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" else: - print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) if args.cache_text_encoder_outputs: assert ( @@ -128,7 +135,7 @@ def train(args): ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -163,9 +170,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -199,8 +204,8 @@ def train(args): accelerator.print("prepare optimizer, data loader etc.") trainable_params = list(network.prepare_optimizer_params()) - print(f"trainable params count: {len(trainable_params)}") - print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + logger.info(f"trainable params count: {len(trainable_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -222,7 +227,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -264,8 +271,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) @@ -296,8 +302,10 @@ def train(args): accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -516,12 +524,13 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) @@ -537,8 +546,12 @@ def setup_parser() -> argparse.ArgumentParser: choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", ) - parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数") - parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument( + "--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数" + ) + parser.add_argument( + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") parser.add_argument( "--network_dropout", diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 5d363280d..d33239d92 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,13 +1,15 @@ import argparse -import torch - -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from library import sdxl_model_util, sdxl_train_util, train_util import train_network - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class SdxlNetworkTrainer(train_network.NetworkTrainer): def __init__(self): @@ -60,13 +62,12 @@ def cache_text_encoder_outputs_if_needed( if args.cache_text_encoder_outputs: if not args.lowram: # メモリ消費を減らす - print("move vae and unet to cpu to save memory") + logger.info("move vae and unet to cpu to save memory") org_vae_device = vae.device org_unet_device = unet.device vae.to("cpu") unet.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) # When TE is not be trained, it will not be prepared so we need to use explicit autocast with accelerator.autocast(): @@ -81,11 +82,10 @@ def cache_text_encoder_outputs_if_needed( text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) if not args.lowram: - print("move vae and unet back to original device") + logger.info("move vae and unet back to original device") vae.to(org_vae_device) unet.to(org_unet_device) else: @@ -143,7 +143,7 @@ def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, wei # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # print("text encoder outputs verified") + # logger.info("text encoder outputs verified") return encoder_hidden_states1, encoder_hidden_states2, pool2 diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index df3937135..b9a948bb2 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -2,10 +2,11 @@ import os import regex -import torch -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex init_ipex() + import open_clip from library import sdxl_model_util, sdxl_train_util, train_util diff --git a/stable_cascade_gen_img.py b/stable_cascade_gen_img.py new file mode 100644 index 000000000..b7e5fe4ef --- /dev/null +++ b/stable_cascade_gen_img.py @@ -0,0 +1,304 @@ +import argparse +import math +import os +import random +import time +import numpy as np + +from safetensors.torch import load_file, save_file +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTextConfig +from PIL import Image +from accelerate import init_empty_weights + +import library.stable_cascade as sc +import library.stable_cascade_utils as sc_utils +import library.device_utils as device_utils +from library import train_util +from library.sdxl_model_util import _load_state_dict_on_device + + +def main(args): + device = device_utils.get_preferred_device() + + loading_device = device if not args.lowvram else "cpu" + text_model_device = "cpu" + + dtype = torch.float32 + if args.bf16: + dtype = torch.bfloat16 + elif args.fp16: + dtype = torch.float16 + + text_model_dtype = torch.float32 + + # EfficientNet encoder + effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device) + effnet.eval().requires_grad_(False).to(loading_device) + + generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device) + generator_c.eval().requires_grad_(False).to(loading_device) + + generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device) + generator_b.eval().requires_grad_(False).to(loading_device) + + # CLIP encoders + tokenizer = sc_utils.load_tokenizer(args) + + text_model = sc_utils.load_clip_text_model( + args.text_model_checkpoint_path, text_model_dtype, text_model_device, args.save_text_model + ) + text_model = text_model.requires_grad_(False).to(text_model_dtype).to(text_model_device) + + # image_model = ( + # CLIPVisionModelWithProjection.from_pretrained(clip_image_model_name).requires_grad_(False).to(dtype).to(device) + # ) + + # vqGAN + stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device) + stage_a.eval().requires_grad_(False) + + # previewer + if args.previewer_checkpoint_path is not None: + previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=dtype, device=loading_device) + previewer.eval().requires_grad_(False) + else: + previewer = None + + # 謎のクラス gdf + gdf_c = sc.GDF( + schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=sc.VPScaler(), + target=sc.EpsilonTarget(), + noise_cond=sc.CosineTNoiseCond(), + loss_weight=None, + ) + gdf_b = sc.GDF( + schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=sc.VPScaler(), + target=sc.EpsilonTarget(), + noise_cond=sc.CosineTNoiseCond(), + loss_weight=None, + ) + + # Stage C Parameters + + # extras.sampling_configs["cfg"] = 4 + # extras.sampling_configs["shift"] = 2 + # extras.sampling_configs["timesteps"] = 20 + # extras.sampling_configs["t_start"] = 1.0 + + # # Stage B Parameters + # extras_b.sampling_configs["cfg"] = 1.1 + # extras_b.sampling_configs["shift"] = 1 + # extras_b.sampling_configs["timesteps"] = 10 + # extras_b.sampling_configs["t_start"] = 1.0 + b_cfg = 1.1 + b_shift = 1 + b_timesteps = 10 + b_t_start = 1.0 + + # caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee" + # height, width = 1024, 1024 + + while True: + print("type caption:") + # if Ctrl+Z is pressed, it will raise EOFError + try: + caption = input() + except EOFError: + break + + caption = caption.strip() + if caption == "": + continue + + # parse options: '--w' and '--h' for size, '--l' for cfg, '--s' for timesteps, '--f' for shift. if not specified, use default values + # e.g. "caption --w 4 --h 4 --l 20 --s 20 --f 1.0" + + tokens = caption.split() + width = height = 1024 + cfg = 4 + timesteps = 20 + shift = 2 + t_start = 1.0 # t_start is not an option, but it is a parameter + negative_prompt = "" + seed = None + + caption_tokens = [] + i = 0 + while i < len(tokens): + token = tokens[i] + if i == len(tokens) - 1: + caption_tokens.append(token) + i += 1 + continue + + if token == "--w": + width = int(tokens[i + 1]) + elif token == "--h": + height = int(tokens[i + 1]) + elif token == "--l": + cfg = float(tokens[i + 1]) + elif token == "--s": + timesteps = int(tokens[i + 1]) + elif token == "--f": + shift = float(tokens[i + 1]) + elif token == "--t": + t_start = float(tokens[i + 1]) + elif token == "--n": + negative_prompt = tokens[i + 1] + elif token == "--d": + seed = int(tokens[i + 1]) + else: + caption_tokens.append(token) + i += 1 + continue + + i += 2 + + caption = " ".join(caption_tokens) + + stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1) + + # PREPARE CONDITIONS + # cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model) + input_ids = tokenizer( + [caption], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" + )["input_ids"].to(text_model.device) + cond_text, cond_pooled = train_util.get_hidden_states_stable_cascade( + tokenizer.model_max_length, input_ids, tokenizer, text_model + ) + cond_text = cond_text.to(device, dtype=dtype) + cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype) + + # uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model) + input_ids = tokenizer( + [negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" + )["input_ids"].to(text_model.device) + uncond_text, uncond_pooled = train_util.get_hidden_states_stable_cascade( + tokenizer.model_max_length, input_ids, tokenizer, text_model + ) + uncond_text = uncond_text.to(device, dtype=dtype) + uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype) + + zero_img_emb = torch.zeros(1, 768, device=device) + + # 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく + conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb} + unconditions = { + "clip_text_pooled": uncond_pooled, + "clip": uncond_pooled, + "clip_text": uncond_text, + "clip_img": zero_img_emb, + } + conditions_b = {} + conditions_b.update(conditions) + unconditions_b = {} + unconditions_b.update(unconditions) + + # seed everything + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + random.seed(seed) + np.random.seed(seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + + if args.lowvram: + generator_c = generator_c.to(device) + + with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): + sampling_c = gdf_c.sample( + generator_c, + conditions, + stage_c_latent_shape, + unconditions, + device=device, + cfg=cfg, + shift=shift, + timesteps=timesteps, + t_start=t_start, + ) + for sampled_c, _, _ in tqdm(sampling_c, total=timesteps): + sampled_c = sampled_c + + conditions_b["effnet"] = sampled_c + unconditions_b["effnet"] = torch.zeros_like(sampled_c) + + if previewer is not None: + with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): + preview = previewer(sampled_c) + preview = preview.clamp(0, 1) + preview = preview.permute(0, 2, 3, 1).squeeze(0) + preview = preview.detach().float().cpu().numpy() + preview = Image.fromarray((preview * 255).astype(np.uint8)) + + timestamp_str = time.strftime("%Y%m%d_%H%M%S") + os.makedirs(args.outdir, exist_ok=True) + preview.save(os.path.join(args.outdir, f"preview_{timestamp_str}.png")) + + if args.lowvram: + generator_c = generator_c.to(loading_device) + device_utils.clean_memory_on_device(device) + generator_b = generator_b.to(device) + + with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): + sampling_b = gdf_b.sample( + generator_b, + conditions_b, + stage_b_latent_shape, + unconditions_b, + device=device, + cfg=b_cfg, + shift=b_shift, + timesteps=b_timesteps, + t_start=b_t_start, + ) + for sampled_b, _, _ in tqdm(sampling_b, total=b_t_start): + sampled_b = sampled_b + + if args.lowvram: + generator_b = generator_b.to(loading_device) + device_utils.clean_memory_on_device(device) + stage_a = stage_a.to(device) + + with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): + sampled = stage_a.decode(sampled_b).float() + # print(sampled.shape, sampled.min(), sampled.max()) + + if args.lowvram: + stage_a = stage_a.to(loading_device) + device_utils.clean_memory_on_device(device) + + # float 0-1 to PIL Image + sampled = sampled.clamp(0, 1) + sampled = sampled.mul(255).to(dtype=torch.uint8) + sampled = sampled.permute(0, 2, 3, 1) + sampled = sampled.cpu().numpy() + sampled = Image.fromarray(sampled[0]) + + timestamp_str = time.strftime("%Y%m%d_%H%M%S") + os.makedirs(args.outdir, exist_ok=True) + sampled.save(os.path.join(args.outdir, f"sampled_{timestamp_str}.png")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + sc_utils.add_effnet_arguments(parser) + train_util.add_tokenizer_arguments(parser) + sc_utils.add_stage_a_arguments(parser) + sc_utils.add_stage_b_arguments(parser) + sc_utils.add_stage_c_arguments(parser) + sc_utils.add_previewer_arguments(parser) + sc_utils.add_text_model_arguments(parser) + parser.add_argument("--bf16", action="store_true") + parser.add_argument("--fp16", action="store_true") + parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先") + parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode") + args = parser.parse_args() + + main(args) diff --git a/stable_cascade_train_stage_c.py b/stable_cascade_train_stage_c.py new file mode 100644 index 000000000..aae99b90f --- /dev/null +++ b/stable_cascade_train_stage_c.py @@ -0,0 +1,546 @@ +# training with captions + +import argparse +import math +import os +from multiprocessing import Value +from typing import List +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from diffusers import DDPMScheduler + +import library.train_util as train_util +from library.sdxl_train_util import add_sdxl_training_arguments +import library.stable_cascade_utils as sc_utils +import library.stable_cascade as sc + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + + # TODO add assertions for other unsupported options + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + tokenizer = sc_utils.load_tokenizer(args) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer]) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(32) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + loading_device = accelerator.device if args.lowram else "cpu" + effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device) + stage_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=weight_dtype, device=loading_device) + text_encoder1 = sc_utils.load_clip_text_model(args.text_model_checkpoint_path, dtype=weight_dtype, device=loading_device) + + if args.sample_at_first or args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None: + # Previewer is small enough to be loaded on CPU + previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=torch.float32, device="cpu") + previewer.eval() + else: + previewer = None + + # 学習を準備する + if cache_latents: + effnet.to(accelerator.device, dtype=effnet_dtype) + effnet.requires_grad_(False) + effnet.eval() + with torch.no_grad(): + train_dataset_group.cache_latents( + effnet, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX, + 32, + ) + effnet.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # 学習を準備する:モデルを適切な状態にする + if args.gradient_checkpointing: + accelerator.print("enable gradient checkpointing") + stage_c.set_gradient_checkpointing(True) + + train_stage_c = args.learning_rate > 0 + train_text_encoder1 = False + + if args.train_text_encoder: + accelerator.print("enable text encoder training") + if args.gradient_checkpointing: + text_encoder1.gradient_checkpointing_enable() + lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + train_text_encoder1 = lr_te1 > 0 + assert ( + train_text_encoder1 + ), "text_encoder1 learning rate is 0. Please set a positive value / text_encoder1の学習率が0です。正の値を設定してください。" + + if not train_text_encoder1: + text_encoder1.to(weight_dtype) + text_encoder1.requires_grad_(train_text_encoder1) + text_encoder1.train(train_text_encoder1) + else: + text_encoder1.to(weight_dtype) + text_encoder1.requires_grad_(False) + text_encoder1.eval() + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + with torch.no_grad(), accelerator.autocast(): + train_dataset_group.cache_text_encoder_outputs( + (tokenizer,), + (text_encoder1,), + accelerator.device, + None, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX, + ) + accelerator.wait_for_everyone() + + if not cache_latents: + effnet.requires_grad_(False) + effnet.eval() + effnet.to(accelerator.device, dtype=effnet_dtype) + + stage_c.requires_grad_(True) + if not train_stage_c: + stage_c.to(accelerator.device, dtype=weight_dtype) # because of stage_c will not be prepared + + training_models = [] + params_to_optimize = [] + if train_stage_c: + training_models.append(stage_c) + params_to_optimize.append({"params": list(stage_c.parameters()), "lr": args.learning_rate}) + + if train_text_encoder1: + training_models.append(text_encoder1) + params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + + # calculate number of trainable parameters + n_params = 0 + for params in params_to_optimize: + for p in params["params"]: + n_params += p.numel() + + accelerator.print(f"train stage-C: {train_stage_c}, text_encoder1: {train_text_encoder1}") + accelerator.print(f"number of models: {len(training_models)}") + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + stage_c.to(weight_dtype) + text_encoder1.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + stage_c.to(weight_dtype) + text_encoder1.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + if train_stage_c: + stage_c = accelerator.prepare(stage_c) + if train_text_encoder1: + text_encoder1 = accelerator.prepare(text_encoder1) + + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + # 謎のクラス GDF + gdf = sc.GDF( + schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=sc.VPScaler(), + target=sc.EpsilonTarget(), + noise_cond=sc.CosineTNoiseCond(), + loss_weight=sc.AdaptiveLossWeight() if args.adaptive_loss_weight else sc.P2LossWeight(), + ) + + # 以下2つの変数は、どうもデフォルトのままっぽい + # gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + # gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + + # For --sample_at_first + sc_utils.sample_images(accelerator, args, 0, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf) + + loss_recorder = train_util.LossRecorder() + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # latentに変換 + latents = effnet(batch["images"].to(effnet_dtype)).to(weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + input_ids1 = batch["input_ids"] + with torch.set_grad_enabled(args.train_text_encoder): + # Get the text embedding for conditioning + # TODO support weighted captions + input_ids1 = input_ids1.to(accelerator.device) + # unwrap_model is fine for models not wrapped by accelerator + encoder_hidden_states, pool = train_util.get_hidden_states_stable_cascade( + args.max_token_length, + input_ids1, + tokenizer, + text_encoder1, + None if not args.full_fp16 else weight_dtype, + accelerator, + ) + else: + encoder_hidden_states = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + pool = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + + pool = pool.unsqueeze(1) # add extra dimension b,1280 -> b,1,1280 + + # FORWARD PASS + with torch.no_grad(): + noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(latents, shift=1, loss_shift=1) + + zero_img_emb = torch.zeros(noised.shape[0], 768, device=accelerator.device) + with accelerator.autocast(): + pred = stage_c( + noised, noise_cond, clip_text=encoder_hidden_states, clip_text_pooled=pool, clip_img=zero_img_emb + ) + loss = torch.nn.functional.mse_loss(pred, target, reduction="none").mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() + + if args.adaptive_loss_weight: + gdf.loss_weight.update_buckets(logSNR, loss) # use loss instead of loss_adjusted + + accelerator.backward(loss_adjusted) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + sc_utils.sample_images(accelerator, args, None, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + sc_utils.save_stage_c_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(stage_c), + accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None, + ) + + current_loss = loss_adjusted.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + sc_utils.save_stage_c_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(stage_c), + accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None, + ) + + sc_utils.sample_images(accelerator, args, epoch + 1, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf) + + is_main_process = accelerator.is_main_process + # if is_main_process: + stage_c = accelerator.unwrap_model(stage_c) + text_encoder1 = accelerator.unwrap_model(text_encoder1) + + accelerator.end_training() + + if args.save_state: # and is_main_process: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + sc_utils.save_stage_c_model_on_end( + args, save_dtype, epoch, global_step, stage_c, text_encoder1 if train_text_encoder1 else None + ) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + sc_utils.add_effnet_arguments(parser) + sc_utils.add_stage_c_arguments(parser) + sc_utils.add_text_model_arguments(parser) + sc_utils.add_previewer_arguments(parser) + sc_utils.add_training_arguments(parser) + train_util.add_tokenizer_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_sdxl_training_arguments(parser) # cache text encoder outputs + + parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument( + "--learning_rate_te1", + type=float, + default=None, + help="learning rate for text encoder / text encoderの学習率", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 17916ef70..e25506e41 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -16,7 +16,10 @@ ConfigSanitizer, BlueprintGenerator, ) - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: train_util.prepare_dataset_args(args, True) @@ -41,18 +44,18 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -63,7 +66,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -90,7 +93,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -98,7 +101,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む - print("load model") + logger.info("load model") if args.sdxl: (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) else: @@ -152,7 +155,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.skip_existing: if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): - print(f"Skipping {image_info.latents_npz} because it already exists.") + logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") continue image_infos.append(image_info) diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 7d9b13d68..46bffc4e0 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -16,7 +16,10 @@ ConfigSanitizer, BlueprintGenerator, ) - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: train_util.prepare_dataset_args(args, True) @@ -48,18 +51,18 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -70,7 +73,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -95,14 +98,14 @@ def cache_to_disk(args: argparse.Namespace) -> None: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, _ = train_util.prepare_dtype(args) # モデルを読み込む - print("load model") + logger.info("load model") if args.sdxl: (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) text_encoders = [text_encoder1, text_encoder2] @@ -147,7 +150,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.skip_existing: if os.path.exists(image_info.text_encoder_outputs_npz): - print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") + logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") continue image_info.input_ids1 = input_ids1 diff --git a/tools/canny.py b/tools/canny.py index 5e0806898..f2190975c 100644 --- a/tools/canny.py +++ b/tools/canny.py @@ -1,6 +1,10 @@ import argparse import cv2 +import logging +from library.utils import setup_logging +setup_logging() +logger = logging.getLogger(__name__) def canny(args): img = cv2.imread(args.input) @@ -10,7 +14,7 @@ def canny(args): # canny_img = 255 - canny_img cv2.imwrite(args.output, canny_img) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py index fe30996aa..572ee2f0c 100644 --- a/tools/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -6,7 +6,10 @@ from diffusers import StableDiffusionPipeline import library.model_util as model_util - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def convert(args): # 引数を確認する @@ -30,7 +33,7 @@ def convert(args): # モデルを読み込む msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) - print(f"loading {msg}: {args.model_to_load}") + logger.info(f"loading {msg}: {args.model_to_load}") if is_load_ckpt: v2_model = args.v2 @@ -48,13 +51,13 @@ def convert(args): if args.v1 == args.v2: # 自動判定する v2_model = unet.config.cross_attention_dim == 1024 - print("checking model version: model is " + ("v2" if v2_model else "v1")) + logger.info("checking model version: model is " + ("v2" if v2_model else "v1")) else: v2_model = not args.v1 # 変換して保存する msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" - print(f"converting and saving as {msg}: {args.model_to_save}") + logger.info(f"converting and saving as {msg}: {args.model_to_save}") if is_save_ckpt: original_model = args.model_to_load if is_load_ckpt else None @@ -70,15 +73,15 @@ def convert(args): save_dtype=save_dtype, vae=vae, ) - print(f"model saved. total converted state_dict keys: {key_count}") + logger.info(f"model saved. total converted state_dict keys: {key_count}") else: - print( + logger.info( f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}" ) model_util.save_diffusers_checkpoint( v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index 68dec6cae..bbc643edc 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,6 +15,10 @@ from anime_face_detector import create_detector from tqdm import tqdm import numpy as np +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) KP_REYE = 11 KP_LEYE = 19 @@ -24,7 +28,7 @@ def detect_faces(detector, image, min_size): preds = detector(image) # bgr - # print(len(preds)) + # logger.info(len(preds)) faces = [] for pred in preds: @@ -78,7 +82,7 @@ def process(args): assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません" # アニメ顔検出モデルを読み込む - print("loading face detector.") + logger.info("loading face detector.") detector = create_detector('yolov3') # cropの引数を解析する @@ -97,7 +101,7 @@ def process(args): crop_h_ratio, crop_v_ratio = [float(t) for t in tokens] # 画像を処理する - print("processing.") + logger.info("processing.") output_extension = ".png" os.makedirs(args.dst_dir, exist_ok=True) @@ -111,7 +115,7 @@ def process(args): if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) if image.shape[2] == 4: - print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") + logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい h, w = image.shape[:2] @@ -144,11 +148,11 @@ def process(args): # 顔サイズを基準にリサイズする scale = args.resize_face_size / face_size if scale < cur_crop_width / w: - print( + logger.warning( f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") scale = cur_crop_width / w if scale < cur_crop_height / h: - print( + logger.warning( f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") scale = cur_crop_height / h elif crop_h_ratio is not None: @@ -157,10 +161,10 @@ def process(args): else: # 切り出しサイズ指定あり if w < cur_crop_width: - print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") + logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") scale = cur_crop_width / w if h < cur_crop_height: - print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") + logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") scale = cur_crop_height / h if args.resize_fit: scale = max(cur_crop_width / w, cur_crop_height / h) @@ -198,7 +202,7 @@ def process(args): face_img = face_img[y:y + cur_crop_height] # # debug - # print(path, cx, cy, angle) + # logger.info(path, cx, cy, angle) # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8)) # cv2.imshow("image", crp) # if cv2.waitKey() == 27: diff --git a/tools/latent_upscaler.py b/tools/latent_upscaler.py index ab1fa3390..f05cf7194 100644 --- a/tools/latent_upscaler.py +++ b/tools/latent_upscaler.py @@ -11,10 +11,16 @@ import numpy as np import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from torch import nn from tqdm import tqdm from PIL import Image - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): @@ -216,7 +222,7 @@ def upscale( upsampled_images = upsampled_images / 127.5 - 1.0 # convert upsample images to latents with batch size - # print("Encoding upsampled (LANCZOS4) images...") + # logger.info("Encoding upsampled (LANCZOS4) images...") upsampled_latents = [] for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)): batch = upsampled_images[i : i + vae_batch_size].to(vae.device) @@ -227,7 +233,7 @@ def upscale( upsampled_latents = torch.cat(upsampled_latents, dim=0) # upscale (refine) latents with this model with batch size - print("Upscaling latents...") + logger.info("Upscaling latents...") upscaled_latents = [] for i in range(0, upsampled_latents.shape[0], batch_size): with torch.no_grad(): @@ -242,7 +248,7 @@ def create_upscaler(**kwargs): weights = kwargs["weights"] model = Upscaler() - print(f"Loading weights from {weights}...") + logger.info(f"Loading weights from {weights}...") if os.path.splitext(weights)[1] == ".safetensors": from safetensors.torch import load_file @@ -255,20 +261,20 @@ def create_upscaler(**kwargs): # another interface: upscale images with a model for given images from command line def upscale_images(args: argparse.Namespace): - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + DEVICE = get_preferred_device() us_dtype = torch.float16 # TODO: support fp32/bf16 os.makedirs(args.output_dir, exist_ok=True) # load VAE with Diffusers assert args.vae_path is not None, "VAE path is required" - print(f"Loading VAE from {args.vae_path}...") + logger.info(f"Loading VAE from {args.vae_path}...") vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae") vae.to(DEVICE, dtype=us_dtype) # prepare model - print("Preparing model...") + logger.info("Preparing model...") upscaler: Upscaler = create_upscaler(weights=args.weights) - # print("Loading weights from", args.weights) + # logger.info("Loading weights from", args.weights) # upscaler.load_state_dict(torch.load(args.weights)) upscaler.eval() upscaler.to(DEVICE, dtype=us_dtype) @@ -303,14 +309,14 @@ def upscale_images(args: argparse.Namespace): image_debug.save(dest_file_name) # upscale - print("Upscaling...") + logger.info("Upscaling...") upscaled_latents = upscaler.upscale( vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size ) upscaled_latents /= 0.18215 # decode with batch - print("Decoding...") + logger.info("Decoding...") upscaled_images = [] for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)): with torch.no_grad(): diff --git a/tools/merge_models.py b/tools/merge_models.py index 391bfe677..8f1fbf2f8 100644 --- a/tools/merge_models.py +++ b/tools/merge_models.py @@ -5,7 +5,10 @@ from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def is_unet_key(key): # VAE or TextEncoder, the last one is for SDXL @@ -45,10 +48,10 @@ def merge(args): # check if all models are safetensors for model in args.models: if not model.endswith("safetensors"): - print(f"Model {model} is not a safetensors model") + logger.info(f"Model {model} is not a safetensors model") exit() if not os.path.isfile(model): - print(f"Model {model} does not exist") + logger.info(f"Model {model} does not exist") exit() assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models" @@ -65,7 +68,7 @@ def merge(args): if merged_sd is None: # load first model - print(f"Loading model {model}, ratio = {ratio}...") + logger.info(f"Loading model {model}, ratio = {ratio}...") merged_sd = {} with safe_open(model, framework="pt", device=args.device) as f: for key in tqdm(f.keys()): @@ -81,11 +84,11 @@ def merge(args): value = ratio * value.to(dtype) # first model's value * ratio merged_sd[key] = value - print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) + logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) continue # load other models - print(f"Loading model {model}, ratio = {ratio}...") + logger.info(f"Loading model {model}, ratio = {ratio}...") with safe_open(model, framework="pt", device=args.device) as f: model_keys = f.keys() @@ -93,7 +96,7 @@ def merge(args): _, new_key = replace_text_encoder_key(key) if new_key not in merged_sd: if args.show_skipped and new_key not in first_model_keys: - print(f"Skip: {new_key}") + logger.info(f"Skip: {new_key}") continue value = f.get_tensor(key) @@ -104,7 +107,7 @@ def merge(args): for key in merged_sd.keys(): if key in model_keys: continue - print(f"Key {key} not in model {model}, use first model's value") + logger.warning(f"Key {key} not in model {model}, use first model's value") if key in supplementary_key_ratios: supplementary_key_ratios[key] += ratio else: @@ -112,7 +115,7 @@ def merge(args): # add supplementary keys' value (including VAE and TextEncoder) if len(supplementary_key_ratios) > 0: - print("add first model's value") + logger.info("add first model's value") with safe_open(args.models[0], framework="pt", device=args.device) as f: for key in tqdm(f.keys()): _, new_key = replace_text_encoder_key(key) @@ -120,7 +123,7 @@ def merge(args): continue if is_unet_key(new_key): # not VAE or TextEncoder - print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") + logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") value = f.get_tensor(key) # original key @@ -134,7 +137,7 @@ def merge(args): if not output_file.endswith(".safetensors"): output_file = output_file + ".safetensors" - print(f"Saving to {output_file}...") + logger.info(f"Saving to {output_file}...") # convert to save_dtype for k in merged_sd.keys(): @@ -142,7 +145,7 @@ def merge(args): save_file(merged_sd, output_file) - print("Done!") + logger.info("Done!") if __name__ == "__main__": diff --git a/tools/original_control_net.py b/tools/original_control_net.py index cd47bd76a..5640d542d 100644 --- a/tools/original_control_net.py +++ b/tools/original_control_net.py @@ -7,7 +7,10 @@ from library.original_unet import UNet2DConditionModel, SampleOutput import library.model_util as model_util - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class ControlNetInfo(NamedTuple): unet: Any @@ -51,7 +54,7 @@ def load_control_net(v2, unet, model): # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む # state dictを読み込む - print(f"ControlNet: loading control SD model : {model}") + logger.info(f"ControlNet: loading control SD model : {model}") if model_util.is_safetensors(model): ctrl_sd_sd = load_file(model) @@ -61,7 +64,7 @@ def load_control_net(v2, unet, model): # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む is_difference = "difference" in ctrl_sd_sd - print("ControlNet: loading difference:", is_difference) + logger.info(f"ControlNet: loading difference: {is_difference}") # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく # またTransfer Controlの元weightとなる @@ -89,13 +92,13 @@ def load_control_net(v2, unet, model): # ControlNetのU-Netを作成する ctrl_unet = UNet2DConditionModel(**unet_config) info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) - print("ControlNet: loading Control U-Net:", info) + logger.info(f"ControlNet: loading Control U-Net: {info}") # U-Net以外のControlNetを作成する # TODO support middle only ctrl_net = ControlNet() info = ctrl_net.load_state_dict(zero_conv_sd) - print("ControlNet: loading ControlNet:", info) + logger.info("ControlNet: loading ControlNet: {info}") ctrl_unet.to(unet.device, dtype=unet.dtype) ctrl_net.to(unet.device, dtype=unet.dtype) @@ -117,7 +120,7 @@ def canny(img): return canny - print("Unsupported prep type:", prep_type) + logger.info(f"Unsupported prep type: {prep_type}") return None @@ -174,13 +177,26 @@ def call_unet_and_control_net( cnet_idx = step % cnet_cnt cnet_info = control_nets[cnet_idx] - # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + # logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) if cnet_info.ratio < current_ratio: return original_unet(sample, timestep, encoder_hidden_states) guided_hint = guided_hints[cnet_idx] + + # gradual latent support: match the size of guided_hint to the size of sample + if guided_hint.shape[-2:] != sample.shape[-2:]: + # print(f"guided_hint.shape={guided_hint.shape}, sample.shape={sample.shape}") + org_dtype = guided_hint.dtype + if org_dtype == torch.bfloat16: + guided_hint = guided_hint.to(torch.float32) + guided_hint = torch.nn.functional.interpolate(guided_hint, size=sample.shape[-2:], mode="bicubic") + if org_dtype == torch.bfloat16: + guided_hint = guided_hint.to(org_dtype) + guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1)) - outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net) + outs = unet_forward( + True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net + ) outs = [o * cnet_info.weight for o in outs] # U-Net @@ -192,7 +208,7 @@ def call_unet_and_control_net( # ControlNet cnet_outs_list = [] for i, cnet_info in enumerate(control_nets): - # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + # logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) if cnet_info.ratio < current_ratio: continue guided_hint = guided_hints[i] @@ -232,7 +248,7 @@ def unet_forward( upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - print("Forward upsample size to force interpolation output size.") + logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # 1. time diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index 2d3224c4e..b8069fc1d 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,7 +6,10 @@ import math from PIL import Image import numpy as np - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): # Split the max_resolution string by "," and strip any whitespaces @@ -83,7 +86,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi image.save(os.path.join(dst_img_folder, new_filename), quality=100) proc = "Resized" if current_pixels > max_pixels else "Saved" - print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") + logger.info(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") # If other files with same basename, copy them with resolution suffix if copy_associated_files: @@ -94,7 +97,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi continue for max_resolution in max_resolutions: new_asoc_file = base + '+' + max_resolution + ext - print(f"Copy {asoc_file} as {new_asoc_file}") + logger.info(f"Copy {asoc_file} as {new_asoc_file}") shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) diff --git a/tools/show_metadata.py b/tools/show_metadata.py index 92ca7b1c8..05bfbe0a4 100644 --- a/tools/show_metadata.py +++ b/tools/show_metadata.py @@ -1,6 +1,10 @@ import json import argparse from safetensors import safe_open +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) @@ -10,10 +14,10 @@ metadata = f.metadata() if metadata is None: - print("No metadata found") + logger.error("No metadata found") else: # metadata is json dict, but not pretty printed # sort by key and pretty print print(json.dumps(metadata, indent=4, sort_keys=True)) - \ No newline at end of file + diff --git a/tools/stable_cascade_cache_latents.py b/tools/stable_cascade_cache_latents.py new file mode 100644 index 000000000..2ac875930 --- /dev/null +++ b/tools/stable_cascade_cache_latents.py @@ -0,0 +1,191 @@ +# Stable Cascadeのlatentsをdiskにキャッシュする +# cache latents of Stable Cascade to disk + +import argparse +import math +from multiprocessing import Value +import os + +from accelerate.utils import set_seed +import torch +from tqdm import tqdm + +from library import stable_cascade_utils as sc_utils +from library import config_util +from library import train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def cache_to_disk(args: argparse.Namespace) -> None: + train_util.prepare_dataset_args(args, True) + + # check cache latents arg + assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" + + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # tokenizerを準備する:datasetを動かすために必要 + tokenizer = sc_utils.load_tokenizer(args) + tokenizers = [tokenizer] + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) + + # datasetのcache_latentsを呼ばなければ、生の画像が返る + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, _ = train_util.prepare_dtype(args) + effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + logger.info("load model") + effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, accelerator.device) + effnet.to(accelerator.device, dtype=effnet_dtype) + effnet.requires_grad_(False) + effnet.eval() + + # dataloaderを準備する + train_dataset_group.set_caching_mode("latents") + + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず + train_dataloader = accelerator.prepare(train_dataloader) + + # データ取得のためのループ + for batch in tqdm(train_dataloader): + b_size = len(batch["images"]) + vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size + flip_aug = batch["flip_aug"] + random_crop = batch["random_crop"] + bucket_reso = batch["bucket_reso"] + + # バッチを分割して処理する + for i in range(0, b_size, vae_batch_size): + images = batch["images"][i : i + vae_batch_size] + absolute_paths = batch["absolute_paths"][i : i + vae_batch_size] + resized_sizes = batch["resized_sizes"][i : i + vae_batch_size] + + image_infos = [] + for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)): + image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) + image_info.image = image + image_info.bucket_reso = bucket_reso + image_info.resized_size = resized_size + image_info.latents_npz = os.path.splitext(absolute_path)[0] + train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX + + if args.skip_existing: + if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug, 32): + logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") + continue + + image_infos.append(image_info) + + if len(image_infos) > 0: + train_util.cache_batch_latents(effnet, True, image_infos, flip_aug, random_crop) + + accelerator.wait_for_everyone() + accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_tokenizer_arguments(parser) + sc_utils.add_effnet_arguments(parser) + train_util.add_training_arguments(parser, True) + train_util.add_dataset_arguments(parser, True, True, True) + config_util.add_config_arguments(parser) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う", + ) + parser.add_argument( + "--skip_existing", + action="store_true", + help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + cache_to_disk(args) diff --git a/tools/stable_cascade_cache_text_encoder_outputs.py b/tools/stable_cascade_cache_text_encoder_outputs.py new file mode 100644 index 000000000..240aaecee --- /dev/null +++ b/tools/stable_cascade_cache_text_encoder_outputs.py @@ -0,0 +1,183 @@ +# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance + +import argparse +import math +from multiprocessing import Value +import os + +from accelerate.utils import set_seed +import torch +from tqdm import tqdm + +from library import config_util +from library import train_util +from library import sdxl_train_util +from library import stable_cascade_utils as sc_utils +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def cache_to_disk(args: argparse.Namespace) -> None: + train_util.prepare_dataset_args(args, True) + + # check cache arg + assert ( + args.cache_text_encoder_outputs_to_disk + ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります" + + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # tokenizerを準備する:datasetを動かすために必要 + tokenizer = sc_utils.load_tokenizer(args) + tokenizers = [tokenizer] + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, _ = train_util.prepare_dtype(args) + + # モデルを読み込む + logger.info("load model") + text_encoder = sc_utils.load_clip_text_model( + args.text_model_checkpoint_path, weight_dtype, accelerator.device, args.save_text_model + ) + text_encoders = [text_encoder] + for text_encoder in text_encoders: + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + + # dataloaderを準備する + train_dataset_group.set_caching_mode("text") + + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず + train_dataloader = accelerator.prepare(train_dataloader) + + # データ取得のためのループ + for batch in tqdm(train_dataloader): + absolute_paths = batch["absolute_paths"] + input_ids1_list = batch["input_ids1_list"] + + image_infos = [] + for absolute_path, input_ids1 in zip(absolute_paths, input_ids1_list): + image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) + image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX + image_info + + if args.skip_existing: + if os.path.exists(image_info.text_encoder_outputs_npz): + logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") + continue + + image_info.input_ids1 = input_ids1 + image_infos.append(image_info) + + if len(image_infos) > 0: + b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos]) + train_util.cache_batch_text_encoder_outputs( + image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, None, weight_dtype + ) + + accelerator.wait_for_everyone() + accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_tokenizer_arguments(parser) + sc_utils.add_text_model_arguments(parser) + train_util.add_training_arguments(parser, True) + train_util.add_dataset_arguments(parser, True, True, True) + config_util.add_config_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + parser.add_argument( + "--skip_existing", + action="store_true", + help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + cache_to_disk(args) diff --git a/train_controlnet.py b/train_controlnet.py index 7b0b2bbfe..8963a5d62 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -1,5 +1,4 @@ import argparse -import gc import json import math import os @@ -10,10 +9,9 @@ import toml from tqdm import tqdm -import torch - -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -35,6 +33,12 @@ pyramid_noise_like, apply_noise_offset, ) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # TODO 他のスクリプトと共通化する @@ -56,6 +60,7 @@ def train(args): # training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_user_config = args.dataset_config is not None @@ -69,11 +74,11 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -103,7 +108,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -114,7 +119,7 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -219,10 +224,8 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - + clean_memory_on_device(accelerator.device) + accelerator.wait_for_everyone() if args.gradient_checkpointing: @@ -253,7 +256,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -309,8 +314,10 @@ def train(args): accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -332,7 +339,7 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -567,12 +574,13 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, controlnet, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) diff --git a/train_db.py b/train_db.py index 888cad25e..c89caaf2c 100644 --- a/train_db.py +++ b/train_db.py @@ -1,7 +1,6 @@ # DreamBooth training # XXX dropped option: fine_tune -import gc import argparse import itertools import math @@ -10,10 +9,9 @@ import toml from tqdm import tqdm -import torch - -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed @@ -35,6 +33,12 @@ scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # perlin_noise, @@ -42,6 +46,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, False) + setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -54,11 +59,11 @@ def train(args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -93,13 +98,13 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") if args.gradient_accumulation_steps > 1: - print( + logger.warning( f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong" ) - print( + logger.warning( f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" ) @@ -138,9 +143,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -193,7 +196,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -264,7 +269,7 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) @@ -449,12 +454,13 @@ def train(args): train_util.save_sd_model_on_train_end( args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) diff --git a/train_network.py b/train_network.py index 8d102ae8f..af15560ce 100644 --- a/train_network.py +++ b/train_network.py @@ -1,6 +1,5 @@ import importlib import argparse -import gc import math import os import sys @@ -11,13 +10,13 @@ import toml from tqdm import tqdm -import torch -from torch.nn.parallel import DistributedDataParallel as DDP - -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() +from torch.nn.parallel import DistributedDataParallel as DDP + from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import model_util @@ -41,6 +40,12 @@ add_v_prediction_like_loss, apply_debiased_estimation, ) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) class NetworkTrainer: @@ -136,6 +141,7 @@ def train(self, args): training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -153,18 +159,18 @@ def train(self, args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if use_user_config: - print(f"Loading dataset config from {args.dataset_config}") + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -175,7 +181,7 @@ def train(self, args): ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -204,7 +210,7 @@ def train(self, args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -217,7 +223,7 @@ def train(self, args): self.assert_extra_args(args, train_dataset_group) # acceleratorを準備する - print("preparing accelerator") + logger.info("preparing accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -266,9 +272,7 @@ def train(self, args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -310,7 +314,7 @@ def train(self, args): if hasattr(network, "prepare_network"): network.prepare_network(args) if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): - print( + logger.warning( "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません" ) args.scale_weight_norms = False @@ -938,12 +942,13 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) @@ -951,7 +956,9 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) - parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") + parser.add_argument( + "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" + ) parser.add_argument( "--save_model_as", type=str, @@ -963,10 +970,17 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") - parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") - parser.add_argument("--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール") parser.add_argument( - "--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)" + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) + parser.add_argument( + "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール" + ) + parser.add_argument( + "--network_dim", + type=int, + default=None, + help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)", ) parser.add_argument( "--network_alpha", @@ -981,14 +995,25 @@ def setup_parser() -> argparse.ArgumentParser: help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", ) parser.add_argument( - "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", ) - parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") parser.add_argument( - "--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する" + "--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する" ) parser.add_argument( - "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列" + "--network_train_text_encoder_only", + action="store_true", + help="only training Text Encoder part / Text Encoder関連部分のみ学習する", + ) + parser.add_argument( + "--training_comment", + type=str, + default=None, + help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列", ) parser.add_argument( "--dim_from_weights", diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 441c1e00b..a78a37b2c 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -1,15 +1,13 @@ import argparse -import gc import math import os from multiprocessing import Value import toml from tqdm import tqdm -import torch - -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed @@ -32,6 +30,12 @@ add_v_prediction_like_loss, apply_debiased_estimation, ) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) imagenet_templates_small = [ "a photo of a {}", @@ -168,6 +172,7 @@ def train(self, args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -178,7 +183,7 @@ def train(self, args): tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -288,7 +293,7 @@ def train(self, args): ] } else: - print("Train with captions.") + logger.info("Train with captions.") user_config = { "datasets": [ { @@ -363,9 +368,7 @@ def train(self, args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -736,12 +739,13 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) @@ -757,7 +761,9 @@ def setup_parser() -> argparse.ArgumentParser: help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", ) - parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み") + parser.add_argument( + "--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み" + ) parser.add_argument( "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" ) @@ -767,7 +773,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", ) - parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") + parser.add_argument( + "--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可" + ) parser.add_argument( "--use_object_template", action="store_true", diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 7046a4808..3f9155978 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -1,16 +1,14 @@ import importlib import argparse -import gc import math import os import toml from multiprocessing import Value from tqdm import tqdm -import torch - -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed @@ -36,6 +34,12 @@ ) import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) imagenet_templates_small = [ "a photo of a {}", @@ -94,12 +98,13 @@ def train(args): if args.output_name is None: args.output_name = args.token_string use_template = args.use_object_template or args.use_style_template + setup_logging(args, reset=True) train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None: - print( + logger.warning( "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません" ) assert ( @@ -114,7 +119,7 @@ def train(args): tokenizer = train_util.load_tokenizer(args) # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -127,7 +132,7 @@ def train(args): if args.init_word is not None: init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: - print( + logger.warning( f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" ) else: @@ -141,7 +146,7 @@ def train(args): ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"tokens are added: {token_ids}") + logger.info(f"tokens are added: {token_ids}") assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" @@ -169,7 +174,7 @@ def train(args): tokenizer.add_tokens(token_strings_XTI) token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI) - print(f"tokens are added (XTI): {token_ids_XTI}") + logger.info(f"tokens are added (XTI): {token_ids_XTI}") # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) @@ -178,7 +183,7 @@ def train(args): if init_token_ids is not None: for i, token_id in enumerate(token_ids_XTI): token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]] - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) # load weights if args.weights is not None: @@ -186,22 +191,22 @@ def train(args): assert len(token_ids) == len( embeddings ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" - # print(token_ids, embeddings.size()) + # logger.info(token_ids, embeddings.size()) for token_id, embedding in zip(token_ids_XTI, embeddings): token_embeds[token_id] = embedding - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - print(f"weighs loaded") + # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + logger.info(f"weighs loaded") - print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.info( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -209,14 +214,14 @@ def train(args): else: use_dreambooth_method = args.in_json is None if use_dreambooth_method: - print("Use DreamBooth method.") + logger.info("Use DreamBooth method.") user_config = { "datasets": [ {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} ] } else: - print("Train with captions.") + logger.info("Train with captions.") user_config = { "datasets": [ { @@ -240,7 +245,7 @@ def train(args): # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: - print(f"use template for training captions. is object: {args.use_object_template}") + logger.info(f"use template for training captions. is object: {args.use_object_template}") templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small replace_to = " ".join(token_strings) captions = [] @@ -264,7 +269,7 @@ def train(args): train_util.debug_dataset(train_dataset_group, show_input_ids=True) return if len(train_dataset_group) == 0: - print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + logger.error("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") return if cache_latents: @@ -286,9 +291,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -297,7 +300,7 @@ def train(args): text_encoder.gradient_checkpointing_enable() # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + logger.info("prepare optimizer, data loader etc.") trainable_params = text_encoder.get_input_embeddings().parameters() _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -318,7 +321,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + logger.info( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -332,7 +337,7 @@ def train(args): ) index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] - # print(len(index_no_updates), torch.sum(index_no_updates)) + # logger.info(len(index_no_updates), torch.sum(index_no_updates)) orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder @@ -370,15 +375,17 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + logger.info("running training / 学習開始") + logger.info(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + logger.info(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + logger.info(f" num epochs / epoch数: {num_train_epochs}") + logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}") + logger.info( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) + logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -393,17 +400,20 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers( + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + ) # function for saving/removing def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + logger.info("") + logger.info(f"saving checkpoint: {ckpt_file}") save_weights(ckpt_file, embs, save_dtype) if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) @@ -411,12 +421,13 @@ def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): def remove_model(old_ckpt_name): old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") + logger.info(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) # training loop for epoch in range(num_train_epochs): - print(f"\nepoch {epoch+1}/{num_train_epochs}") + logger.info("") + logger.info(f"epoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 text_encoder.train() @@ -586,7 +597,7 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def save_weights(file, updated_embs, save_dtype): @@ -647,6 +658,7 @@ def load_weights(file): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) @@ -662,7 +674,9 @@ def setup_parser() -> argparse.ArgumentParser: help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", ) - parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み") + parser.add_argument( + "--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み" + ) parser.add_argument( "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" ) @@ -672,7 +686,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", ) - parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") + parser.add_argument( + "--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可" + ) parser.add_argument( "--use_object_template", action="store_true",