Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Feb 18, 2024
1 parent 9b0e532 commit 806a623
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ The code for training the Text Encoder is also written, but it is untested.
### Command line sample

```batch
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 stable_cascade_train_stage_c.py --mixed_precision bf16 --save_precision bf16 --max_data_loader_n_workers 2 --persistent_data_loader_workers --gradient_checkpointing --learning_rate 1e-4 --optimizer_type adafactor --optimizer_args "scale_parameter=False" "relative_step=False" "warmup_init=False" --max_train_epochs 10 --save_every_n_epochs 1 --save_precision bf16 --output_dir ../output --output_name sc_test - --stage_c_checkpoint_path ../models/stage_c_bf16.safetensors --effnet_checkpoint_path ../models/effnet_encoder.safetensors --previewer_checkpoint_path ../models/previewer.safetensors --dataset_config ../dataset/config_bs1.toml --sample_every_n_epochs 1 --sample_prompts ../dataset/prompts.txt
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 stable_cascade_train_stage_c.py --mixed_precision bf16 --save_precision bf16 --max_data_loader_n_workers 2 --persistent_data_loader_workers --gradient_checkpointing --learning_rate 1e-4 --optimizer_type adafactor --optimizer_args "scale_parameter=False" "relative_step=False" "warmup_init=False" --max_train_epochs 10 --save_every_n_epochs 1 --save_precision bf16 --output_dir ../output --output_name sc_test - --stage_c_checkpoint_path ../models/stage_c_bf16.safetensors --effnet_checkpoint_path ../models/effnet_encoder.safetensors --previewer_checkpoint_path ../models/previewer.safetensors --dataset_config ../dataset/config_bs1.toml --sample_every_n_epochs 1 --sample_prompts ../dataset/prompts.txt --adaptive_loss_weight
```

### About the dataset for fine tuning
Expand Down
20 changes: 20 additions & 0 deletions stable_cascade_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def main(args):
stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device)
stage_a.eval().requires_grad_(False)

# previewer
if args.previewer_checkpoint_path is not None:
previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=dtype, device=loading_device)
previewer.eval().requires_grad_(False)
else:
previewer = None

# 謎のクラス gdf
gdf_c = sc.GDF(
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
Expand Down Expand Up @@ -221,6 +228,18 @@ def main(args):
conditions_b["effnet"] = sampled_c
unconditions_b["effnet"] = torch.zeros_like(sampled_c)

if previewer is not None:
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
preview = previewer(sampled_c)
preview = preview.clamp(0, 1)
preview = preview.permute(0, 2, 3, 1).squeeze(0)
preview = preview.detach().float().cpu().numpy()
preview = Image.fromarray((preview * 255).astype(np.uint8))

timestamp_str = time.strftime("%Y%m%d_%H%M%S")
os.makedirs(args.outdir, exist_ok=True)
preview.save(os.path.join(args.outdir, f"preview_{timestamp_str}.png"))

if args.lowvram:
generator_c = generator_c.to(loading_device)
device_utils.clean_memory_on_device(device)
Expand Down Expand Up @@ -274,6 +293,7 @@ def main(args):
sc_utils.add_stage_a_arguments(parser)
sc_utils.add_stage_b_arguments(parser)
sc_utils.add_stage_c_arguments(parser)
sc_utils.add_previewer_arguments(parser)
sc_utils.add_text_model_arguments(parser)
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--fp16", action="store_true")
Expand Down

0 comments on commit 806a623

Please sign in to comment.