From 50022bd42c0ec73be57505a942a3eb1be607bccf Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 1 Jul 2024 23:10:27 +0800 Subject: [PATCH] update --- library/train_util.py | 2 +- sdxl_train_controlnet.py | 38 +++++++++++++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 760be33eb..f9d483fb5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3176,7 +3176,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: # available backends: # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt", "tensort", "ipex", "tvm"], help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") diff --git a/sdxl_train_controlnet.py b/sdxl_train_controlnet.py index 266cceeeb..bab152f88 100644 --- a/sdxl_train_controlnet.py +++ b/sdxl_train_controlnet.py @@ -15,6 +15,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel +from diffusers.utils.torch_utils import is_compiled_module from safetensors.torch import load_file from library import ( deepspeed_utils, @@ -143,6 +144,11 @@ def train(args): accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) vae_dtype = torch.float32 if args.no_half_vae else weight_dtype @@ -610,14 +616,25 @@ def remove_model(old_ckpt_name): progress_bar.update(1) global_step += 1 - # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + sdxl_train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + controlnet=controlnet, + ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name,accelerator.unwrap_model(controlnet)) + save_model(ckpt_name,unwrap_model(controlnet)) if args.save_state: train_util.save_and_remove_state_stepwise(args, accelerator, global_step) @@ -651,7 +668,7 @@ def remove_model(old_ckpt_name): saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name,accelerator.unwrap_model(controlnet)) + save_model(ckpt_name,unwrap_model(controlnet)) remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) if remove_epoch_no is not None: @@ -661,12 +678,23 @@ def remove_model(old_ckpt_name): if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - # self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + sdxl_train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + controlnet=controlnet, + ) # end of epoch if is_main_process: - controlnet = accelerator.unwrap_model(controlnet) + controlnet = unwrap_model(controlnet) accelerator.end_training()