Skip to content

Commit

Permalink
use correct download path for flux
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiyuan8 committed Sep 16, 2024
1 parent dda74a8 commit 05e0053
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions nexa/gguf/nexa_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,14 @@ def __init__(self, model_path, local_path=None, **kwargs):
self.ae_downloaded_path, _ = pull_model(self.ae_path)
if self.clip_l_path:
self.clip_l_downloaded_path, _ = pull_model(self.clip_l_path)

if "lcm-dreamshaper" in self.model_path or "flux" in self.model_path:
self.params = DEFAULT_IMG_GEN_PARAMS_LCM # both lcm-dreamshaper and flux use the same params
self.params = DEFAULT_IMG_GEN_PARAMS_LCM.copy() # both lcm-dreamshaper and flux use the same params
elif "sdxl-turbo" in self.model_path:
self.params = DEFAULT_IMG_GEN_PARAMS_TURBO
self.params = DEFAULT_IMG_GEN_PARAMS_TURBO.copy()
else:
self.params = DEFAULT_IMG_GEN_PARAMS

self.params.update(kwargs)
self.params = DEFAULT_IMG_GEN_PARAMS.copy()

self.params.update({k: v for k, v in kwargs.items() if v is not None})
if not kwargs.get("streamlit", False):
self._load_model(model_path)
if self.model is None:
Expand All @@ -111,12 +109,12 @@ def __init__(self, model_path, local_path=None, **kwargs):
def _load_model(self, model_path: str):
with suppress_stdout_stderr():
from nexa.gguf.sd.stable_diffusion import StableDiffusion
if self.t5xxl_path and self.ae_path and self.clip_l_path:
if self.t5xxl_downloaded_path and self.ae_downloaded_path and self.clip_l_downloaded_path:
self.model = StableDiffusion(
diffusion_model_path=self.downloaded_path,
clip_l_path=self.clip_l_path,
t5xxl_path=self.t5xxl_path,
vae_path=self.ae_path,
clip_l_path=self.clip_l_downloaded_path,
t5xxl_path=self.t5xxl_downloaded_path,
vae_path=self.ae_downloaded_path,
n_threads=self.params.get("n_threads", multiprocessing.cpu_count()),
wtype=self.params.get(
"wtype", NEXA_RUN_MODEL_PRECISION_MAP.get(model_path, "default")
Expand Down

0 comments on commit 05e0053

Please sign in to comment.