diff --git a/nexa/gguf/nexa_inference_image.py b/nexa/gguf/nexa_inference_image.py index d674b6e4..9bb4b2bb 100644 --- a/nexa/gguf/nexa_inference_image.py +++ b/nexa/gguf/nexa_inference_image.py @@ -91,16 +91,14 @@ def __init__(self, model_path, local_path=None, **kwargs): self.ae_downloaded_path, _ = pull_model(self.ae_path) if self.clip_l_path: self.clip_l_downloaded_path, _ = pull_model(self.clip_l_path) - if "lcm-dreamshaper" in self.model_path or "flux" in self.model_path: - self.params = DEFAULT_IMG_GEN_PARAMS_LCM # both lcm-dreamshaper and flux use the same params + self.params = DEFAULT_IMG_GEN_PARAMS_LCM.copy() # both lcm-dreamshaper and flux use the same params elif "sdxl-turbo" in self.model_path: - self.params = DEFAULT_IMG_GEN_PARAMS_TURBO + self.params = DEFAULT_IMG_GEN_PARAMS_TURBO.copy() else: - self.params = DEFAULT_IMG_GEN_PARAMS - - self.params.update(kwargs) + self.params = DEFAULT_IMG_GEN_PARAMS.copy() + self.params.update({k: v for k, v in kwargs.items() if v is not None}) if not kwargs.get("streamlit", False): self._load_model(model_path) if self.model is None: @@ -111,12 +109,12 @@ def __init__(self, model_path, local_path=None, **kwargs): def _load_model(self, model_path: str): with suppress_stdout_stderr(): from nexa.gguf.sd.stable_diffusion import StableDiffusion - if self.t5xxl_path and self.ae_path and self.clip_l_path: + if self.t5xxl_downloaded_path and self.ae_downloaded_path and self.clip_l_downloaded_path: self.model = StableDiffusion( diffusion_model_path=self.downloaded_path, - clip_l_path=self.clip_l_path, - t5xxl_path=self.t5xxl_path, - vae_path=self.ae_path, + clip_l_path=self.clip_l_downloaded_path, + t5xxl_path=self.t5xxl_downloaded_path, + vae_path=self.ae_downloaded_path, n_threads=self.params.get("n_threads", multiprocessing.cpu_count()), wtype=self.params.get( "wtype", NEXA_RUN_MODEL_PRECISION_MAP.get(model_path, "default")