From 9c4d7d56482f08618a337ab0733824ac5704b6c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Sun, 7 Apr 2024 17:15:49 +0800 Subject: [PATCH] init --- fine_tune.py | 5 ++++- library/train_util.py | 15 +++++++++++++++ sdxl_train.py | 5 ++++- sdxl_train_control_net_lllite.py | 7 ++++++- sdxl_train_control_net_lllite_old.py | 7 ++++++- train_db.py | 5 ++++- train_network.py | 5 ++++- train_textual_inversion_XTI.py | 7 ++++++- 8 files changed, 49 insertions(+), 7 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index a0350ce18..aab7596e3 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -322,6 +322,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for m in training_models: m.train() + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() for step, batch in enumerate(train_dataloader): current_step.value = global_step @@ -390,7 +392,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("scheduleFree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes diff --git a/library/train_util.py b/library/train_util.py index 1a46f6a7d..035870134 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4012,6 +4012,21 @@ def get_optimizer(args, trainable_params): logger.info(f"use AdamW optimizer | {optimizer_kwargs}") optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + if optimizer_type == "AdamWScheduleFree".lower(): + optimizer_class = sf.AdamWScheduleFree + logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "SGDScheduleFree".lower(): + optimizer_class = sf.SGDScheduleFree + logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) if optimizer is None: # 任意のoptimizerを使う diff --git a/sdxl_train.py b/sdxl_train.py index 816598e04..3ab0513d0 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -501,6 +501,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for m in training_models: m.train() + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() for step, batch in enumerate(train_dataloader): current_step.value = global_step @@ -626,7 +628,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("scheduleFree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 9eaaa19f2..2c0c9818e 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -290,8 +290,12 @@ def train(args): if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() else: unet.eval() + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: @@ -481,7 +485,8 @@ def remove_model(old_ckpt_name): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("scheduleFree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index e55a58896..a383948b5 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -261,8 +261,12 @@ def train(args): if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() else: unet.eval() + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() network.prepare_grad_etc() @@ -449,7 +453,8 @@ def remove_model(old_ckpt_name): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("scheduleFree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes diff --git a/train_db.py b/train_db.py index 0a152f224..15e1a63c5 100644 --- a/train_db.py +++ b/train_db.py @@ -302,6 +302,8 @@ def train(args): # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() # train==True is required to enable gradient_checkpointing if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: text_encoder.train() @@ -384,7 +386,8 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("scheduleFree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes diff --git a/train_network.py b/train_network.py index 8fe98f126..a6b67128f 100644 --- a/train_network.py +++ b/train_network.py @@ -446,6 +446,8 @@ def train(self, args): if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() for t_enc in text_encoders: t_enc.train() @@ -900,7 +902,8 @@ def remove_model(old_ckpt_name): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("scheduleFree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) if args.scale_weight_norms: diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 861d48d1d..c6921c4e4 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -354,8 +354,12 @@ def train(args): unet.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() else: unet.eval() + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() if not cache_latents: vae.requires_grad_(False) @@ -496,7 +500,8 @@ def remove_model(old_ckpt_name): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("scheduleFree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Let's make sure we don't update any embedding weights besides the newly added token