From 396f876c67c13c3ce100140ba5ced6d4f5b0cedc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E4=B8=BA?= Date: Tue, 17 Dec 2024 13:05:50 +0800 Subject: [PATCH 1/3] omni-vlm streaming support --- dependency/llama.cpp | 2 +- nexa/gguf/llama/omni_vlm_cpp.py | 18 +++++++++++++++++- nexa/gguf/nexa_inference_vlm_omni.py | 19 +++++++++++++++++-- 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/dependency/llama.cpp b/dependency/llama.cpp index 5859733c..8c782fad 160000 --- a/dependency/llama.cpp +++ b/dependency/llama.cpp @@ -1 +1 @@ -Subproject commit 5859733cb86dbd61091a8045203b33df056e536a +Subproject commit 8c782fad0ac22c186ce19ff4e50d98d1eeffb38d diff --git a/nexa/gguf/llama/omni_vlm_cpp.py b/nexa/gguf/llama/omni_vlm_cpp.py index ab7a5699..2a766d09 100644 --- a/nexa/gguf/llama/omni_vlm_cpp.py +++ b/nexa/gguf/llama/omni_vlm_cpp.py @@ -71,9 +71,25 @@ def omnivlm_inference(prompt: omni_char_p, image_path: omni_char_p): _lib.omnivlm_inference.restype = omni_char_p +def omnivlm_inference_streaming(prompt: omni_char_p, image_path: omni_char_p): + return _lib.omnivlm_inference_streaming(prompt, image_path) + + +_lib.omnivlm_inference_streaming.argtypes = [omni_char_p, omni_char_p] +_lib.omnivlm_inference_streaming.restype = ctypes.c_void_p + + +def sample(omni_vlm_streaming: ctypes.c_void_p): + return _lib.sample(omni_vlm_streaming) + + +_lib.sample.argtypes = [ctypes.c_void_p] +_lib.sample.restype = ctypes.c_int32 + + def omnivlm_free(): return _lib.omnivlm_free() _lib.omnivlm_free.argtypes = [] -_lib.omnivlm_free.restype = None \ No newline at end of file +_lib.omnivlm_free.restype = None diff --git a/nexa/gguf/nexa_inference_vlm_omni.py b/nexa/gguf/nexa_inference_vlm_omni.py index 4a76a4eb..b9f39666 100644 --- a/nexa/gguf/nexa_inference_vlm_omni.py +++ b/nexa/gguf/nexa_inference_vlm_omni.py @@ -135,11 +135,13 @@ def run(self): message="" ) - response = self.inference(user_input, image_path) + # response = self.inference(user_input, image_path) + response = self.inference_streaming(user_input, image_path) stop_spinner(stop_event, spinner_thread) print(f"\nResponse: {response}") + except KeyboardInterrupt: print("\nExiting...") break @@ -159,6 +161,19 @@ def inference(self, prompt: str, image_path: str): return decoded_response + def inference_streaming(self, prompt: str, image_path: str): + response = [] + with suppress_stdout_stderr(): + prompt = ctypes.c_char_p(prompt.encode("utf-8")) + image_path = ctypes.c_char_p(image_path.encode("utf-8")) + oss = omni_vlm_cpp.omnivlm_inference_streaming(prompt, image_path) + + res = 0 + while res >= 0: + res = omni_vlm_cpp.sample(oss) + response.append(res) + return response + def __del__(self): omni_vlm_cpp.omnivlm_free() @@ -218,4 +233,4 @@ def run_streamlit(self, model_path: str, is_local_path = False, hf = False, proj if args.streamlit: inference.run_streamlit(model_path) else: - inference.run() \ No newline at end of file + inference.run() From 476cb1a18931add828656389db0e51eeb3f5430a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E4=B8=BA?= Date: Tue, 17 Dec 2024 14:03:33 +0800 Subject: [PATCH 2/3] add get_str() interface --- dependency/llama.cpp | 2 +- nexa/gguf/llama/omni_vlm_cpp.py | 8 ++++++++ nexa/gguf/nexa_inference_vlm_omni.py | 7 +++++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/dependency/llama.cpp b/dependency/llama.cpp index 8c782fad..868ceb69 160000 --- a/dependency/llama.cpp +++ b/dependency/llama.cpp @@ -1 +1 @@ -Subproject commit 8c782fad0ac22c186ce19ff4e50d98d1eeffb38d +Subproject commit 868ceb6965faa43de7ab9bc3efd97104185406c8 diff --git a/nexa/gguf/llama/omni_vlm_cpp.py b/nexa/gguf/llama/omni_vlm_cpp.py index 2a766d09..e6d4312a 100644 --- a/nexa/gguf/llama/omni_vlm_cpp.py +++ b/nexa/gguf/llama/omni_vlm_cpp.py @@ -87,6 +87,14 @@ def sample(omni_vlm_streaming: ctypes.c_void_p): _lib.sample.restype = ctypes.c_int32 +def get_str(omni_vlm_streaming: ctypes.c_void_p): + return _lib.get_str(omni_vlm_streaming) + + +_lib.get_str.argtypes = [ctypes.c_void_p] +_lib.get_str.restype = ctypes.c_char_p + + def omnivlm_free(): return _lib.omnivlm_free() diff --git a/nexa/gguf/nexa_inference_vlm_omni.py b/nexa/gguf/nexa_inference_vlm_omni.py index b9f39666..d5a0221b 100644 --- a/nexa/gguf/nexa_inference_vlm_omni.py +++ b/nexa/gguf/nexa_inference_vlm_omni.py @@ -1,3 +1,4 @@ +import time import ctypes import logging import os @@ -162,7 +163,7 @@ def inference(self, prompt: str, image_path: str): return decoded_response def inference_streaming(self, prompt: str, image_path: str): - response = [] + response = "" with suppress_stdout_stderr(): prompt = ctypes.c_char_p(prompt.encode("utf-8")) image_path = ctypes.c_char_p(image_path.encode("utf-8")) @@ -171,7 +172,9 @@ def inference_streaming(self, prompt: str, image_path: str): res = 0 while res >= 0: res = omni_vlm_cpp.sample(oss) - response.append(res) + res_str = omni_vlm_cpp.get_str(oss).decode('utf-8') + print(res_str, flush=True) + response += res_str return response def __del__(self): From a6e4a875bfebf3e93de5c336fd35866db8653cbe Mon Sep 17 00:00:00 2001 From: JoyboyBrian Date: Tue, 17 Dec 2024 07:11:26 +0000 Subject: [PATCH 3/3] support streaming --- nexa/gguf/nexa_inference_vlm_omni.py | 35 ++++++++++++++-------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/nexa/gguf/nexa_inference_vlm_omni.py b/nexa/gguf/nexa_inference_vlm_omni.py index d5a0221b..2f111847 100644 --- a/nexa/gguf/nexa_inference_vlm_omni.py +++ b/nexa/gguf/nexa_inference_vlm_omni.py @@ -127,7 +127,8 @@ def run(self): try: image_path = nexa_prompt("Image Path (required): ") if not os.path.exists(image_path): - print(f"Image path: {image_path} not found, running omni VLM without image input.") + print(f"Image path: {image_path} not found, exiting...") + exit(1) # Skip user input for OCR version user_input = "" if self.omni_vlm_version == "vlm-81-ocr" else nexa_prompt() @@ -135,13 +136,14 @@ def run(self): style="default", message="" ) - - # response = self.inference(user_input, image_path) - response = self.inference_streaming(user_input, image_path) - - stop_spinner(stop_event, spinner_thread) - - print(f"\nResponse: {response}") + first_chunk = True + for chunk in self.inference_streaming(user_input, image_path): + if first_chunk: + stop_spinner(stop_event, spinner_thread) + first_chunk = False + if chunk == '\n': + chunk = '' + print(chunk, end='', flush=True) except KeyboardInterrupt: print("\nExiting...") @@ -163,19 +165,18 @@ def inference(self, prompt: str, image_path: str): return decoded_response def inference_streaming(self, prompt: str, image_path: str): - response = "" with suppress_stdout_stderr(): prompt = ctypes.c_char_p(prompt.encode("utf-8")) image_path = ctypes.c_char_p(image_path.encode("utf-8")) oss = omni_vlm_cpp.omnivlm_inference_streaming(prompt, image_path) - - res = 0 - while res >= 0: - res = omni_vlm_cpp.sample(oss) - res_str = omni_vlm_cpp.get_str(oss).decode('utf-8') - print(res_str, flush=True) - response += res_str - return response + + res = 0 + while res >= 0: + res = omni_vlm_cpp.sample(oss) + res_str = omni_vlm_cpp.get_str(oss).decode('utf-8') + if '<|im_start|>' in res_str or '' in res_str: + continue + yield res_str def __del__(self): omni_vlm_cpp.omnivlm_free()