From 053d32ac652977dc2019743fb5faa2259d8d4875 Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Fri, 25 Aug 2023 16:49:28 -0700 Subject: [PATCH] Add prediction type to inference (#62) --- diffusion/inference/inference_model.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/diffusion/inference/inference_model.py b/diffusion/inference/inference_model.py index c4d95b83..aecf07c5 100644 --- a/diffusion/inference/inference_model.py +++ b/diffusion/inference/inference_model.py @@ -35,10 +35,16 @@ class StableDiffusionInference(): Default: ``None``. """ - def __init__(self, pretrained: bool = False): + def __init__(self, pretrained: bool = False, prediction_type: str = 'epsilon'): self.device = torch.cuda.current_device() - model = stable_diffusion_2(pretrained=pretrained, encode_latents_in_fp16=True, fsdp=False) + model = stable_diffusion_2( + pretrained=pretrained, + prediction_type=prediction_type, + encode_latents_in_fp16=True, + fsdp=False, + ) + if not pretrained: state_dict = torch.load(LOCAL_CHECKPOINT_PATH) for key in list(state_dict['state']['model'].keys()):