Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add retry logic for stable diffusion #55

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 60 additions & 38 deletions nexa/gguf/nexa_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
RETRY_ATTEMPTS = (
3 # a temporary fix for the issue of segmentation fault for stable-diffusion-cpp
)


class NexaImageInference:
Expand All @@ -48,7 +51,6 @@ class NexaImageInference:

"""


def __init__(self, model_path, local_path=None, **kwargs):
self.model_path = model_path
self.downloaded_path = local_path
Expand Down Expand Up @@ -81,6 +83,7 @@ def __init__(self, model_path, local_path=None, **kwargs):
def _load_model(self, model_path: str):
with suppress_stdout_stderr():
from nexa.gguf.sd.stable_diffusion import StableDiffusion

self.model = StableDiffusion(
model_path=self.downloaded_path,
lora_model_dir=self.params.get("lora_dir", ""),
Expand All @@ -105,16 +108,28 @@ def _save_images(self, images):
image.save(file_path)
print(f"\nImage {i+1} saved to: {os.path.abspath(file_path)}")

def txt2img(self,
prompt,
negative_prompt="",
cfg_scale=7.5,
width=512,
height=512,
sample_steps=20,
seed=0,
control_cond="",
control_strength=0.9):
def _retry(self, func, *args, **kwargs):
for attempt in range(RETRY_ATTEMPTS):
try:
return func(*args, **kwargs)
except Exception as e:
logging.error(f"Attempt {attempt + 1} failed with error: {e}")
time.sleep(1)
logging.error("All retry attempts failed.")
return None

def txt2img(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
width=512,
height=512,
sample_steps=20,
seed=0,
control_cond="",
control_strength=0.9,
):
"""
Used for SDK. Generate images from text.

Expand All @@ -125,7 +140,8 @@ def txt2img(self,
Returns:
list: List of generated images.
"""
images = self.model.txt_to_img(
images = self._retry(
self.model.txt_to_img,
prompt=prompt,
negative_prompt=negative_prompt,
cfg_scale=cfg_scale,
Expand Down Expand Up @@ -157,25 +173,28 @@ def run_txt2img(self):
control_cond=self.params.get("control_image_path", ""),
control_strength=self.params.get("control_strength", 0.9),
)
self._save_images(images)
if images:
self._save_images(images)
except Exception as e:
logging.error(f"Error during text to image generation: {e}")
except KeyboardInterrupt:
print(EXIT_REMINDER)
except Exception as e:
logging.error(f"Error during generation: {e}", exc_info=True)

def img2img(self,
image_path,
prompt,
negative_prompt="",
cfg_scale=7.5,
width=512,
height=512,
sample_steps=20,
seed=0,
control_cond="",
control_strength=0.9):
def img2img(
self,
image_path,
prompt,
negative_prompt="",
cfg_scale=7.5,
width=512,
height=512,
sample_steps=20,
seed=0,
control_cond="",
control_strength=0.9,
):
"""
Used for SDK. Generate images from an image.

Expand All @@ -187,7 +206,8 @@ def img2img(self,
Returns:
list: List of generated images.
"""
images = self.model.img_to_img(
images = self._retry(
self.model.img_to_img,
image=image_path,
prompt=prompt,
negative_prompt=negative_prompt,
Expand All @@ -209,19 +229,21 @@ def run_img2img(self):
negative_prompt = nexa_prompt(
"Enter your negative prompt (press Enter to skip): "
)
images = self.img2img(image_path,
prompt,
negative_prompt,
cfg_scale=self.params["guidance_scale"],
width=self.params["width"],
height=self.params["height"],
sample_steps=self.params["num_inference_steps"],
seed=self.params["random_seed"],
control_cond=self.params.get("control_image_path", ""),
control_strength=self.params.get("control_strength", 0.9),
)
images = self.img2img(
image_path,
prompt,
negative_prompt,
cfg_scale=self.params["guidance_scale"],
width=self.params["width"],
height=self.params["height"],
sample_steps=self.params["num_inference_steps"],
seed=self.params["random_seed"],
control_cond=self.params.get("control_image_path", ""),
control_strength=self.params.get("control_strength", 0.9),
)

self._save_images(images)
if images:
self._save_images(images)
except KeyboardInterrupt:
print(EXIT_REMINDER)
except Exception as e:
Expand Down Expand Up @@ -309,4 +331,4 @@ def run_streamlit(self, model_path: str):
if args.img2img:
inference.run_img2img()
else:
inference.run_txt2img()
inference.run_txt2img()
Loading