Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Apr 7, 2024
1 parent b748b48 commit 9c4d7d5
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 7 deletions.
5 changes: 4 additions & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

for m in training_models:
m.train()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()

for step, batch in enumerate(train_dataloader):
current_step.value = global_step
Expand Down Expand Up @@ -390,7 +392,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("scheduleFree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

# Checks if the accelerator has performed an optimization step behind the scenes
Expand Down
15 changes: 15 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4012,6 +4012,21 @@ def get_optimizer(args, trainable_params):
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "AdamWScheduleFree".lower():
optimizer_class = sf.AdamWScheduleFree
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
elif optimizer_type == "SGDScheduleFree".lower():
optimizer_class = sf.SGDScheduleFree
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

if optimizer is None:
# 任意のoptimizerを使う
Expand Down
5 changes: 4 additions & 1 deletion sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

for m in training_models:
m.train()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()

for step, batch in enumerate(train_dataloader):
current_step.value = global_step
Expand Down Expand Up @@ -626,7 +628,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("scheduleFree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

# Checks if the accelerator has performed an optimization step behind the scenes
Expand Down
7 changes: 6 additions & 1 deletion sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,12 @@ def train(args):

if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
else:
unet.eval()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
Expand Down Expand Up @@ -481,7 +485,8 @@ def remove_model(old_ckpt_name):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("scheduleFree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

# Checks if the accelerator has performed an optimization step behind the scenes
Expand Down
7 changes: 6 additions & 1 deletion sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,12 @@ def train(args):

if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
else:
unet.eval()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

network.prepare_grad_etc()

Expand Down Expand Up @@ -449,7 +453,8 @@ def remove_model(old_ckpt_name):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("scheduleFree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

# Checks if the accelerator has performed an optimization step behind the scenes
Expand Down
5 changes: 4 additions & 1 deletion train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ def train(args):

# 指定したステップ数までText Encoderを学習する:epoch最初の状態
unet.train()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
# train==True is required to enable gradient_checkpointing
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
text_encoder.train()
Expand Down Expand Up @@ -384,7 +386,8 @@ def train(args):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("scheduleFree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

# Checks if the accelerator has performed an optimization step behind the scenes
Expand Down
5 changes: 4 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,8 @@ def train(self, args):
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
unet.train()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
for t_enc in text_encoders:
t_enc.train()

Expand Down Expand Up @@ -900,7 +902,8 @@ def remove_model(old_ckpt_name):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("scheduleFree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if args.scale_weight_norms:
Expand Down
7 changes: 6 additions & 1 deletion train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,12 @@ def train(args):
unet.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
else:
unet.eval()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

if not cache_latents:
vae.requires_grad_(False)
Expand Down Expand Up @@ -496,7 +500,8 @@ def remove_model(old_ckpt_name):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("scheduleFree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

# Let's make sure we don't update any embedding weights besides the newly added token
Expand Down

0 comments on commit 9c4d7d5

Please sign in to comment.