Skip to content

Commit

Permalink
Merge pull request #92 from NexaAI/zack-main
Browse files Browse the repository at this point in the history
Suport Flux
  • Loading branch information
zhiyuan8 authored Sep 16, 2024
2 parents 022fa02 + 05e0053 commit 353087d
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 166 deletions.
2 changes: 1 addition & 1 deletion nexa/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.8.2"
__version__ = "0.0.8.3"
46 changes: 28 additions & 18 deletions nexa/gguf/nexa_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,14 @@ def __init__(self, model_path, local_path=None, **kwargs):
self.ae_downloaded_path, _ = pull_model(self.ae_path)
if self.clip_l_path:
self.clip_l_downloaded_path, _ = pull_model(self.clip_l_path)

if "lcm-dreamshaper" in self.model_path:
self.params = DEFAULT_IMG_GEN_PARAMS_LCM
if "lcm-dreamshaper" in self.model_path or "flux" in self.model_path:
self.params = DEFAULT_IMG_GEN_PARAMS_LCM.copy() # both lcm-dreamshaper and flux use the same params
elif "sdxl-turbo" in self.model_path:
self.params = DEFAULT_IMG_GEN_PARAMS_TURBO
self.params = DEFAULT_IMG_GEN_PARAMS_TURBO.copy()
else:
self.params = DEFAULT_IMG_GEN_PARAMS

self.params.update(kwargs)
self.params = DEFAULT_IMG_GEN_PARAMS.copy()

self.params.update({k: v for k, v in kwargs.items() if v is not None})
if not kwargs.get("streamlit", False):
self._load_model(model_path)
if self.model is None:
Expand All @@ -111,17 +109,29 @@ def __init__(self, model_path, local_path=None, **kwargs):
def _load_model(self, model_path: str):
with suppress_stdout_stderr():
from nexa.gguf.sd.stable_diffusion import StableDiffusion

self.model = StableDiffusion(
model_path=self.downloaded_path,
lora_model_dir=self.params.get("lora_dir", ""),
n_threads=self.params.get("n_threads", multiprocessing.cpu_count()),
wtype=self.params.get(
"wtype", NEXA_RUN_MODEL_PRECISION_MAP.get(model_path, "f32")
), # Weight type (options: default, f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)
control_net_path=self.params.get("control_net_path", ""),
verbose=False,
)
if self.t5xxl_downloaded_path and self.ae_downloaded_path and self.clip_l_downloaded_path:
self.model = StableDiffusion(
diffusion_model_path=self.downloaded_path,
clip_l_path=self.clip_l_downloaded_path,
t5xxl_path=self.t5xxl_downloaded_path,
vae_path=self.ae_downloaded_path,
n_threads=self.params.get("n_threads", multiprocessing.cpu_count()),
wtype=self.params.get(
"wtype", NEXA_RUN_MODEL_PRECISION_MAP.get(model_path, "default")
), # Weight type (options: default, f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)
verbose=False,
)
else:
self.model = StableDiffusion(
model_path=self.downloaded_path,
lora_model_dir=self.params.get("lora_dir", ""),
n_threads=self.params.get("n_threads", multiprocessing.cpu_count()),
wtype=self.params.get(
"wtype", NEXA_RUN_MODEL_PRECISION_MAP.get(model_path, "default")
), # Weight type (options: default, f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)
control_net_path=self.params.get("control_net_path", ""),
verbose=False,
)

