Skip to content

Commit

Permalink
update for resolution shift
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Nov 7, 2024
1 parent 9bf4b34 commit 142ed5d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
21 changes: 17 additions & 4 deletions library/sd3_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,11 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
default=1.0,
help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。",
)
parser.add_argument(
"--resolution_shift",
action="store_true",
help="use flux resolution shift for training timestep distribution adjustment / 訓練タイムステップ分布調整のためにflux解像度シフトを使用する",
)


def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
Expand Down Expand Up @@ -992,7 +997,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):


def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz = latents.shape[0]
bsz, _, h, w = latents.shape

# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
Expand All @@ -1005,10 +1010,18 @@ def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) ->
)
t_min = args.min_timestep if args.min_timestep is not None else 0
t_max = args.max_timestep if args.max_timestep is not None else 1000
shift = args.training_shift

# weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details)
u = (u * shift) / (1 + (shift - 1) * u)
if args.resolution_shift:
mu = flux_train_utils.get_lin_function(
y1=0.5,
y2=1.15,
)((h // 2) * (w // 2))
u = flux_train_utils.time_shift(mu, 1.0, u)
else:
shift = args.training_shift

# weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details)
u = (u * shift) / (1 + (shift - 1) * u)

indices = (u * (t_max - t_min) + t_min).long()
timesteps = indices.to(device=device, dtype=dtype)
Expand Down
9 changes: 7 additions & 2 deletions sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,13 @@ def get_noise_pred_and_target(
# TODO support attention mask
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)

# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input

# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)

# flow matching loss
target = latents
Expand Down

0 comments on commit 142ed5d

Please sign in to comment.