Skip to content

Commit

Permalink
Fix diffusers repo id naming (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil authored May 17, 2024
1 parent 67f10ae commit 14fc8ac
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
5 changes: 3 additions & 2 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,9 @@ def load_model_from_pretrained(self) -> None:
elif self.config.library == "diffusers":
self.logger.info("\t+ Loading Diffusion Pipeline")
self.pretrained_model = self.automodel_class.from_pretrained(
pretrained_model_name_or_path=self.config.model,
pretrained_model_or_path=self.config.model,
# pretrained_model_name_or_path=self.config.model,
# pretrained_model_or_path=self.config.model,
self.config.model,
device_map=self.config.device_map,
**self.config.model_kwargs,
**self.automodel_kwargs,
Expand Down
7 changes: 4 additions & 3 deletions optimum_benchmark/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
"inpainting": "AutoPipelineForInpainting",
"text-to-image": "AutoPipelineForText2Image",
"image-to-image": "AutoPipelineForImage2Image",
"stable-diffusion": "StableDiffusionPipeline", # legacy
"stable-diffusion-xl": "StableDiffusionXLImg2ImgPipeline", # legacy
"stable-diffusion": "StableDiffusionPipeline", # should be deprecated
"stable-diffusion-xl": "StableDiffusionXLImg2ImgPipeline", # should be deprecated
}
_TIMM_TASKS_TO_MODEL_LOADERS = {
"image-classification": "create_model",
Expand Down Expand Up @@ -146,6 +146,8 @@ def infer_task_from_model_name_or_path(model_name_or_path: str, revision: Option
inferred_task_name = "text-to-image"
elif "image-to-image" in model_info.tags:
inferred_task_name = "image-to-image"
elif "inpainting" in model_info.tags:
inferred_task_name = "inpainting"
else:
class_name = model_info.config["diffusers"]["class_name"]
inferred_task_name = "stable-diffusion-xl" if "XL" in class_name else "stable-diffusion"
Expand All @@ -165,7 +167,6 @@ def infer_task_from_model_name_or_path(model_name_or_path: str, revision: Option
if class_name_for_task == auto_model_class_name:
inferred_task_name = task_name
break

inferred_task_name = None

else:
Expand Down

0 comments on commit 14fc8ac

Please sign in to comment.