Skip to content

Commit

Permalink
move import position
Browse files Browse the repository at this point in the history
  • Loading branch information
xyyimian committed Aug 21, 2024
1 parent a0b2358 commit c9f3d4f
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions nexa/gguf/nexa_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import time
from pathlib import Path

from nexa.gguf.sd.stable_diffusion import StableDiffusion
from nexa.general import pull_model
from nexa.constants import (
DEFAULT_IMG_GEN_PARAMS,
Expand Down Expand Up @@ -44,7 +43,7 @@ class NexaImageInference:
streamlit (bool): Run the inference in Streamlit UI.
"""
from nexa.gguf.sd.stable_diffusion import StableDiffusion


def __init__(self, model_path, **kwargs):
self.model_path = None
Expand Down Expand Up @@ -85,6 +84,7 @@ def __init__(self, model_path, **kwargs):
@SpinningCursorAnimation()
def _load_model(self, model_path: str):
with suppress_stdout_stderr():
from nexa.gguf.sd.stable_diffusion import StableDiffusion
self.model = StableDiffusion(
model_path=self.downloaded_path,
lora_model_dir=self.params.get("lora_dir", ""),
Expand Down
3 changes: 2 additions & 1 deletion nexa/gguf/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class NexaTextInference:
top_k (int): Top-k sampling parameter.
top_p (float): Top-p sampling parameter
"""
from nexa.gguf.llama.llama import Llama

def __init__(self, model_path, stop_words=None, **kwargs):
self.params = DEFAULT_TEXT_GEN_PARAMS
self.params.update(kwargs)
Expand Down Expand Up @@ -110,6 +110,7 @@ def _load_model(self):
logging.debug(f"Loading model from {self.downloaded_path}")
start_time = time.time()
with suppress_stdout_stderr():
from nexa.gguf.llama.llama import Llama
self.model = Llama(
model_path=self.downloaded_path,
verbose=self.profiling,
Expand Down
4 changes: 2 additions & 2 deletions nexa/gguf/nexa_inference_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from nexa.general import pull_model
from nexa.gguf.lib_utils import is_gpu_available
from nexa.gguf.llama.llama import Llama
from nexa.gguf.llama.llama_chat_format import (
Llava15ChatHandler,
Llava16ChatHandler,
Expand Down Expand Up @@ -79,7 +78,7 @@ class NexaVLMInference:
top_k (int): Top-k sampling parameter.
top_p (float): Top-p sampling parameter
"""
from nexa.gguf.llama.llama import Llama


def __init__(self, model_path, stop_words=None, **kwargs):
self.params = DEFAULT_TEXT_GEN_PARAMS
Expand Down Expand Up @@ -151,6 +150,7 @@ def _load_model(self):
if self.projector_downloaded_path
else None
)
from nexa.gguf.llama.llama import Llama
self.model = Llama(
model_path=self.downloaded_path,
chat_handler=self.projector,
Expand Down
3 changes: 1 addition & 2 deletions nexa/gguf/nexa_inference_voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

from nexa.constants import EXIT_REMINDER, NEXA_RUN_MODEL_MAP_VOICE, DEFAULT_VOICE_GEN_PARAMS
from nexa.general import pull_model
from nexa.utils import nexa_prompt
from faster_whisper import WhisperModel
from nexaai.utils import nexa_prompt, SpinningCursorAnimation, suppress_stdout_stderr
from nexa.utils import nexa_prompt, SpinningCursorAnimation, suppress_stdout_stderr

logging.basicConfig(level=logging.INFO)

Expand Down

0 comments on commit c9f3d4f

Please sign in to comment.