Skip to content

Commit

Permalink
update for eval()
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Apr 10, 2024
1 parent 6116b62 commit 073ee01
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 9 deletions.
7 changes: 5 additions & 2 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

for m in training_models:
m.train()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(*training_models):
with torch.no_grad():
Expand Down Expand Up @@ -409,6 +409,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
7 changes: 5 additions & 2 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

for m in training_models:
m.train()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
Expand Down Expand Up @@ -640,6 +640,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
10 changes: 8 additions & 2 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,13 +292,14 @@ def train(args):
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる

else:
unet.eval()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
unet.eval()

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
Expand Down Expand Up @@ -397,6 +398,8 @@ def remove_model(old_ckpt_name):
current_epoch.value = epoch + 1

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(unet):
with torch.no_grad():
Expand Down Expand Up @@ -492,6 +495,9 @@ def remove_model(old_ckpt_name):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
5 changes: 5 additions & 0 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ def remove_model(old_ckpt_name):
network.on_epoch_start() # train()

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(network):
with torch.no_grad():
Expand Down Expand Up @@ -462,6 +464,9 @@ def remove_model(old_ckpt_name):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
5 changes: 5 additions & 0 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ def remove_model(old_ckpt_name):
current_epoch.value = epoch + 1

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(controlnet):
with torch.no_grad():
Expand Down Expand Up @@ -477,6 +479,9 @@ def remove_model(old_ckpt_name):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
7 changes: 5 additions & 2 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,13 @@ def train(args):

# 指定したステップ数までText Encoderを学習する:epoch最初の状態
unet.train()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
# train==True is required to enable gradient_checkpointing
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
text_encoder.train()

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
Expand Down Expand Up @@ -403,6 +403,9 @@ def train(args):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
10 changes: 9 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,10 @@ def train(self, args):

if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
unet.train()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
unet.train()

for t_enc in text_encoders:
t_enc.train()

Expand All @@ -466,6 +467,8 @@ def train(self, args):
t_enc.text_model.embeddings.requires_grad_(True)

else:
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
unet.eval()
for t_enc in text_encoders:
t_enc.eval()
Expand Down Expand Up @@ -814,6 +817,8 @@ def remove_model(old_ckpt_name):
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet)
Expand Down Expand Up @@ -931,6 +936,9 @@ def remove_model(old_ckpt_name):
else:
keys_scaled, mean_norm, maximum_norm = None, None, None

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
9 changes: 9 additions & 0 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,12 @@ def train(self, args):
unet.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
# TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
unet.train()
else:
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
unet.eval()

if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
Expand Down Expand Up @@ -567,6 +571,8 @@ def remove_model(old_ckpt_name):
loss_total = 0

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(text_encoders[0]):
with torch.no_grad():
Expand Down Expand Up @@ -637,6 +643,9 @@ def remove_model(old_ckpt_name):
index_no_updates
]

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
5 changes: 5 additions & 0 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,8 @@ def remove_model(old_ckpt_name):
loss_total = 0

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(text_encoder):
with torch.no_grad():
Expand Down Expand Up @@ -515,6 +517,9 @@ def remove_model(old_ckpt_name):
index_no_updates
]

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down

0 comments on commit 073ee01

Please sign in to comment.