diff --git a/nexa/cli/entry.py b/nexa/cli/entry.py index 422696eb..6683a4c6 100644 --- a/nexa/cli/entry.py +++ b/nexa/cli/entry.py @@ -15,27 +15,32 @@ def run_ggml_inference(args): stop_words = kwargs.pop("stop_words", []) - if run_type == "NLP": - from nexa.gguf.nexa_inference_text import NexaTextInference - inference = NexaTextInference(model_path=model_path, local_path=local_path, stop_words=stop_words, **kwargs) - elif run_type == "Computer Vision": - from nexa.gguf.nexa_inference_image import NexaImageInference - inference = NexaImageInference(model_path=model_path, local_path=local_path, **kwargs) - if hasattr(args, 'streamlit') and args.streamlit: - inference.run_streamlit(model_path) - elif args.img2img: - inference.run_img2img() + try: + if run_type == "NLP": + from nexa.gguf.nexa_inference_text import NexaTextInference + inference = NexaTextInference(model_path=model_path, local_path=local_path, stop_words=stop_words, **kwargs) + elif run_type == "Computer Vision": + from nexa.gguf.nexa_inference_image import NexaImageInference + inference = NexaImageInference(model_path=model_path, local_path=local_path, **kwargs) + if hasattr(args, 'streamlit') and args.streamlit: + inference.run_streamlit(model_path) + elif args.img2img: + inference.run_img2img() + else: + inference.run_txt2img() + return + elif run_type == "Multimodal": + from nexa.gguf.nexa_inference_vlm import NexaVLMInference + inference = NexaVLMInference(model_path=model_path, local_path=local_path, stop_words=stop_words, **kwargs) + elif run_type == "Audio": + from nexa.gguf.nexa_inference_voice import NexaVoiceInference + inference = NexaVoiceInference(model_path=model_path, local_path=local_path, **kwargs) else: - inference.run_txt2img() + print(f"Unknown task: {run_type}. Skipping inference.") + return + except Exception as e: + print(f"Error loading GGUF models, please refer to our docs to install nexaai package: https://docs.nexaai.com/getting-started/installation ") return - elif run_type == "Multimodal": - from nexa.gguf.nexa_inference_vlm import NexaVLMInference - inference = NexaVLMInference(model_path=model_path, local_path=local_path, stop_words=stop_words, **kwargs) - elif run_type == "Audio": - from nexa.gguf.nexa_inference_voice import NexaVoiceInference - inference = NexaVoiceInference(model_path=model_path, local_path=local_path, **kwargs) - else: - raise ValueError(f"Unknown task: {run_type}") if hasattr(args, 'streamlit') and args.streamlit: inference.run_streamlit(model_path) @@ -49,20 +54,25 @@ def run_onnx_inference(args): from nexa.general import pull_model local_path, run_type = pull_model(model_path) - if run_type == "NLP": - from nexa.onnx.nexa_inference_text import NexaTextInference as NexaTextOnnxInference - inference = NexaTextOnnxInference(model_path=model_path, local_path=local_path, **kwargs) - elif run_type == "Computer Vision": - from nexa.onnx.nexa_inference_image import NexaImageInference as NexaImageOnnxInference - inference = NexaImageOnnxInference(model_path=model_path, local_path=local_path, **kwargs) - elif run_type == "Audio": - from nexa.onnx.nexa_inference_voice import NexaVoiceInference as NexaVoiceOnnxInference - inference = NexaVoiceOnnxInference(model_path=model_path, local_path=local_path, **kwargs) - elif run_type == "TTS": - from nexa.onnx.nexa_inference_tts import NexaTTSInference as NexaTTSOnnxInference - inference = NexaTTSOnnxInference(model_path=model_path, local_path=local_path, **kwargs) - else: - raise ValueError(f"Unknown task: {run_type}") + try: + if run_type == "NLP": + from nexa.onnx.nexa_inference_text import NexaTextInference as NexaTextOnnxInference + inference = NexaTextOnnxInference(model_path=model_path, local_path=local_path, **kwargs) + elif run_type == "Computer Vision": + from nexa.onnx.nexa_inference_image import NexaImageInference as NexaImageOnnxInference + inference = NexaImageOnnxInference(model_path=model_path, local_path=local_path, **kwargs) + elif run_type == "Audio": + from nexa.onnx.nexa_inference_voice import NexaVoiceInference as NexaVoiceOnnxInference + inference = NexaVoiceOnnxInference(model_path=model_path, local_path=local_path, **kwargs) + elif run_type == "TTS": + from nexa.onnx.nexa_inference_tts import NexaTTSInference as NexaTTSOnnxInference + inference = NexaTTSOnnxInference(model_path=model_path, local_path=local_path, **kwargs) + else: + print(f"Unknown task: {run_type}. Skipping inference.") + return + except Exception as e: + print(f"Error loading ONNX models, please refer to our docs to install nexaai[onnx] package: https://docs.nexaai.com/getting-started/installation ") + return if hasattr(args, 'streamlit') and args.streamlit: inference.run_streamlit(model_path) diff --git a/nexa/constants.py b/nexa/constants.py index 1c7b90ec..a6699d28 100644 --- a/nexa/constants.py +++ b/nexa/constants.py @@ -120,7 +120,7 @@ NEXA_RUN_PROJECTOR_MAP = { "nanollava": "nanoLLaVA:projector-fp16", - "nanoLLaVA:fp16": "nanoLLaVA:project-fp16", + "nanoLLaVA:fp16": "nanoLLaVA:projector-fp16", "llava-phi3": "llava-phi-3-mini:projector-q4_0", "llava-phi-3-mini:q4_0": "llava-phi-3-mini:projector-q4_0", "llava-phi-3-mini:fp16": "llava-phi-3-mini:projector-fp16", diff --git a/nexa/gguf/nexa_inference_image.py b/nexa/gguf/nexa_inference_image.py index e6d606dc..1deebb03 100644 --- a/nexa/gguf/nexa_inference_image.py +++ b/nexa/gguf/nexa_inference_image.py @@ -115,7 +115,7 @@ def _retry(self, func, *args, **kwargs): except Exception as e: logging.error(f"Attempt {attempt + 1} failed with error: {e}") time.sleep(1) - logging.error("All retry attempts failed.") + print("All retry attempts failed becase of Out of Memory error, Try to use smaller models...") return None def txt2img( diff --git a/nexa/gguf/streamlit/streamlit_image_chat.py b/nexa/gguf/streamlit/streamlit_image_chat.py index 5f4bbca6..39fa8a79 100644 --- a/nexa/gguf/streamlit/streamlit_image_chat.py +++ b/nexa/gguf/streamlit/streamlit_image_chat.py @@ -4,6 +4,7 @@ from nexa.general import pull_model import streamlit as st from nexa.gguf.nexa_inference_image import NexaImageInference +import io default_model = sys.argv[1] @@ -106,5 +107,16 @@ def generate_images(nexa_model: NexaImageInference, prompt: str, negative_prompt st.session_state.nexa_model, prompt, negative_prompt ) st.success("Images generated successfully!") - for image in images: - st.image(image, caption="Generated Image", use_column_width=True) + for i, image in enumerate(images): + st.image(image, caption=f"Generated Image", use_column_width=True) + + img_byte_arr = io.BytesIO() + image.save(img_byte_arr, format='PNG') + img_byte_arr = img_byte_arr.getvalue() + + st.download_button( + label=f"Download Image", + data=img_byte_arr, + file_name=f"generated_image.png", + mime="image/png" + ) diff --git a/nexa/onnx/streamlit/streamlit_image_chat.py b/nexa/onnx/streamlit/streamlit_image_chat.py index 67112265..56c9c140 100644 --- a/nexa/onnx/streamlit/streamlit_image_chat.py +++ b/nexa/onnx/streamlit/streamlit_image_chat.py @@ -3,7 +3,7 @@ import numpy as np import streamlit as st - +from optimum.onnxruntime import ORTLatentConsistencyModelPipeline from nexa.general import pull_model from nexa.onnx.nexa_inference_image import NexaImageInference @@ -15,11 +15,11 @@ def load_model(model_path): local_path, run_type = pull_model(model_path) nexa_model = NexaImageInference(model_path=model_path, local_path=local_path) - if nexa_model.downloaded_onnx_folder is None: + if nexa_model.download_onnx_folder is None: st.error("Failed to download the model. Please check the model path.") return None - nexa_model._load_model(nexa_model.downloaded_onnx_folder) + nexa_model._load_model(nexa_model.download_onnx_folder) return nexa_model @@ -30,17 +30,21 @@ def generate_images(nexa_model: NexaImageInference, prompt, negative_prompt): generator = np.random.RandomState(nexa_model.params["random_seed"]) - images = nexa_model.pipeline( - prompt=prompt, - negative_prompt=negative_prompt if negative_prompt else None, - num_inference_steps=nexa_model.params["num_inference_steps"], - num_images_per_prompt=nexa_model.params["num_images_per_prompt"], - height=nexa_model.params["height"], - width=nexa_model.params["width"], - generator=generator, - guidance_scale=nexa_model.params["guidance_scale"], - ).images + is_lcm_pipeline = isinstance(nexa_model.pipeline, ORTLatentConsistencyModelPipeline) + + pipeline_kwargs = { + "prompt": prompt, + "num_inference_steps": nexa_model.params["num_inference_steps"], + "num_images_per_prompt": nexa_model.params["num_images_per_prompt"], + "height": nexa_model.params["height"], + "width": nexa_model.params["width"], + "generator": generator, + "guidance_scale": nexa_model.params["guidance_scale"], + } + if not is_lcm_pipeline and negative_prompt: + pipeline_kwargs["negative_prompt"] = negative_prompt + images = nexa_model.pipeline(**pipeline_kwargs).images return images diff --git a/pyproject.toml b/pyproject.toml index f54a8e27..53bcc4b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,10 +24,9 @@ dependencies = [ "prompt_toolkit", "tqdm", # Shared dependencies "tabulate", - "streamlit", + "streamlit>=1.37.1", "streamlit-audiorec", "python-multipart", - "streamlit-audiorec", "cmake", ] classifiers = [