Skip to content

Commit

Permalink
update for Radm
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Dec 4, 2024
1 parent 8a82b8c commit 50bf685
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
12 changes: 9 additions & 3 deletions flux_train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
init_ipex()

from accelerate.utils import set_seed
from diffusers.utils.torch_utils import is_compiled_module

import library.train_util as train_util
from library import (
Expand Down Expand Up @@ -173,6 +174,11 @@ def train(args):
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)

def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model

# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)

Expand Down Expand Up @@ -748,7 +754,7 @@ def grad_hook(parameter: torch.Tensor):
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(controlnet),
unwrap_model(controlnet),
)
optimizer_train_fn()

Expand Down Expand Up @@ -784,7 +790,7 @@ def grad_hook(parameter: torch.Tensor):
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(controlnet),
unwrap_model(controlnet),
)

flux_train_utils.sample_images(
Expand All @@ -794,7 +800,7 @@ def grad_hook(parameter: torch.Tensor):

is_main_process = accelerator.is_main_process
# if is_main_process:
controlnet = accelerator.unwrap_model(controlnet)
controlnet = unwrap_model(controlnet)

accelerator.end_training()
optimizer_eval_fn()
Expand Down
3 changes: 3 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5092,6 +5092,9 @@ def get_optimizer(args, trainable_params, model=None) -> tuple[str, str, object]
elif optimizer_type == "SGDScheduleFree".lower():
optimizer_class = sf.SGDScheduleFree
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
elif optimizer_type == "RAdamScheduleFree".lower():
optimizer_class = sf.RAdamScheduleFree
logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}")
else:
optimizer_class = None

Expand Down

0 comments on commit 50bf685

Please sign in to comment.