From be5860f8e266c5562f123fe9e0cb3febef615290 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sat, 30 Nov 2024 00:08:21 +0900 Subject: [PATCH] add schnell option to load_cn --- flux_train_control_net.py | 4 ++-- library/flux_utils.py | 14 ++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index a17c811e3..bb27c35ed 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -259,14 +259,14 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - _, flux = flux_utils.load_flow_model( + is_schnell, flux = flux_utils.load_flow_model( args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) flux.requires_grad_(False) flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) + controlnet = flux_utils.load_controlnet(args.controlnet, is_schnell, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) controlnet.train() if args.gradient_checkpointing: diff --git a/library/flux_utils.py b/library/flux_utils.py index f2759c375..8be1d63ee 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,14 +1,14 @@ -from dataclasses import replace import json import os +from dataclasses import replace from typing import List, Optional, Tuple, Union + import einops import torch - -from safetensors.torch import load_file -from safetensors import safe_open from accelerate import init_empty_weights -from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config +from safetensors import safe_open +from safetensors.torch import load_file +from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel from library.utils import setup_logging @@ -154,11 +154,9 @@ def load_ae( def load_controlnet( - ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False + ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False ): logger.info("Building ControlNet") - # is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) - is_schnell = False name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL with torch.device(device): controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype)