Skip to content

Commit

Permalink
Merge branch 'sd3' of https://github.com/kohya-ss/sd-scripts into sd3
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Oct 10, 2024
2 parents e50a9f8 + 83e3048 commit 1bb33b7
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 111 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ The command to install PyTorch is as follows:

### Recent Updates

Oct 6, 2024:
- In FLUX.1 LoRA training and fine-tuning, the specified weight file (*.safetensors) is automatically determined to be dev or schnell. This allows schnell models to be loaded correctly. Note that LoRA training with schnell models and fine-tuning with schnell models are unverified.
- FLUX.1 LoRA training and fine-tuning can now load weights in Diffusers format in addition to BFL format (a single *.safetensors file). Please specify the parent directory of `transformer` or `diffusion_pytorch_model-00001-of-00003.safetensors` with the full path. However, Diffusers format CLIP/T5XXL is not supported. Saving is supported only in BFL format.

Sep 26, 2024:
The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.

Expand Down
15 changes: 6 additions & 9 deletions flux_minimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,6 @@ def encode(prpt: str):
steps = args.steps
guidance_scale = args.guidance

name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way
is_schnell = name == "schnell"

def is_fp8(dt):
return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]

Expand Down Expand Up @@ -455,12 +452,8 @@ def is_fp8(dt):
# if is_fp8(t5xxl_dtype):
# t5xxl = accelerator.prepare(t5xxl)

t5xxl_max_length = 256 if is_schnell else 512
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()

# DiT
model = flux_utils.load_flow_model(name, args.ckpt_path, None, loading_device)
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype
Expand All @@ -469,8 +462,12 @@ def is_fp8(dt):
# if args.offload:
# model = model.to("cpu")

t5xxl_max_length = 256 if is_schnell else 512
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()

# AE
ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device)
ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
ae.eval()
# if is_fp8(ae_dtype):
# ae = accelerator.prepare(ae)
Expand Down
17 changes: 9 additions & 8 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,16 @@ def train(args):

train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認

_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
)
)
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev"
t5xxl_max_token_length = (
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if name == "schnell" else 512)
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
)
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))

Expand Down Expand Up @@ -177,12 +177,11 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)

# モデルを読み込む
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev"

# load VAE for caching latents
ae = None
if cache_latents:
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae.to(accelerator.device, dtype=weight_dtype)
ae.requires_grad_(False)
ae.eval()
Expand All @@ -196,7 +195,7 @@ def train(args):

# prepare tokenize strategy
if args.t5xxl_max_token_length is None:
if name == "schnell":
if is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
Expand Down Expand Up @@ -258,8 +257,8 @@ def train(args):
clean_memory_on_device(accelerator.device)

# load FLUX
flux = flux_utils.load_flow_model(
name, args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
_, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)

if args.gradient_checkpointing:
Expand Down Expand Up @@ -294,7 +293,7 @@ def train(args):

if not cache_latents:
# load VAE here if not cached
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu")
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu")
ae.requires_grad_(False)
ae.eval()
ae.to(accelerator.device, dtype=weight_dtype)
Expand Down Expand Up @@ -706,7 +705,9 @@ def optimizer_hook(parameter: torch.Tensor):
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()

# For --sample_at_first
optimizer_eval_fn()
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
optimizer_train_fn()
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
Expand Down
17 changes: 7 additions & 10 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import copy
import math
import random
from typing import Any
from typing import Any, Optional

import torch
from accelerate import Accelerator
Expand All @@ -24,6 +24,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_schnell: Optional[bool] = None

def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
Expand Down Expand Up @@ -57,19 +58,15 @@ def assert_extra_args(self, args, train_dataset_group):

train_dataset_group.verify_bucket_reso_steps(32) # TODO check this

def get_flux_model_name(self, args):
return "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev"

def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
name = self.get_flux_model_name(args)

# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype

# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
model = flux_utils.load_flow_model(
name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
self.is_schnell, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
)
if args.fp8_base:
# check dtype of model
Expand Down Expand Up @@ -100,7 +97,7 @@ def load_target_model(self, args, weight_dtype, accelerator):
elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model")

ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)

return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model

Expand Down Expand Up @@ -142,10 +139,10 @@ def prepare_split_model(self, model, weight_dtype, accelerator):
return flux_lower

def get_tokenize_strategy(self, args):
name = self.get_flux_model_name(args)
_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)

if args.t5xxl_max_token_length is None:
if name == "schnell":
if is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
Expand Down
Loading

0 comments on commit 1bb33b7

Please sign in to comment.