def _save_images(self, images):
"""
Expand Down
88 changes: 57 additions & 31 deletions nexa/gguf/sd/_internals_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
from contextlib import ExitStack

from nexa.gguf.sd._utils_diffusion import suppress_stdout_stderr

import nexa.gguf.sd.stable_diffusion_cpp as sd_cpp

Expand Down Expand Up @@ -59,6 +62,7 @@ def __init__(
self.keep_control_net_cpu = keep_control_net_cpu
self.keep_vae_on_cpu = keep_vae_on_cpu
self.verbose = verbose
self._exit_stack = ExitStack()

self.model = None

Expand All @@ -75,39 +79,50 @@ def __init__(
raise ValueError(f"Diffusion model path does not exist: {diffusion_model_path}")

if model_path or diffusion_model_path:
# Load the Stable Diffusion model ctx
self.model = sd_cpp.new_sd_ctx(
self.model_path.encode("utf-8"),
self.clip_l_path.encode("utf-8"),
self.t5xxl_path.encode("utf-8"),
self.diffusion_model_path.encode("utf-8"),
self.vae_path.encode("utf-8"),
self.taesd_path.encode("utf-8"),
self.control_net_path.encode("utf-8"),
self.lora_model_dir.encode("utf-8"),
self.embed_dir.encode("utf-8"),
self.stacked_id_embed_dir.encode("utf-8"),
self.vae_decode_only,
self.vae_tiling,
self.free_params_immediately,
self.n_threads,
self.wtype,
self.rng_type,
self.schedule,
self.keep_clip_on_cpu,
self.keep_control_net_cpu,
self.keep_vae_on_cpu,
)
with suppress_stdout_stderr(disable=verbose):
# Load the Stable Diffusion model ctx
self.model = sd_cpp.new_sd_ctx(
self.model_path.encode("utf-8"),
self.clip_l_path.encode("utf-8"),
self.t5xxl_path.encode("utf-8"),
self.diffusion_model_path.encode("utf-8"),
self.vae_path.encode("utf-8"),
self.taesd_path.encode("utf-8"),
self.control_net_path.encode("utf-8"),
self.lora_model_dir.encode("utf-8"),
self.embed_dir.encode("utf-8"),
self.stacked_id_embed_dir.encode("utf-8"),
self.vae_decode_only,
self.vae_tiling,
self.free_params_immediately,
self.n_threads,
self.wtype,
self.rng_type,
self.schedule,
self.keep_clip_on_cpu,
self.keep_control_net_cpu,
self.keep_vae_on_cpu,
)

# Check if the model was loaded successfully
if self.model is None:
raise ValueError(f"Failed to load model from file: {model_path}")

def free_ctx():
"""Free the model from memory."""
if self.model is not None and self._free_sd_ctx is not None:
self._free_sd_ctx(self.model)
self.model = None

self._exit_stack.callback(free_ctx)

def close(self):
"""Closes the exit stack, ensuring all context managers are exited."""
self._exit_stack.close()

def __del__(self):
"""Free the model when the object is deleted."""
if self.model is not None and self._free_sd_ctx is not None:
self._free_sd_ctx(self.model)
self.model = None
"""Free memory when the object is deleted."""
self.close()


# ============================================
Expand All @@ -132,6 +147,7 @@ def __init__(
self.n_threads = n_threads
self.wtype = wtype
self.verbose = verbose
self._exit_stack = ExitStack()

self.upscaler = None

Expand All @@ -151,8 +167,18 @@ def __init__(
if self.upscaler is None:
raise ValueError(f"Failed to load upscaler model from file: {upscaler_path}")

def free_ctx():
"""Free the model from memory."""
if self.upscaler is not None and self._free_upscaler_ctx is not None:
self._free_upscaler_ctx(self.upscaler)
self.upscaler = None

self._exit_stack.callback(free_ctx)

def close(self):
"""Closes the exit stack, ensuring all context managers are exited."""
self._exit_stack.close()

def __del__(self):
"""Free the upscaler model when the object is deleted."""
if self.upscaler is not None and self._free_upscaler_ctx is not None:
self._free_upscaler_ctx(self.upscaler)
self.upscaler = None
"""Free memory when the object is deleted."""
self.close()
2 changes: 0 additions & 2 deletions nexa/gguf/sd/_logger_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import ctypes
import logging

import nexa.gguf.sd.stable_diffusion_cpp as stable_diffusion_cpp

# enum sd_log_level_t {
# SD_LOG_DEBUG = 0,
# SD_LOG_INFO = 1,
Expand Down
59 changes: 59 additions & 0 deletions nexa/gguf/sd/_utils_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import sys

# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
outnull_file = open(os.devnull, "w")
errnull_file = open(os.devnull, "w")

STDOUT_FILENO = 1
STDERR_FILENO = 2


class suppress_stdout_stderr(object):
"""
Stops all output to stdout and stderr when used as a context manager (GGML will otherwise still print logs).
Source: https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/_utils.py
"""

# NOTE: these must be "saved" here to avoid exceptions when using
# this context manager inside of a __del__ method
sys = sys
os = os

def __init__(self, disable: bool = True):
self.disable = disable

# Oddly enough this works better than the contextlib version
def __enter__(self):
if self.disable:
return self

self.old_stdout_fileno_undup = STDOUT_FILENO
self.old_stderr_fileno_undup = STDERR_FILENO

self.old_stdout_fileno = self.os.dup(self.old_stdout_fileno_undup)
self.old_stderr_fileno = self.os.dup(self.old_stderr_fileno_undup)

self.old_stdout = self.sys.stdout
self.old_stderr = self.sys.stderr

self.os.dup2(outnull_file.fileno(), self.old_stdout_fileno_undup)
self.os.dup2(errnull_file.fileno(), self.old_stderr_fileno_undup)

self.sys.stdout = outnull_file
self.sys.stderr = errnull_file
return self

def __exit__(self, *_):
if self.disable:
return

# Check if sys.stdout and sys.stderr have fileno method
self.sys.stdout = self.old_stdout
self.sys.stderr = self.old_stderr

self.os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
self.os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)

self.os.close(self.old_stdout_fileno)
self.os.close(self.old_stderr_fileno)
Loading

0 comments on commit 353087d

Please sign in to comment.