diff --git a/dependency/llama.cpp b/dependency/llama.cpp index 5859733c..868ceb69 160000 --- a/dependency/llama.cpp +++ b/dependency/llama.cpp @@ -1 +1 @@ -Subproject commit 5859733cb86dbd61091a8045203b33df056e536a +Subproject commit 868ceb6965faa43de7ab9bc3efd97104185406c8 diff --git a/nexa/gguf/llama/omni_vlm_cpp.py b/nexa/gguf/llama/omni_vlm_cpp.py index ab7a5699..e6d4312a 100644 --- a/nexa/gguf/llama/omni_vlm_cpp.py +++ b/nexa/gguf/llama/omni_vlm_cpp.py @@ -71,9 +71,33 @@ def omnivlm_inference(prompt: omni_char_p, image_path: omni_char_p): _lib.omnivlm_inference.restype = omni_char_p +def omnivlm_inference_streaming(prompt: omni_char_p, image_path: omni_char_p): + return _lib.omnivlm_inference_streaming(prompt, image_path) + + +_lib.omnivlm_inference_streaming.argtypes = [omni_char_p, omni_char_p] +_lib.omnivlm_inference_streaming.restype = ctypes.c_void_p + + +def sample(omni_vlm_streaming: ctypes.c_void_p): + return _lib.sample(omni_vlm_streaming) + + +_lib.sample.argtypes = [ctypes.c_void_p] +_lib.sample.restype = ctypes.c_int32 + + +def get_str(omni_vlm_streaming: ctypes.c_void_p): + return _lib.get_str(omni_vlm_streaming) + + +_lib.get_str.argtypes = [ctypes.c_void_p] +_lib.get_str.restype = ctypes.c_char_p + + def omnivlm_free(): return _lib.omnivlm_free() _lib.omnivlm_free.argtypes = [] -_lib.omnivlm_free.restype = None \ No newline at end of file +_lib.omnivlm_free.restype = None diff --git a/nexa/gguf/nexa_inference_vlm_omni.py b/nexa/gguf/nexa_inference_vlm_omni.py index 4a76a4eb..2f111847 100644 --- a/nexa/gguf/nexa_inference_vlm_omni.py +++ b/nexa/gguf/nexa_inference_vlm_omni.py @@ -1,3 +1,4 @@ +import time import ctypes import logging import os @@ -126,7 +127,8 @@ def run(self): try: image_path = nexa_prompt("Image Path (required): ") if not os.path.exists(image_path): - print(f"Image path: {image_path} not found, running omni VLM without image input.") + print(f"Image path: {image_path} not found, exiting...") + exit(1) # Skip user input for OCR version user_input = "" if self.omni_vlm_version == "vlm-81-ocr" else nexa_prompt() @@ -134,12 +136,15 @@ def run(self): style="default", message="" ) + first_chunk = True + for chunk in self.inference_streaming(user_input, image_path): + if first_chunk: + stop_spinner(stop_event, spinner_thread) + first_chunk = False + if chunk == '\n': + chunk = '' + print(chunk, end='', flush=True) - response = self.inference(user_input, image_path) - - stop_spinner(stop_event, spinner_thread) - - print(f"\nResponse: {response}") except KeyboardInterrupt: print("\nExiting...") break @@ -159,6 +164,20 @@ def inference(self, prompt: str, image_path: str): return decoded_response + def inference_streaming(self, prompt: str, image_path: str): + with suppress_stdout_stderr(): + prompt = ctypes.c_char_p(prompt.encode("utf-8")) + image_path = ctypes.c_char_p(image_path.encode("utf-8")) + oss = omni_vlm_cpp.omnivlm_inference_streaming(prompt, image_path) + + res = 0 + while res >= 0: + res = omni_vlm_cpp.sample(oss) + res_str = omni_vlm_cpp.get_str(oss).decode('utf-8') + if '<|im_start|>' in res_str or '' in res_str: + continue + yield res_str + def __del__(self): omni_vlm_cpp.omnivlm_free() @@ -218,4 +237,4 @@ def run_streamlit(self, model_path: str, is_local_path = False, hf = False, proj if args.streamlit: inference.run_streamlit(model_path) else: - inference.run() \ No newline at end of file + inference.run()