Skip to content

Commit

Permalink
use no schedule
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Apr 7, 2024
1 parent 788ee7a commit f5f3bb0
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 41 deletions.
27 changes: 20 additions & 7 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,18 +255,31 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("scheduleFree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("scheduleFree"):
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader
)
else:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("scheduleFree"):
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down
16 changes: 12 additions & 4 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("scheduleFree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]

else:
Expand All @@ -428,7 +433,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("scheduleFree"):
optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader)
else:
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
Expand Down
5 changes: 4 additions & 1 deletion sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,10 @@ def train(args):
unet.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("scheduleFree"):
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
Expand Down
11 changes: 8 additions & 3 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,14 @@ def train(args):
network.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("scheduleFree"):
unet, network, optimizer, train_dataloader = accelerator.prepare(
unet, network, optimizer, train_dataloader
)
else:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
network: control_net_lllite.ControlNetLLLite

if args.gradient_checkpointing:
Expand Down
11 changes: 8 additions & 3 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,14 @@ def train(args):
controlnet.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("scheduleFree"):
controlnet, optimizer, train_dataloader = accelerator.prepare(
controlnet, optimizer, train_dataloader
)
else:
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet, optimizer, train_dataloader, lr_scheduler
)

unet.requires_grad_(False)
text_encoder.requires_grad_(False)
Expand Down
27 changes: 20 additions & 7 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,32 @@ def train(args):
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("scheduleFree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]

else:
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("scheduleFree"):
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader
)
else:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
training_models = [unet, text_encoder]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("scheduleFree"):
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
training_models = [unet]

if not train_text_encoder:
Expand Down
24 changes: 17 additions & 7 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,14 @@ def train(self, args):
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
network=network,
)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("scheduleFree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_model = ds_model
else:
if train_unet:
Expand All @@ -437,10 +442,15 @@ def train(self, args):
text_encoders = [text_encoder]
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set

network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler
)

if args.optimizer_type.lower().endswith("scheduleFree"):
network, optimizer, train_dataloader = accelerator.prepare(
network, optimizer, train_dataloader
)
else:
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler
)
training_model = network

if args.gradient_checkpointing:
Expand Down
22 changes: 16 additions & 6 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,14 +416,24 @@ def train(self, args):

# acceleratorがなんかよろしくやってくれるらしい
if len(text_encoders) == 1:
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("scheduleFree"):
text_encoder_or_list, optimizer, train_dataloader = accelerator.preparet(
text_encoder_or_list, optimizer, train_dataloader
)
else:
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.preparet(
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
)

elif len(text_encoders) == 2:
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("scheduleFree"):
text_encoder1, text_encoder2, optimizer, train_dataloader = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader
)
else:
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
)

text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]

Expand Down
11 changes: 8 additions & 3 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,14 @@ def train(args):
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)

# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("scheduleFree"):
text_encoder, optimizer, train_dataloader = accelerator.prepare(
text_encoder, optimizer, train_dataloader
)
else:
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler
)

index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# logger.info(len(index_no_updates), torch.sum(index_no_updates))
Expand Down

0 comments on commit f5f3bb0

Please sign in to comment.