From 5a8cd1ecd68c3502a68e47ac90bb018667712cf7 Mon Sep 17 00:00:00 2001 From: "pierre.delaunay" Date: Fri, 19 Jul 2024 13:43:59 -0400 Subject: [PATCH] Add diffusion prepare --- benchmarks/diffusion/main.py | 21 +++++++++++---------- benchmarks/diffusion/prepare.py | 28 +++++++++++++++++++++++++--- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/benchmarks/diffusion/main.py b/benchmarks/diffusion/main.py index 6a174b903..cff53451a 100644 --- a/benchmarks/diffusion/main.py +++ b/benchmarks/diffusion/main.py @@ -45,9 +45,6 @@ class Arguments: lr_warmup_steps: int = 500 epochs: int = 10 -def step(): - pass - def models(accelerator, args: Arguments): encoder = CLIPTextModel.from_pretrained( @@ -135,6 +132,7 @@ def collate_fn(examples): collate_fn=collate_fn, batch_size=args.batch_size, num_workers=args.num_workers, + persistent_workers=True, ) def train(args: Arguments): @@ -214,16 +212,19 @@ def batch_size(x): lr_scheduler.step() optimizer.zero_grad() - - def main(): - from argklass import ArgumentParser - parser = ArgumentParser() - parser.add_arguments(Arguments) - config, _ = parser.parse_known_args() + from benchmate.metrics import StopProgram + + try: + from argklass import ArgumentParser + parser = ArgumentParser() + parser.add_arguments(Arguments) + config, _ = parser.parse_known_args() - train(config) + train(config) + except StopProgram: + pass if __name__ == "__main__": diff --git a/benchmarks/diffusion/prepare.py b/benchmarks/diffusion/prepare.py index 9be06e708..6ee03ad0d 100755 --- a/benchmarks/diffusion/prepare.py +++ b/benchmarks/diffusion/prepare.py @@ -2,12 +2,16 @@ from dataclasses import dataclass import os +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler from datasets import load_dataset @dataclass class TrainingConfig: - dataset_name: str = "huggan/smithsonian_butterflies_subset" + model: str = "runwayml/stable-diffusion-v1-5" + dataset: str = "lambdalabs/naruto-blip-captions" def main(): @@ -15,9 +19,27 @@ def main(): parser = ArgumentParser() parser.add_arguments(TrainingConfig) - config, _ = parser.parse_known_args() + args, _ = parser.parse_known_args() + + _ = load_dataset(args.dataset) + + _ = CLIPTextModel.from_pretrained( + args.model, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + + _ = AutoencoderKL.from_pretrained( + args.model, subfolder="vae", revision=args.revision, variant=args.variant + ) + + _ = UNet2DConditionModel.from_pretrained( + args.model, subfolder="unet", revision=args.revision, variant=args.variant + ) + + _ = CLIPTokenizer.from_pretrained( + args.model, subfolder="tokenizer", revision=args.revision + ) - _ = load_dataset(config.dataset_name, split="train") + _ = DDPMScheduler.from_pretrained(args.model, subfolder="scheduler") if __name__ == "__main__":