Skip to content

Commit

Permalink
add schnell option to load_cn
Browse files Browse the repository at this point in the history
  • Loading branch information
minux302 committed Nov 29, 2024
1 parent 575f583 commit be5860f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
4 changes: 2 additions & 2 deletions flux_train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,14 +259,14 @@ def train(args):
clean_memory_on_device(accelerator.device)

# load FLUX
_, flux = flux_utils.load_flow_model(
is_schnell, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)
flux.requires_grad_(False)
flux.to(accelerator.device)

# load controlnet
controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, accelerator.device, args.disable_mmap_load_safetensors)
controlnet = flux_utils.load_controlnet(args.controlnet, is_schnell, torch.float32, accelerator.device, args.disable_mmap_load_safetensors)
controlnet.train()

if args.gradient_checkpointing:
Expand Down
14 changes: 6 additions & 8 deletions library/flux_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from dataclasses import replace
import json
import os
from dataclasses import replace
from typing import List, Optional, Tuple, Union

import einops
import torch

from safetensors.torch import load_file
from safetensors import safe_open
from accelerate import init_empty_weights
from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel

from library.utils import setup_logging

Expand Down Expand Up @@ -154,11 +154,9 @@ def load_ae(


def load_controlnet(
ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
):
logger.info("Building ControlNet")
# is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
is_schnell = False
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
with torch.device(device):
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype)
Expand Down

0 comments on commit be5860f

Please sign in to comment.