Skip to content

Commit

Permalink
Merge pull request #32 from NexaAI/ethan/bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiyuan8 authored Aug 22, 2024
2 parents ae75950 + cae6702 commit d623721
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Example:
`docker run -v /home/ubuntu/.cache/nexa/hub/official:/model -it nexa4ai/sdk:latest nexa gen-text /model/Phi-3-mini-128k-instruct/q4_0.gguf`

will create an interactive session with text generation
```


## Nexa CLI commands

Expand Down
8 changes: 5 additions & 3 deletions nexa/gguf/nexa_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
DEFAULT_IMG_GEN_PARAMS_LCM,
DEFAULT_IMG_GEN_PARAMS_TURBO,
)
from nexa.utils import SpinningCursorAnimation, nexa_prompt, suppress_stdout_stderr
from nexa.utils import SpinningCursorAnimation, nexa_prompt
from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr

from streamlit.web import cli as stcli

logging.basicConfig(
Expand Down Expand Up @@ -142,7 +144,7 @@ def txt2img(self,
)
return images

def run_txt2img(self):
def loop_txt2img(self):
while True:
try:
prompt = nexa_prompt("Enter your prompt: ")
Expand Down Expand Up @@ -313,4 +315,4 @@ def run_streamlit(self, model_path: str):
if args.img2img:
inference.run_img2img()
else:
inference.run_txt2img()
inference.loop_txt2img()
4 changes: 3 additions & 1 deletion nexa/gguf/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
)
from nexa.general import pull_model
from nexa.gguf.lib_utils import is_gpu_available
from nexa.utils import SpinningCursorAnimation, nexa_prompt, suppress_stdout_stderr
from nexa.utils import SpinningCursorAnimation, nexa_prompt
from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr


logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
Expand Down
3 changes: 2 additions & 1 deletion nexa/gguf/nexa_inference_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
Llava16ChatHandler,
NanoLlavaChatHandler,
)
from nexa.utils import SpinningCursorAnimation, nexa_prompt, suppress_stdout_stderr
from nexa.utils import SpinningCursorAnimation, nexa_prompt
from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
Expand Down
5 changes: 4 additions & 1 deletion nexa/gguf/nexa_inference_voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
NEXA_RUN_MODEL_MAP_VOICE,
)
from nexa.general import pull_model
from nexa.utils import nexa_prompt, SpinningCursorAnimation, suppress_stdout_stderr
from nexa.utils import nexa_prompt, SpinningCursorAnimation
from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr


logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -75,6 +77,7 @@ def _load_model(self):

logging.debug(f"Loading model from: {self.downloaded_path}")
with suppress_stdout_stderr():
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
self.model = WhisperModel(
self.downloaded_path,
device="cpu",
Expand Down
5 changes: 3 additions & 2 deletions nexa/onnx/nexa_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from nexa.constants import EXIT_REMINDER, NEXA_RUN_MODEL_MAP_ONNX
from nexa.general import pull_model
from nexa.utils import nexa_prompt
from nexa.utils import nexa_prompt, SpinningCursorAnimation

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -73,6 +73,7 @@ def run(self):
# Step 3: Enter dialogue mode
self._dialogue_mode()

@SpinningCursorAnimation()
def _load_model(self, model_path):
"""
Load the model from the given model path using the appropriate pipeline.
Expand Down Expand Up @@ -147,7 +148,7 @@ def generate_images(self, prompt, negative_prompt):
images = self.pipeline(**pipeline_kwargs).images
return images



def _save_images(self, images):
"""
Expand Down
3 changes: 2 additions & 1 deletion nexa/onnx/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from nexa.constants import NEXA_RUN_MODEL_MAP_ONNX
from nexa.general import pull_model
from nexa.utils import nexa_prompt
from nexa.utils import nexa_prompt, SpinningCursorAnimation

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -51,6 +51,7 @@ def __init__(self, model_path, **kwargs):
self.timings = kwargs.get("timings", False)
self.device = "cpu"

@SpinningCursorAnimation()
def _load_model_and_tokenizer(self) -> Tuple[Any, Any, Any, bool]:
logging.debug(f"Loading model from {self.downloaded_onnx_folder}")
start_time = time.time()
Expand Down
3 changes: 0 additions & 3 deletions nexa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
from prompt_toolkit.styles import Style

from nexa.constants import EXIT_COMMANDS, EXIT_REMINDER
from nexa.gguf.llama._utils_transformers import (
suppress_stdout_stderr,
) # re-import, don't comment out


def is_package_installed(package_name: str) -> bool:
Expand Down
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "scikit_build_core.build"

[project]
name = "nexaai"
version = "0.0.2.dev"
version = "0.0.5"
description = "Nexa AI SDK"
readme = "README.md"
license = { text = "MIT" }
Expand All @@ -16,8 +16,6 @@ dependencies = [
"diskcache>=5.6.1",
"jinja2>=2.11.3",
"librosa>=0.8.0",
"boto3>=1.34.148",
"botocore>=1.34.148",
"fastapi",
"uvicorn",
"pydantic",
Expand All @@ -38,7 +36,7 @@ classifiers = [
[project.optional-dependencies]
onnx = [
"librosa",
"optimum[onnxruntime]>=1.7.3", # for CPU version
"optimum[onnxruntime]", # for CPU version
"diffusers", # required for image generation
"optuna",
"pydantic",
Expand Down

0 comments on commit d623721

Please sign in to comment.