Skip to content

Commit

Permalink
Merge branch 'sdxl-ctrl-net' of https://github.com/kohya-ss/sd-scripts
Browse files Browse the repository at this point in the history
…into sdxl_controlnet

# Conflicts:
#	library/train_util.py
  • Loading branch information
sdbds committed Oct 10, 2024
2 parents 1776955 + 886f753 commit 65b08d1
Show file tree
Hide file tree
Showing 60 changed files with 18,364 additions and 969 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/typos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ jobs:
- uses: actions/checkout@v4

- name: typos-action
uses: crate-ci/typos@v1.21.0
uses: crate-ci/typos@v1.24.3
640 changes: 637 additions & 3 deletions README.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions docs/config_README-en.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ These are options related to the configuration of the data set. They cannot be d

* `batch_size`
* This corresponds to the command-line argument `--train_batch_size`.
* `max_bucket_reso`, `min_bucket_reso`
* Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`.

These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each.

Expand Down
2 changes: 2 additions & 0 deletions docs/config_README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学

* `batch_size`
* コマンドライン引数の `--train_batch_size` と同等です。
* `max_bucket_reso`, `min_bucket_reso`
* bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。

これらの設定はデータセットごとに固定です。
つまり、データセットに所属するサブセットはこれらの設定を共有することになります。
Expand Down
2 changes: 1 addition & 1 deletion docs/train_README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
詳細については各自お調べください。
任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--scheduler_args`でオプション引数を指定してください。
任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--lr_scheduler_args`でオプション引数を指定してください。
### オプティマイザの指定について
Expand Down
2 changes: 1 addition & 1 deletion docs/train_README-zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
有关详细信息,请自行研究。
要使用任何调度程序,请像使用任何优化器一样使用“--scheduler_args”指定可选参数。
要使用任何调度程序,请像使用任何优化器一样使用“--lr_scheduler_args”指定可选参数。
### 关于指定优化器
使用 --optimizer_args 选项指定优化器选项参数。可以以key=value的格式指定多个值。此外,您可以指定多个值,以逗号分隔。例如,要指定 AdamW 优化器的参数,``--optimizer_args weight_decay=0.01 betas=.9,.999``。
Expand Down
63 changes: 46 additions & 17 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tqdm import tqdm

import torch
from library import deepspeed_utils
from library import deepspeed_utils, strategy_base
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()
Expand Down Expand Up @@ -39,6 +39,7 @@
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import library.strategy_sd as strategy_sd


def train(args):
Expand All @@ -52,7 +53,15 @@ def train(args):
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する

tokenizer = train_util.load_tokenizer(args)
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)

# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)

# データセットを準備する
if args.dataset_class is None:
Expand Down Expand Up @@ -81,16 +90,18 @@ def train(args):
]
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
train_dataset_group = train_util.load_arbitrary_dataset(args)

current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)

train_dataset_group.verify_bucket_reso_steps(64)

if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
Expand Down Expand Up @@ -165,8 +176,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)

train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)

vae.to("cpu")
clean_memory_on_device(accelerator.device)

Expand All @@ -192,6 +204,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
text_encoder.eval()

text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)

if not cache_latents:
vae.requires_grad_(False)
vae.eval()
Expand All @@ -214,7 +229,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)

# dataloaderを準備する
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()

# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -317,7 +336,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
)

# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
Expand All @@ -342,19 +366,22 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
if args.weighted_captions:
# TODO move to strategy_sd.py
encoder_hidden_states = get_weighted_text_embeddings(
tokenizer,
tokenize_strategy.tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder], [input_ids]
)[0]
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
Expand Down Expand Up @@ -409,7 +436,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
global_step += 1

train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)

# 指定ステップごとにモデルを保存
Expand All @@ -434,7 +461,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
)

current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
Expand All @@ -447,7 +474,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

Expand All @@ -472,7 +499,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae,
)

train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)

is_main_process = accelerator.is_main_process
if is_main_process:
Expand Down
8 changes: 5 additions & 3 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tqdm import tqdm

import library.train_util as train_util
from library.utils import setup_logging
from library.utils import setup_logging, pil_resize

setup_logging()
import logging
Expand Down Expand Up @@ -42,8 +42,10 @@ def preprocess_image(image):
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)

interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
if size > IMAGE_SIZE:
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
else:
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))

image = image.astype(np.float32)
return image
Expand Down
Loading

0 comments on commit 65b08d1

Please sign in to comment.