diff --git a/nexa/gguf/__init__.py b/nexa/gguf/__init__.py index 0001ab54..6ab29ece 100644 --- a/nexa/gguf/__init__.py +++ b/nexa/gguf/__init__.py @@ -1,11 +1,11 @@ -# from .nexa_inference_image import NexaImageInference -# from .nexa_inference_text import NexaTextInference -# from .nexa_inference_vlm import NexaVLMInference -# from .nexa_inference_voice import NexaVoiceInference +from .nexa_inference_image import NexaImageInference +from .nexa_inference_text import NexaTextInference +from .nexa_inference_vlm import NexaVLMInference +from .nexa_inference_voice import NexaVoiceInference -# __all__ = [ -# "NexaImageInference", -# "NexaTextInference", -# "NexaVLMInference", -# "NexaVoiceInference", -# ] \ No newline at end of file +__all__ = [ + "NexaImageInference", + "NexaTextInference", + "NexaVLMInference", + "NexaVoiceInference", +] \ No newline at end of file diff --git a/nexa/gguf/nexa_inference_vlm.py b/nexa/gguf/nexa_inference_vlm.py index 1e2ab005..63061852 100644 --- a/nexa/gguf/nexa_inference_vlm.py +++ b/nexa/gguf/nexa_inference_vlm.py @@ -219,7 +219,14 @@ def run(self): logging.error(f"Error during generation: {e}", exc_info=True) print("\n") - def create_chat_completion(self, messages, temperature, max_tokens, top_k, top_p, stream, stop): + def create_chat_completion(self, + messages, + max_tokens:int = 2048, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + stream=False, + stop=[]): """ Generate text completion for a given chat prompt. @@ -231,6 +238,34 @@ def create_chat_completion(self, messages, temperature, max_tokens, top_k, top_p top_p (float): Top-p sampling parameter. stream (bool): Stream the output. stop (list): List of stop words for early stopping. + + Returns: + Iterator: An iterator of the generated text completion + return format: + { + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "The 2020 World Series was played in Texas at Globe Life Field in Arlington.", + "role": "assistant" + }, + "logprobs": null + } + ], + "created": 1677664795, + "id": "chatcmpl-7QyqpwdfhqwajicIEznoc6Q47XAyW", + "model": "gpt-4o-mini", + "object": "chat.completion", + "usage": { + "completion_tokens": 17, + "prompt_tokens": 57, + "total_tokens": 74 + } + } + usage: message = completion.choices[0].message.content + """ return self.model.create_chat_completion( messages=messages, diff --git a/tests/test_vlm.py b/tests/test_vlm.py index 17400bd4..d8977a68 100644 --- a/tests/test_vlm.py +++ b/tests/test_vlm.py @@ -1,8 +1,7 @@ import base64 import os -from nexa.gguf.llama import llama -from nexa.gguf.llama.llama_chat_format import NanoLlavaChatHandler +from nexa.gguf import NexaVLMInference from tests.utils import download_model from nexa.gguf.lib_utils import is_gpu_available import tempfile @@ -23,18 +22,10 @@ def test_image_generation(): model_url = "https://nexa-model-hub-bucket.s3.us-west-1.amazonaws.com/public/nanoLLaVA/model-fp16.gguf" mmproj_url = "https://nexa-model-hub-bucket.s3.us-west-1.amazonaws.com/public/nanoLLaVA/projector-fp16.gguf" - model_path = download_model(model_url, temp_dir) - mmproj_path = download_model(mmproj_url, temp_dir) - chat_handler = NanoLlavaChatHandler(clip_model_path=mmproj_path) - - llm = llama.Llama( - model_path=model_path, - chat_handler=chat_handler, - n_ctx=2048, # n_ctx should be increased to accommodate the image embedding - n_gpu_layers=-1 if is_gpu_available() else 0, # Uncomment to use GPU acceleration - verbose=False, + model = NexaVLMInference( + model_path="nanollava", ) - output = llm.create_chat_completion( + output = model.create_chat_completion( messages=[ { "role": "system",