Skip to content

Commit

Permalink
Add diffusion prepare
Browse files Browse the repository at this point in the history
  • Loading branch information
pierre.delaunay committed Jul 19, 2024
1 parent 3065756 commit 5a8cd1e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
21 changes: 11 additions & 10 deletions benchmarks/diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ class Arguments:
lr_warmup_steps: int = 500
epochs: int = 10

def step():
pass


def models(accelerator, args: Arguments):
encoder = CLIPTextModel.from_pretrained(
Expand Down Expand Up @@ -135,6 +132,7 @@ def collate_fn(examples):
collate_fn=collate_fn,
batch_size=args.batch_size,
num_workers=args.num_workers,
persistent_workers=True,
)

def train(args: Arguments):
Expand Down Expand Up @@ -214,16 +212,19 @@ def batch_size(x):
lr_scheduler.step()
optimizer.zero_grad()




def main():
from argklass import ArgumentParser
parser = ArgumentParser()
parser.add_arguments(Arguments)
config, _ = parser.parse_known_args()
from benchmate.metrics import StopProgram

try:
from argklass import ArgumentParser
parser = ArgumentParser()
parser.add_arguments(Arguments)
config, _ = parser.parse_known_args()

train(config)
train(config)
except StopProgram:
pass


if __name__ == "__main__":
Expand Down
28 changes: 25 additions & 3 deletions benchmarks/diffusion/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,44 @@
from dataclasses import dataclass
import os

from transformers import CLIPTextModel, CLIPTokenizer

from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from datasets import load_dataset


@dataclass
class TrainingConfig:
dataset_name: str = "huggan/smithsonian_butterflies_subset"
model: str = "runwayml/stable-diffusion-v1-5"
dataset: str = "lambdalabs/naruto-blip-captions"


def main():
from argklass import ArgumentParser

parser = ArgumentParser()
parser.add_arguments(TrainingConfig)
config, _ = parser.parse_known_args()
args, _ = parser.parse_known_args()

_ = load_dataset(args.dataset)

_ = CLIPTextModel.from_pretrained(
args.model, subfolder="text_encoder", revision=args.revision, variant=args.variant
)

_ = AutoencoderKL.from_pretrained(
args.model, subfolder="vae", revision=args.revision, variant=args.variant
)

_ = UNet2DConditionModel.from_pretrained(
args.model, subfolder="unet", revision=args.revision, variant=args.variant
)

_ = CLIPTokenizer.from_pretrained(
args.model, subfolder="tokenizer", revision=args.revision
)

_ = load_dataset(config.dataset_name, split="train")
_ = DDPMScheduler.from_pretrained(args.model, subfolder="scheduler")


if __name__ == "__main__":
Expand Down

0 comments on commit 5a8cd1e

Please sign in to comment.