From 40df0eb8edd9a18271a60704c4ddb9fcc8f16273 Mon Sep 17 00:00:00 2001 From: Satya Ortiz-Gagne Date: Wed, 4 Sep 2024 23:09:34 -0400 Subject: [PATCH] Fix diffusion --- benchmarks/diffusion/benchfile.py | 6 +++++- benchmarks/diffusion/main.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/benchmarks/diffusion/benchfile.py b/benchmarks/diffusion/benchfile.py index 2458070ce..7a2f85f7c 100644 --- a/benchmarks/diffusion/benchfile.py +++ b/benchmarks/diffusion/benchfile.py @@ -1,3 +1,4 @@ +import os from milabench.pack import Package from milabench.commands import AccelerateAllNodes @@ -26,9 +27,12 @@ async def prepare(self): def build_run_plan(self): from milabench.commands import PackCommand + if "HF_TOKEN" in os.environ or "MILABENCH_HF_TOKEN" in os.environ: + os.environ["HF_TOKEN"] = os.environ.get("HF_TOKEN", os.environ["MILABENCH_HF_TOKEN"]) + main = self.dirs.code / self.main_script plan = PackCommand(self, *self.argv, lazy=True) - + if False: plan = VoirCommand(plan, cwd=main.parent) diff --git a/benchmarks/diffusion/main.py b/benchmarks/diffusion/main.py index 09513f606..750994eac 100755 --- a/benchmarks/diffusion/main.py +++ b/benchmarks/diffusion/main.py @@ -21,7 +21,7 @@ @dataclass class Arguments: - model: str = "runwayml/stable-diffusion-v1-5" + model: str = "benjamin-paine/stable-diffusion-v1-5" dataset: str = "lambdalabs/naruto-blip-captions" batch_size: int = 16 num_workers: int = 8