From 838310a29189c7bbd1a0d45eeecf92cec62ab2a0 Mon Sep 17 00:00:00 2001 From: Zack Zhiyuan Li Date: Wed, 28 Aug 2024 02:09:53 -0700 Subject: [PATCH] add retry --- nexa/gguf/nexa_inference_image.py | 98 +++++++++++++++++++------------ 1 file changed, 60 insertions(+), 38 deletions(-) diff --git a/nexa/gguf/nexa_inference_image.py b/nexa/gguf/nexa_inference_image.py index 8a0202dd..e6d606dc 100644 --- a/nexa/gguf/nexa_inference_image.py +++ b/nexa/gguf/nexa_inference_image.py @@ -22,6 +22,9 @@ logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) +RETRY_ATTEMPTS = ( + 3 # a temporary fix for the issue of segmentation fault for stable-diffusion-cpp +) class NexaImageInference: @@ -48,7 +51,6 @@ class NexaImageInference: """ - def __init__(self, model_path, local_path=None, **kwargs): self.model_path = model_path self.downloaded_path = local_path @@ -81,6 +83,7 @@ def __init__(self, model_path, local_path=None, **kwargs): def _load_model(self, model_path: str): with suppress_stdout_stderr(): from nexa.gguf.sd.stable_diffusion import StableDiffusion + self.model = StableDiffusion( model_path=self.downloaded_path, lora_model_dir=self.params.get("lora_dir", ""), @@ -105,16 +108,28 @@ def _save_images(self, images): image.save(file_path) print(f"\nImage {i+1} saved to: {os.path.abspath(file_path)}") - def txt2img(self, - prompt, - negative_prompt="", - cfg_scale=7.5, - width=512, - height=512, - sample_steps=20, - seed=0, - control_cond="", - control_strength=0.9): + def _retry(self, func, *args, **kwargs): + for attempt in range(RETRY_ATTEMPTS): + try: + return func(*args, **kwargs) + except Exception as e: + logging.error(f"Attempt {attempt + 1} failed with error: {e}") + time.sleep(1) + logging.error("All retry attempts failed.") + return None + + def txt2img( + self, + prompt, + negative_prompt="", + cfg_scale=7.5, + width=512, + height=512, + sample_steps=20, + seed=0, + control_cond="", + control_strength=0.9, + ): """ Used for SDK. Generate images from text. @@ -125,7 +140,8 @@ def txt2img(self, Returns: list: List of generated images. """ - images = self.model.txt_to_img( + images = self._retry( + self.model.txt_to_img, prompt=prompt, negative_prompt=negative_prompt, cfg_scale=cfg_scale, @@ -157,7 +173,8 @@ def run_txt2img(self): control_cond=self.params.get("control_image_path", ""), control_strength=self.params.get("control_strength", 0.9), ) - self._save_images(images) + if images: + self._save_images(images) except Exception as e: logging.error(f"Error during text to image generation: {e}") except KeyboardInterrupt: @@ -165,17 +182,19 @@ def run_txt2img(self): except Exception as e: logging.error(f"Error during generation: {e}", exc_info=True) - def img2img(self, - image_path, - prompt, - negative_prompt="", - cfg_scale=7.5, - width=512, - height=512, - sample_steps=20, - seed=0, - control_cond="", - control_strength=0.9): + def img2img( + self, + image_path, + prompt, + negative_prompt="", + cfg_scale=7.5, + width=512, + height=512, + sample_steps=20, + seed=0, + control_cond="", + control_strength=0.9, + ): """ Used for SDK. Generate images from an image. @@ -187,7 +206,8 @@ def img2img(self, Returns: list: List of generated images. """ - images = self.model.img_to_img( + images = self._retry( + self.model.img_to_img, image=image_path, prompt=prompt, negative_prompt=negative_prompt, @@ -209,19 +229,21 @@ def run_img2img(self): negative_prompt = nexa_prompt( "Enter your negative prompt (press Enter to skip): " ) - images = self.img2img(image_path, - prompt, - negative_prompt, - cfg_scale=self.params["guidance_scale"], - width=self.params["width"], - height=self.params["height"], - sample_steps=self.params["num_inference_steps"], - seed=self.params["random_seed"], - control_cond=self.params.get("control_image_path", ""), - control_strength=self.params.get("control_strength", 0.9), - ) + images = self.img2img( + image_path, + prompt, + negative_prompt, + cfg_scale=self.params["guidance_scale"], + width=self.params["width"], + height=self.params["height"], + sample_steps=self.params["num_inference_steps"], + seed=self.params["random_seed"], + control_cond=self.params.get("control_image_path", ""), + control_strength=self.params.get("control_strength", 0.9), + ) - self._save_images(images) + if images: + self._save_images(images) except KeyboardInterrupt: print(EXIT_REMINDER) except Exception as e: @@ -309,4 +331,4 @@ def run_streamlit(self, model_path: str): if args.img2img: inference.run_img2img() else: - inference.run_txt2img() \ No newline at end of file + inference.run_txt2img()