diff --git a/sdxl_train_controlnet.py b/sdxl_train_controlnet.py index 9f3dbcb33..266cceeeb 100644 --- a/sdxl_train_controlnet.py +++ b/sdxl_train_controlnet.py @@ -272,12 +272,14 @@ def train(args): ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" accelerator.print("enable full fp16 training.") controlnet.to(weight_dtype) + unet.to(weight_dtype) elif args.full_bf16: assert ( args.mixed_precision == "bf16" ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" accelerator.print("enable full bf16 training.") controlnet.to(weight_dtype) + unet.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(