Skip to content

Commit

Permalink
Merge branch 'kohya-ss:dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds authored Apr 7, 2024
2 parents 6c7c9f2 + d30ebb2 commit 1fd8b27
Show file tree
Hide file tree
Showing 13 changed files with 267 additions and 52 deletions.
114 changes: 93 additions & 21 deletions README.md

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

# Predict the noise residual
with accelerator.autocast():
Expand All @@ -368,7 +368,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
Expand All @@ -380,7 +380,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
Expand Down Expand Up @@ -520,6 +520,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = setup_parser()

args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)

train(args)
138 changes: 134 additions & 4 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,7 +1898,7 @@ def __init__(
subset.image_dir,
False,
None,
subset.caption_extension,
subset.caption_extension,
subset.cache_info,
subset.num_repeats,
subset.shuffle_caption,
Expand Down Expand Up @@ -3277,6 +3277,27 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
)
parser.add_argument(
"--loss_type",
type=str,
default="l2",
choices=["l2", "huber", "smooth_l1"],
help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2",
)
parser.add_argument(
"--huber_schedule",
type=str,
default="exponential",
choices=["constant", "exponential", "snr"],
help="The scheduling method for Huber loss (constant, exponential, or SNR-based). Only used when loss_type is 'huber' or 'smooth_l1'. default is exponential"
+ " / Huber損失のスケジューリング方法(constant、exponential、またはSNRベース)。loss_typeが'huber'または'smooth_l1'の場合に有効、デフォルトはexponential",
)
parser.add_argument(
"--huber_c",
type=float,
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
)

parser.add_argument(
"--lowram",
Expand Down Expand Up @@ -3399,6 +3420,60 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
)


# verify command line args for training
def verify_command_line_training_args(args: argparse.Namespace):
# if wandb is enabled, the command line is exposed to the public
# check whether sensitive options are included in the command line arguments
# if so, warn or inform the user to move them to the configuration file
# wandbが有効な場合、コマンドラインが公開される
# 学習用のコマンドライン引数に敏感なオプションが含まれているかどうかを確認し、
# 含まれている場合は設定ファイルに移動するようにユーザーに警告または通知する

wandb_enabled = args.log_with is not None and args.log_with != "tensorboard" # "all" or "wandb"
if not wandb_enabled:
return

sensitive_args = ["wandb_api_key", "huggingface_token"]
sensitive_path_args = [
"pretrained_model_name_or_path",
"vae",
"tokenizer_cache_dir",
"train_data_dir",
"conditioning_data_dir",
"reg_data_dir",
"output_dir",
"logging_dir",
]

for arg in sensitive_args:
if getattr(args, arg, None) is not None:
logger.warning(
f"wandb is enabled, but option `{arg}` is included in the command line. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file."
+ f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれています。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。"
)

# if path is absolute, it may include sensitive information
for arg in sensitive_path_args:
if getattr(args, arg, None) is not None and os.path.isabs(getattr(args, arg)):
logger.info(
f"wandb is enabled, but option `{arg}` is included in the command line and it is an absolute path. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file or use relative path."
+ f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれており、絶対パスです。コマンドラインは公開されるため、`.toml`ファイルに移動するか、相対パスを使用することをお勧めします。"
)

if getattr(args, "config_file", None) is not None:
logger.info(
f"wandb is enabled, but option `config_file` is included in the command line. Because the command line is exposed to the public, please be careful about the information included in the path."
+ f" / wandbが有効で、かつオプション `config_file` がコマンドラインに含まれています。コマンドラインは公開されるため、パスに含まれる情報にご注意ください。"
)

# other sensitive options
if args.huggingface_repo_id is not None and args.huggingface_repo_visibility != "public":
logger.info(
f"wandb is enabled, but option huggingface_repo_id is included in the command line and huggingface_repo_visibility is not 'public'. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file."
+ f" / wandbが有効で、かつオプション huggingface_repo_id がコマンドラインに含まれており、huggingface_repo_visibility が 'public' ではありません。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。"
)


def verify_training_args(args: argparse.Namespace):
r"""
Verify training arguments. Also reflect highvram option to global variable
Expand Down Expand Up @@ -4882,6 +4957,38 @@ def save_sd_model_on_train_end_common(
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)


def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):

# TODO: if a huber loss is selected, it will use constant timesteps for each batch
# as. In the future there may be a smarter way

if args.loss_type == "huber" or args.loss_type == "smooth_l1":
timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu")
timestep = timesteps.item()

if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
huber_c = math.exp(-alpha * timestep)
elif args.huber_schedule == "snr":
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
elif args.huber_schedule == "constant":
huber_c = args.huber_c
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")

timesteps = timesteps.repeat(b_size).to(device)
elif args.loss_type == "l2":
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
huber_c = 1 # may be anything, as it's not used
else:
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
timesteps = timesteps.long()

return timesteps, huber_c


def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
Expand All @@ -4901,8 +5008,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep

timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device)
timesteps = timesteps.long()
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand All @@ -4915,7 +5021,31 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

return noise, noisy_latents, timesteps
return noise, noisy_latents, timesteps, huber_c


# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
):

if loss_type == "l2":
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == "huber":
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
elif loss_type == "smooth_l1":
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
else:
raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
return loss


def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
Expand Down
7 changes: 4 additions & 3 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

Expand All @@ -600,7 +600,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
or args.masked_loss
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
Expand All @@ -616,7 +616,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
Expand Down Expand Up @@ -812,6 +812,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = setup_parser()

args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)

train(args)
5 changes: 3 additions & 2 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

Expand All @@ -458,7 +458,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
Expand Down Expand Up @@ -612,6 +612,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = setup_parser()

args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)

train(args)
5 changes: 3 additions & 2 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

Expand All @@ -426,7 +426,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
Expand Down Expand Up @@ -580,6 +580,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = setup_parser()

args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)

train(args)
1 change: 1 addition & 0 deletions sdxl_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = setup_parser()

args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)

trainer = SdxlNetworkTrainer()
Expand Down
1 change: 1 addition & 0 deletions sdxl_train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = setup_parser()

args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)

trainer = SdxlTextualInversionTrainer()
Expand Down
12 changes: 4 additions & 8 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,13 +420,8 @@ def remove_model(old_ckpt_name):
)

# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(b_size,),
device=latents.device,
)
timesteps = timesteps.long()
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
Expand Down Expand Up @@ -457,7 +452,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
Expand Down Expand Up @@ -617,6 +612,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = setup_parser()

args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)

train(args)
5 changes: 3 additions & 2 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def train(args):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

# Predict the noise residual
with accelerator.autocast():
Expand All @@ -358,7 +358,7 @@ def train(args):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
Expand Down Expand Up @@ -523,6 +523,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = setup_parser()

args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)

train(args)
Loading

0 comments on commit 1fd8b27

Please sign in to comment.