From 50bf6859faafb4dc360fd67f184f0b8893c52e7a Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 4 Dec 2024 22:29:28 +0800 Subject: [PATCH] update for Radm --- flux_train_control_net.py | 12 +++++++++--- library/train_util.py | 3 +++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 744c265b5..c9d38afbb 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -30,6 +30,7 @@ init_ipex() from accelerate.utils import set_seed +from diffusers.utils.torch_utils import is_compiled_module import library.train_util as train_util from library import ( @@ -173,6 +174,11 @@ def train(args): logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -748,7 +754,7 @@ def grad_hook(parameter: torch.Tensor): epoch, num_train_epochs, global_step, - accelerator.unwrap_model(controlnet), + unwrap_model(controlnet), ) optimizer_train_fn() @@ -784,7 +790,7 @@ def grad_hook(parameter: torch.Tensor): epoch, num_train_epochs, global_step, - accelerator.unwrap_model(controlnet), + unwrap_model(controlnet), ) flux_train_utils.sample_images( @@ -794,7 +800,7 @@ def grad_hook(parameter: torch.Tensor): is_main_process = accelerator.is_main_process # if is_main_process: - controlnet = accelerator.unwrap_model(controlnet) + controlnet = unwrap_model(controlnet) accelerator.end_training() optimizer_eval_fn() diff --git a/library/train_util.py b/library/train_util.py index b78732ca1..f32fef06f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5092,6 +5092,9 @@ def get_optimizer(args, trainable_params, model=None) -> tuple[str, str, object] elif optimizer_type == "SGDScheduleFree".lower(): optimizer_class = sf.SGDScheduleFree logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "RAdamScheduleFree".lower(): + optimizer_class = sf.RAdamScheduleFree + logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}") else: optimizer_class = None