Skip to content

Commit

Permalink
[Flax SD finetune] Fix dtype (huggingface#1038)
Browse files Browse the repository at this point in the history
fix jnp dtype
  • Loading branch information
duongna21 authored Oct 28, 2022
1 parent fb38bb1 commit 1e07b6b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/text_to_image/train_text_to_image_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,11 @@ def collate_fn(examples):
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=total_train_batch_size, drop_last=True
)

weight_dtype = torch.float32
weight_dtype = jnp.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
weight_dtype = jnp.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
weight_dtype = jnp.bfloat16

# Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
Expand Down

0 comments on commit 1e07b6b

Please sign in to comment.