Skip to content

Commit

Permalink
support weighted captions for sdxl LoRA and fine tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 9, 2024
1 parent 126159f commit 886f753
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 35 deletions.
5 changes: 4 additions & 1 deletion library/strategy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
raise NotImplementedError

def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
returns: [tokens1, tokens2, ...], [weights1, weights2, ...]
"""
raise NotImplementedError

def _get_weighted_input_ids(
Expand Down Expand Up @@ -303,7 +306,7 @@ def encode_tokens(
:return: list of output embeddings for each architecture
"""
raise NotImplementedError

def encode_tokens_with_weights(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
) -> List[torch.Tensor]:
Expand Down
3 changes: 2 additions & 1 deletion library/strategy_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def encode_tokens(
"""
Args:
tokenize_strategy: TokenizeStrategy
models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]
models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)].
If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required
tokens: List of tokens, for text_encoder1 and text_encoder2
"""
if len(models) == 2:
Expand Down
38 changes: 20 additions & 18 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def train(args):
setup_logging(args, reset=True)

assert (
not args.weighted_captions
), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
not args.weighted_captions or not args.cache_text_encoder_outputs
), "weighted_captions is not supported when caching text encoder outputs / cache_text_encoder_outputsを使うときはweighted_captionsはサポートされていません"
assert (
not args.train_text_encoder or not args.cache_text_encoder_outputs
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
Expand Down Expand Up @@ -660,22 +660,24 @@ def optimizer_hook(parameter: torch.Tensor):
input_ids1, input_ids2 = batch["input_ids_list"]
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
# TODO support weighted captions
# if args.weighted_captions:
# encoder_hidden_states = get_weighted_text_embeddings(
# tokenizer,
# text_encoder,
# batch["captions"],
# accelerator.device,
# args.max_token_length // 75 if args.max_token_length else 1,
# clip_skip=args.clip_skip,
# )
# else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2]
)
if args.weighted_captions:
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoder_hidden_states1, encoder_hidden_states2, pool2 = (
text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
[text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)],
input_ids_list,
weights_list,
)
)
else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
tokenize_strategy,
[text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)],
[input_ids1, input_ids2],
)
if args.full_fp16:
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)
Expand Down
7 changes: 2 additions & 5 deletions sdxl_train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,21 @@

init_ipex()

from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from accelerate import init_empty_weights
from diffusers import DDPMScheduler, ControlNetModel
from diffusers import DDPMScheduler
from diffusers.utils.torch_utils import is_compiled_module
from safetensors.torch import load_file
from library import (
deepspeed_utils,
sai_model_spec,
sdxl_model_util,
sdxl_original_unet,
sdxl_train_util,
strategy_base,
strategy_sd,
strategy_sdxl,
)

import library.model_util as model_util
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
Expand Down Expand Up @@ -264,7 +261,7 @@ def unwrap_model(model):
trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr})
trainable_params.append({"params": unet_params, "lr": args.learning_rate})
all_params = ctrlnet_params + unet_params

logger.info(f"trainable params count: {len(all_params)}")
logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}")

Expand Down
27 changes: 17 additions & 10 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,14 +1123,21 @@ def remove_model(old_ckpt_name):
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
# SD only
encoded_text_encoder_conds = get_weighted_text_embeddings(
tokenizers[0],
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
# # SD only
# encoded_text_encoder_conds = get_weighted_text_embeddings(
# tokenizers[0],
# text_encoder,
# batch["captions"],
# accelerator.device,
# args.max_token_length // 75 if args.max_token_length else 1,
# clip_skip=args.clip_skip,
# )
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids_list,
weights_list,
)
else:
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
Expand All @@ -1139,8 +1146,8 @@ def remove_model(old_ckpt_name):
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
if args.full_fp16:
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
if args.full_fp16:
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]

# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
Expand Down

0 comments on commit 886f753

Please sign in to comment.