Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Jul 1, 2024
1 parent 076b971 commit 50022bd
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
2 changes: 1 addition & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3176,7 +3176,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
# available backends:
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
# https://pytorch.org/docs/stable/torch.compiler.html
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt", "tensort", "ipex", "tvm"],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)",
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
Expand Down
38 changes: 33 additions & 5 deletions sdxl_train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
from diffusers.utils.torch_utils import is_compiled_module
from safetensors.torch import load_file
from library import (
deepspeed_utils,
Expand Down Expand Up @@ -143,6 +144,11 @@ def train(args):
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process

def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model

# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
Expand Down Expand Up @@ -610,14 +616,25 @@ def remove_model(old_ckpt_name):
progress_bar.update(1)
global_step += 1

# sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
sdxl_train_util.sample_images(
accelerator,
args,
None,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
controlnet=controlnet,
)

# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name,accelerator.unwrap_model(controlnet))
save_model(ckpt_name,unwrap_model(controlnet))

if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
Expand Down Expand Up @@ -651,7 +668,7 @@ def remove_model(old_ckpt_name):
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name,accelerator.unwrap_model(controlnet))
save_model(ckpt_name,unwrap_model(controlnet))

remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
Expand All @@ -661,12 +678,23 @@ def remove_model(old_ckpt_name):
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)

# self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
sdxl_train_util.sample_images(
accelerator,
args,
epoch + 1,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
controlnet=controlnet,
)

# end of epoch

if is_main_process:
controlnet = accelerator.unwrap_model(controlnet)
controlnet = unwrap_model(controlnet)

accelerator.end_training()

Expand Down

0 comments on commit 50022bd

Please sign in to comment.