From 8fd96420103e7b2611eace8640b8477211c0e27e Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 12 Oct 2024 22:54:16 +0800 Subject: [PATCH] update --- sdxl_train_controlnet.py | 779 --------------------------------------- 1 file changed, 779 deletions(-) delete mode 100644 sdxl_train_controlnet.py diff --git a/sdxl_train_controlnet.py b/sdxl_train_controlnet.py deleted file mode 100644 index 46fff3078..000000000 --- a/sdxl_train_controlnet.py +++ /dev/null @@ -1,779 +0,0 @@ -import argparse -import math -import os -import random -from multiprocessing import Value -import toml - -from tqdm import tqdm - -import torch -from library.device_utils import init_ipex, clean_memory_on_device - -init_ipex() - -from torch.nn.parallel import DistributedDataParallel as DDP -from accelerate.utils import set_seed -from diffusers import DDPMScheduler, ControlNetModel -from diffusers.utils.torch_utils import is_compiled_module -from safetensors.torch import load_file -from library import ( - deepspeed_utils, - sai_model_spec, - sdxl_model_util, - sdxl_original_unet, - sdxl_train_util, -) - -import library.model_util as model_util -import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -import library.huggingface_util as huggingface_util -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import ( - add_v_prediction_like_loss, - apply_snr_weight, - prepare_scheduler_for_custom_training, - scale_v_prediction_loss_like_noise_prediction, - apply_debiased_estimation, - gradfilter_ema, - gradfilter_ma, -) -from library.utils import setup_logging, add_logging_arguments - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - - -# TODO 他のスクリプトと共通化する -def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): - logs = { - "loss/current": current_loss, - "loss/average": avr_loss, - "lr": lr_scheduler.get_last_lr()[0], - } - - if args.optimizer_type.lower().startswith("DAdapt".lower()): - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[-1].param_groups[0]["d"] - * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - ) - - return logs - - -def train(args): - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) - sdxl_train_util.verify_sdxl_training_args(args) - setup_logging(args, reset=True) - - cache_latents = args.cache_latents - use_user_config = args.dataset_config is not None - - if args.seed is None: - args.seed = random.randint(0, 2**32) - set_seed(args.seed) - - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - - # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) - if use_user_config: - logger.info(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "conditioning_data_dir"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - user_config = { - "datasets": [ - { - "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( - args.train_data_dir, - args.conditioning_data_dir, - args.caption_extension, - ) - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = (train_dataset_group if args.max_data_loader_n_workers == 0 else None) - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - train_dataset_group.verify_bucket_reso_steps(32) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - logger.error( - "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" - ) - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - else: - logger.warning( - "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" - ) - - if args.cache_text_encoder_outputs: - assert ( - train_dataset_group.is_text_encoder_output_cacheable() - ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - - # acceleratorを準備する - logger.info("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process - - def unwrap_model(model): - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model - return model - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - vae_dtype = torch.float32 if args.no_half_vae else weight_dtype - - # モデルを読み込む - ( - load_stable_diffusion_format, - text_encoder1, - text_encoder2, - vae, - unet, - logit_scale, - ckpt_info, - ) = sdxl_train_util.load_target_model( - args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype - ) - - # convert U-Net - with torch.no_grad(): - du_unet_sd = sdxl_model_util.convert_sdxl_unet_state_dict_to_diffusers(unet.state_dict()) - unet.to("cpu") - clean_memory_on_device(accelerator.device) - del unet - unet = sdxl_model_util.UNet2DConditionModel(**sdxl_model_util.DIFFUSERS_SDXL_UNET_CONFIG) - unet.load_state_dict(du_unet_sd) - - controlnet = ControlNetModel.from_unet(unet) - - if args.controlnet_model_name_or_path: - filename = args.controlnet_model_name_or_path - if os.path.isfile(filename): - if os.path.splitext(filename)[1] == ".safetensors": - state_dict = load_file(filename) - else: - state_dict = torch.load(filename) - state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) - controlnet.load_state_dict(state_dict) - elif os.path.isdir(filename): - controlnet = ControlNetModel.from_pretrained(filename) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=vae_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - - accelerator.wait_for_everyone() - - # TextEncoderの出力をキャッシュする - if args.cache_text_encoder_outputs: - # Text Encodes are eval and no grad - with torch.no_grad(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) - accelerator.wait_for_everyone() - - # モデルに xformers とか memory efficient attention を組み込む - # train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - if args.xformers: - unet.enable_xformers_memory_efficient_attention() - controlnet.enable_xformers_memory_efficient_attention() - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - controlnet.enable_gradient_checkpointing() - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - - trainable_params = list(filter(lambda p: p.requires_grad, controlnet.parameters())) - logger.info(f"trainable params count: {len(trainable_params)}") - logger.info( - f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}" - ) - - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader)/ accelerator.num_processes/ args.gradient_accumulation_steps - ) - accelerator.print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix( - args, optimizer, accelerator.num_processes - ) - - # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする - if args.full_fp16: - assert ( - args.mixed_precision == "fp16" - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - accelerator.print("enable full fp16 training.") - controlnet.to(weight_dtype) - unet.to(weight_dtype) - elif args.full_bf16: - assert ( - args.mixed_precision == "bf16" - ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" - accelerator.print("enable full bf16 training.") - controlnet.to(weight_dtype) - unet.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - controlnet, optimizer, train_dataloader, lr_scheduler - ) - - if args.fused_backward_pass: - # use fused optimizer for backward pass: other optimizers will be supported in the future - import library.adafactor_fused - - library.adafactor_fused.patch_adafactor_fused(optimizer) - for param_group in optimizer.param_groups: - for parameter in param_group["params"]: - if parameter.requires_grad: - - def __grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - parameter.register_post_accumulate_grad_hook(__grad_hook) - - unet.requires_grad_(False) - text_encoder1.requires_grad_(False) - text_encoder2.requires_grad_(False) - unet.to(accelerator.device, dtype=weight_dtype) - - # transform DDP after prepare - controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet - - controlnet.train() - - # TextEncoderの出力をキャッシュするときにはCPUへ移動する - if args.cache_text_encoder_outputs: - # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 - text_encoder1.to("cpu", dtype=torch.float32) - text_encoder2.to("cpu", dtype=torch.float32) - clean_memory_on_device(accelerator.device) - else: - # make sure Text Encoders are on GPU - text_encoder1.to(accelerator.device) - text_encoder2.to(accelerator.device) - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=vae_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - # TODO: find a way to handle total batch size when there are multiple datasets - accelerator.print("running training / 学習開始") - accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print( - f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) - # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm( - range(args.max_train_steps), - smoothing=0, - disable=not accelerator.is_local_main_process, - desc="steps", - ) - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - clip_sample=False, - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr( - noise_scheduler - ) - - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - ( - "controlnet_train" - if args.log_tracker_name is None - else args.log_tracker_name - ), - config=train_util.get_sanitized_config_or_none(args), - init_kwargs=init_kwargs, - ) - - loss_recorder = train_util.LossRecorder() - del train_dataset_group - - # function for saving/removing - def save_model(ckpt_name, model, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - accelerator.print(f"\nsaving checkpoint: {ckpt_file}") - sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) - sai_metadata["modelspec.architecture"] = ( - sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet" - ) - state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) - - if save_dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - if os.path.splitext(ckpt_file)[1] == ".safetensors": - from safetensors.torch import save_file - - save_file(state_dict, ckpt_file, sai_metadata) - else: - torch.save(state_dict, ckpt_file) - - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - accelerator.print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # For --sample_at_first - sdxl_train_util.sample_images( - accelerator, - args, - 0, - global_step, - accelerator.device, - vae, - [tokenizer1, tokenizer2], - [text_encoder1, text_encoder2], - unet, - controlnet=controlnet, - ) - - if args.gradfilter_ema_alpha or args.gradfilter_ma_window_size: - grads = None - - # training loop - for epoch in range(num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - - if args.optimizer_accumulation_steps > 0: - optimizer.optimizer.optimizer_accumulation = (step + 1) % args.optimizer_accumulation_steps != 0 - - with accelerator.accumulate(controlnet): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = ( - batch["latents"] - .to(accelerator.device) - .to(dtype=weight_dtype) - ) - else: - # latentに変換 - latents = ( - vae.encode(batch["images"].to(dtype=vae_dtype)) - .latent_dist.sample() - .to(dtype=weight_dtype) - ) - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print( - "NaN found in latents, replacing with zeros" - ) - latents = torch.nan_to_num(latents, 0, out=latents) - latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - - if ( - "text_encoder_outputs1_list" not in batch - or batch["text_encoder_outputs1_list"] is None - ): - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] - with torch.no_grad(): - # Get the text embedding for conditioning - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = ( - train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, - ) - ) - else: - encoder_hidden_states1 = ( - batch["text_encoder_outputs1_list"] - .to(accelerator.device) - .to(weight_dtype) - ) - encoder_hidden_states2 = ( - batch["text_encoder_outputs2_list"] - .to(accelerator.device) - .to(weight_dtype) - ) - pool2 = ( - batch["text_encoder_pool2_list"] - .to(accelerator.device) - .to(weight_dtype) - ) - - # get size embeddings - orig_size = batch["original_sizes_hw"] - crop_size = batch["crop_top_lefts"] - target_size = batch["target_sizes_hw"] - # embs = sdxl_train_util.get_size_embeddings( - # orig_size, crop_size, target_size, accelerator.device - # ).to(weight_dtype) - - embs = torch.cat([orig_size, crop_size, target_size]).to(accelerator.device).to(weight_dtype) #B,6 - # concat embeddings - #vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) - vector_embedding_dict = { - "text_embeds": pool2, - "time_ids": embs - } - text_embedding = torch.cat( - [encoder_hidden_states1, encoder_hidden_states2], dim=2 - ).to(weight_dtype) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = ( - train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - ) - - controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) - - - with accelerator.autocast(): - down_block_res_samples, mid_block_res_sample = controlnet( - noisy_latents, - timesteps, - encoder_hidden_states=text_embedding, - added_cond_kwargs=vector_embedding_dict, - controlnet_cond=controlnet_image, - return_dict=False, - ) - - # Predict the noise residual - noise_pred = unet( - noisy_latents, - timesteps, - encoder_hidden_states=text_embedding, - added_cond_kwargs=vector_embedding_dict, - down_block_additional_residuals=[ - sample.to(dtype=weight_dtype) for sample in down_block_res_samples - ], - mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), - return_dict=False, - )[0] - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = train_util.conditional_loss( - noise_pred.float(),target.float(),reduction="none",loss_type=args.loss_type,huber_c=huber_c, - ) - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss,timesteps,noise_scheduler,args.min_snr_gamma,args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - - if args.gradfilter_ema_alpha: - grads = gradfilter_ema( - m=controlnet, - grads=grads, - alpha=args.gradfilter_ema_alpha, - lamb=args.gradfilter_ema_lamb, - ) - elif args.gradfilter_ma_window_size: - grads = gradfilter_ma( - m=controlnet, - grads=grads, - window_size=args.gradfilter_ma_window_size, - lamb=args.gradfilter_ma_lamb, - filter_type=args.gradfilter_ma_filter_type, - warmup=False if args.gradfilter_ma_warmup_false else True, - ) - - if not args.fused_backward_pass: - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = controlnet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - else: - # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook - lr_scheduler.step() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - sdxl_train_util.sample_images( - accelerator, - args, - None, - global_step, - accelerator.device, - vae, - [tokenizer1, tokenizer2], - [text_encoder1, text_encoder2], - unet, - controlnet=controlnet, - ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name,unwrap_model(controlnet)) - - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) - - current_loss = loss.detach().item() - loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.moving_average - logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if args.logging_dir is not None: - logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - # 指定エポックごとにモデルを保存 - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name,unwrap_model(controlnet)) - - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) - - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - - sdxl_train_util.sample_images( - accelerator, - args, - epoch + 1, - global_step, - accelerator.device, - vae, - [tokenizer1, tokenizer2], - [text_encoder1, text_encoder2], - unet, - controlnet=controlnet, - ) - - # end of epoch - - if is_main_process: - controlnet = unwrap_model(controlnet) - - accelerator.end_training() - - if is_main_process and (args.save_state or args.save_state_on_train_end): - train_util.save_state_on_train_end(args, accelerator) - - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model( - ckpt_name, controlnet, force_sync_upload=True - ) - - logger.info("model saved.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True, True) - train_util.add_training_arguments(parser, False) - train_util.add_masked_loss_arguments(parser) - deepspeed_utils.add_deepspeed_arguments(parser) - train_util.add_sd_saving_arguments(parser) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) - sdxl_train_util.add_sdxl_training_arguments(parser) - - parser.add_argument( - "--controlnet_model_name_or_path", - type=str, - default=None, - help="controlnet model name or path / controlnetのモデル名またはパス", - ) - parser.add_argument( - "--no_half_vae", - action="store_true", - help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", - ) - - return parser - - -if __name__ == "__main__": - # sdxl_original_unet.USE_REENTRANT = False - - parser = setup_parser() - - args = parser.parse_args() - train_util.verify_command_line_training_args(args) - args = train_util.read_config_from_file(args, parser) - - train(args)