diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index d6c0ad1ae..be0f9bdb0 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -292,6 +292,11 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): default=1.0, help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。", ) + parser.add_argument( + "--resolution_shift", + action="store_true", + help="use flux resolution shift for training timestep distribution adjustment / 訓練タイムステップ分布調整のためにflux解像度シフトを使用する", + ) def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): @@ -992,7 +997,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - bsz = latents.shape[0] + bsz, _, h, w = latents.shape # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -1005,10 +1010,18 @@ def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> ) t_min = args.min_timestep if args.min_timestep is not None else 0 t_max = args.max_timestep if args.max_timestep is not None else 1000 - shift = args.training_shift - # weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details) - u = (u * shift) / (1 + (shift - 1) * u) + if args.resolution_shift: + mu = flux_train_utils.get_lin_function( + y1=0.5, + y2=1.15, + )((h // 2) * (w // 2)) + u = flux_train_utils.time_shift(mu, 1.0, u) + else: + shift = args.training_shift + + # weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details) + u = (u * shift) / (1 + (shift - 1) * u) indices = (u * (t_max - t_min) + t_min).long() timesteps = indices.to(device=device, dtype=dtype) diff --git a/sd3_train_network.py b/sd3_train_network.py index 3c9b70579..bb02c7ac7 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -328,8 +328,13 @@ def get_noise_pred_and_target( # TODO support attention mask model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) - # apply model prediction type - model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss target = latents