Skip to content

Commit

Permalink
Refactor code to ensure args.guidance_scale is always a float kohya-s…
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Aug 29, 2024
1 parent 930d709 commit 8ecf0fc
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,8 +688,8 @@ def optimizer_hook(parameter: torch.Tensor):
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)

# get guidance
guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device)
# get guidance: ensure args.guidance_scale is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)

# call model
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
Expand Down

0 comments on commit 8ecf0fc

Please sign in to comment.