diff --git a/fine_tune.py b/fine_tune.py index 17d091408..38b1962f6 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -335,10 +335,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for m in training_models: m.train() - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(*training_models): with torch.no_grad(): @@ -409,6 +409,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/sdxl_train.py b/sdxl_train.py index 2590d36c1..09ca438f9 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -509,10 +509,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for m in training_models: m.train() - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: @@ -640,6 +640,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 6cbb4741b..056f197f4 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -292,13 +292,14 @@ def train(args): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) if args.gradient_checkpointing: - unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる if (args.optimizer_type.lower().endswith("schedulefree")): optimizer.train() + unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる + else: - unet.eval() if (args.optimizer_type.lower().endswith("schedulefree")): optimizer.eval() + unet.eval() # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: @@ -397,6 +398,8 @@ def remove_model(old_ckpt_name): current_epoch.value = epoch + 1 for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(unet): with torch.no_grad(): @@ -492,6 +495,9 @@ def remove_model(old_ckpt_name): lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index bb48fcb14..30240dd38 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -366,6 +366,8 @@ def remove_model(old_ckpt_name): network.on_epoch_start() # train() for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(network): with torch.no_grad(): @@ -462,6 +464,9 @@ def remove_model(old_ckpt_name): lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_controlnet.py b/train_controlnet.py index 6b71799dc..849e467b7 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -398,6 +398,8 @@ def remove_model(old_ckpt_name): current_epoch.value = epoch + 1 for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(controlnet): with torch.no_grad(): @@ -477,6 +479,9 @@ def remove_model(old_ckpt_name): lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_db.py b/train_db.py index ad55d6ce0..1f3ffd2cb 100644 --- a/train_db.py +++ b/train_db.py @@ -315,13 +315,13 @@ def train(args): # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() # train==True is required to enable gradient_checkpointing if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: text_encoder.train() for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: @@ -403,6 +403,9 @@ def train(args): lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_network.py b/train_network.py index e7db8168c..fe491d880 100644 --- a/train_network.py +++ b/train_network.py @@ -455,9 +455,10 @@ def train(self, args): if args.gradient_checkpointing: # according to TI example in Diffusers, train is required - unet.train() if (args.optimizer_type.lower().endswith("schedulefree")): optimizer.train() + unet.train() + for t_enc in text_encoders: t_enc.train() @@ -466,6 +467,8 @@ def train(self, args): t_enc.text_model.embeddings.requires_grad_(True) else: + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() unet.eval() for t_enc in text_encoders: t_enc.eval() @@ -814,6 +817,8 @@ def remove_model(old_ckpt_name): accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) @@ -931,6 +936,9 @@ def remove_model(old_ckpt_name): else: keys_scaled, mean_norm, maximum_norm = None, None, None + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 4adbc642f..bd90a4e5b 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -462,8 +462,12 @@ def train(self, args): unet.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required # TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() unet.train() else: + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() unet.eval() if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する @@ -567,6 +571,8 @@ def remove_model(old_ckpt_name): loss_total = 0 for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(text_encoders[0]): with torch.no_grad(): @@ -637,6 +643,9 @@ def remove_model(old_ckpt_name): index_no_updates ] + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 701fd1467..712af33ec 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -447,6 +447,8 @@ def remove_model(old_ckpt_name): loss_total = 0 for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(text_encoder): with torch.no_grad(): @@ -515,6 +517,9 @@ def remove_model(old_ckpt_name): index_no_updates ] + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1)