From 053433530292a3aa66d22bcbab2544dd2e3e8f42 Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Fri, 25 Aug 2023 16:53:45 -0700 Subject: [PATCH 1/6] Add model name --- diffusion/inference/inference_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/diffusion/inference/inference_model.py b/diffusion/inference/inference_model.py index aecf07c5..ed411773 100644 --- a/diffusion/inference/inference_model.py +++ b/diffusion/inference/inference_model.py @@ -35,10 +35,11 @@ class StableDiffusionInference(): Default: ``None``. """ - def __init__(self, pretrained: bool = False, prediction_type: str = 'epsilon'): + def __init__(self, model_name: str = 'stabilityai/stable-diffusion-2-base', pretrained: bool = False, prediction_type: str = 'epsilon'): self.device = torch.cuda.current_device() model = stable_diffusion_2( + model_name=model_name, pretrained=pretrained, prediction_type=prediction_type, encode_latents_in_fp16=True, From 9b0922be1063d88b94a64602c538b32f35f613a6 Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Mon, 28 Aug 2023 12:36:29 -0700 Subject: [PATCH 2/6] Check type of input --- diffusion/inference/inference_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/diffusion/inference/inference_model.py b/diffusion/inference/inference_model.py index ed411773..07845542 100644 --- a/diffusion/inference/inference_model.py +++ b/diffusion/inference/inference_model.py @@ -75,6 +75,8 @@ def predict(self, model_requests: List[Dict[str, Any]]): prompts.append(inputs['prompt']) if 'negative_prompt' in req: negative_prompts.append(inputs['negative_prompt']) + else: + raise RuntimeError(f'Input must be of type string or dict, but it is type: {type(inputs)}') generate_kwargs = req['parameters'] From 19ce31a3c679225785ff01a76b9ce0e7bbc0e276 Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Mon, 28 Aug 2023 12:52:55 -0700 Subject: [PATCH 3/6] Debug --- diffusion/inference/inference_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/diffusion/inference/inference_model.py b/diffusion/inference/inference_model.py index 07845542..925df293 100644 --- a/diffusion/inference/inference_model.py +++ b/diffusion/inference/inference_model.py @@ -59,6 +59,7 @@ def predict(self, model_requests: List[Dict[str, Any]]): prompts = [] negative_prompts = [] generate_kwargs = {} + print(model_requests) # assumes the same generate_kwargs across all samples for req in model_requests: From 1d0c0acb1de842d05754754d79d7d6d2c3b732e1 Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Mon, 28 Aug 2023 13:36:00 -0700 Subject: [PATCH 4/6] Typo :( --- diffusion/inference/inference_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/diffusion/inference/inference_model.py b/diffusion/inference/inference_model.py index 925df293..3d1ad50d 100644 --- a/diffusion/inference/inference_model.py +++ b/diffusion/inference/inference_model.py @@ -59,7 +59,6 @@ def predict(self, model_requests: List[Dict[str, Any]]): prompts = [] negative_prompts = [] generate_kwargs = {} - print(model_requests) # assumes the same generate_kwargs across all samples for req in model_requests: @@ -70,7 +69,7 @@ def predict(self, model_requests: List[Dict[str, Any]]): # Prompts and negative prompts if available if isinstance(inputs, str): prompts.append(inputs) - elif isinstance(input, Dict): + elif isinstance(inputs, Dict): if 'prompt' not in req: raise RuntimeError('"prompt" must be provided to generate call if using a dict as input') prompts.append(inputs['prompt']) From a66f97e3b619a6051085be8b316b7de2693067a1 Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Mon, 28 Aug 2023 14:26:15 -0700 Subject: [PATCH 5/6] Fix pt. 2 --- diffusion/inference/inference_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diffusion/inference/inference_model.py b/diffusion/inference/inference_model.py index 3d1ad50d..9259b2f5 100644 --- a/diffusion/inference/inference_model.py +++ b/diffusion/inference/inference_model.py @@ -70,10 +70,10 @@ def predict(self, model_requests: List[Dict[str, Any]]): if isinstance(inputs, str): prompts.append(inputs) elif isinstance(inputs, Dict): - if 'prompt' not in req: + if 'prompt' not in inputs: raise RuntimeError('"prompt" must be provided to generate call if using a dict as input') prompts.append(inputs['prompt']) - if 'negative_prompt' in req: + if 'negative_prompt' in inputs: negative_prompts.append(inputs['negative_prompt']) else: raise RuntimeError(f'Input must be of type string or dict, but it is type: {type(inputs)}') From acc019543d151589e0d282cb161932472e13b327 Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Tue, 29 Aug 2023 09:54:50 -0700 Subject: [PATCH 6/6] STYLE --- diffusion/inference/inference_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/diffusion/inference/inference_model.py b/diffusion/inference/inference_model.py index 9259b2f5..6576e390 100644 --- a/diffusion/inference/inference_model.py +++ b/diffusion/inference/inference_model.py @@ -35,7 +35,10 @@ class StableDiffusionInference(): Default: ``None``. """ - def __init__(self, model_name: str = 'stabilityai/stable-diffusion-2-base', pretrained: bool = False, prediction_type: str = 'epsilon'): + def __init__(self, + model_name: str = 'stabilityai/stable-diffusion-2-base', + pretrained: bool = False, + prediction_type: str = 'epsilon'): self.device = torch.cuda.current_device() model = stable_diffusion_2(