Skip to content

Commit

Permalink
Merge branch 'kohya-ss:dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds authored Sep 3, 2024
2 parents 1009a2c + d5c076c commit f392a30
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser

### Working in progress

- `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened!
- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr!
- The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower.
- Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available.
Expand Down
11 changes: 9 additions & 2 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
weights_sd = torch.load(file, map_location="cpu")

# if keys are Diffusers based, convert to SAI based
convert_diffusers_to_sai_if_needed(weights_sd)
if is_sdxl:
convert_diffusers_to_sai_if_needed(weights_sd)

# get dim/alpha mapping
modules_dim = {}
Expand All @@ -840,7 +841,13 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
module_class = LoRAInfModule if for_inference else LoRAModule

network = LoRANetwork(
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
text_encoder,
unet,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
is_sdxl=is_sdxl,
)

# block lr
Expand Down
8 changes: 6 additions & 2 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,11 @@ def optimizer_hook(parameter: torch.Tensor):
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)

target = noise
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise

if (
args.min_snr_gamma
Expand All @@ -720,7 +724,7 @@ def optimizer_hook(parameter: torch.Tensor):
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
Expand Down

0 comments on commit f392a30

Please sign in to comment.