Skip to content

Commit

Permalink
Merge pull request #322 from NexaAI/brian/siglip
Browse files Browse the repository at this point in the history
Brian/siglip
  • Loading branch information
JoyboyBrian authored Dec 17, 2024
2 parents 88f04dd + ede2827 commit b4f651b
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 9 deletions.
26 changes: 25 additions & 1 deletion nexa/gguf/llama/omni_vlm_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,33 @@ def omnivlm_inference(prompt: omni_char_p, image_path: omni_char_p):
_lib.omnivlm_inference.restype = omni_char_p


def omnivlm_inference_streaming(prompt: omni_char_p, image_path: omni_char_p):
return _lib.omnivlm_inference_streaming(prompt, image_path)


_lib.omnivlm_inference_streaming.argtypes = [omni_char_p, omni_char_p]
_lib.omnivlm_inference_streaming.restype = ctypes.c_void_p


def sample(omni_vlm_streaming: ctypes.c_void_p):
return _lib.sample(omni_vlm_streaming)


_lib.sample.argtypes = [ctypes.c_void_p]
_lib.sample.restype = ctypes.c_int32


def get_str(omni_vlm_streaming: ctypes.c_void_p):
return _lib.get_str(omni_vlm_streaming)


_lib.get_str.argtypes = [ctypes.c_void_p]
_lib.get_str.restype = ctypes.c_char_p


def omnivlm_free():
return _lib.omnivlm_free()


_lib.omnivlm_free.argtypes = []
_lib.omnivlm_free.restype = None
_lib.omnivlm_free.restype = None
33 changes: 26 additions & 7 deletions nexa/gguf/nexa_inference_vlm_omni.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import ctypes
import logging
import os
Expand Down Expand Up @@ -126,20 +127,24 @@ def run(self):
try:
image_path = nexa_prompt("Image Path (required): ")
if not os.path.exists(image_path):
print(f"Image path: {image_path} not found, running omni VLM without image input.")
print(f"Image path: {image_path} not found, exiting...")
exit(1)
# Skip user input for OCR version
user_input = "" if self.omni_vlm_version == "vlm-81-ocr" else nexa_prompt()

stop_event, spinner_thread = start_spinner(
style="default",
message=""
)
first_chunk = True
for chunk in self.inference_streaming(user_input, image_path):
if first_chunk:
stop_spinner(stop_event, spinner_thread)
first_chunk = False
if chunk == '\n':
chunk = ''
print(chunk, end='', flush=True)

response = self.inference(user_input, image_path)

stop_spinner(stop_event, spinner_thread)

print(f"\nResponse: {response}")
except KeyboardInterrupt:
print("\nExiting...")
break
Expand All @@ -159,6 +164,20 @@ def inference(self, prompt: str, image_path: str):

return decoded_response

def inference_streaming(self, prompt: str, image_path: str):
with suppress_stdout_stderr():
prompt = ctypes.c_char_p(prompt.encode("utf-8"))
image_path = ctypes.c_char_p(image_path.encode("utf-8"))
oss = omni_vlm_cpp.omnivlm_inference_streaming(prompt, image_path)

res = 0
while res >= 0:
res = omni_vlm_cpp.sample(oss)
res_str = omni_vlm_cpp.get_str(oss).decode('utf-8')
if '<|im_start|>' in res_str or '</s>' in res_str:
continue
yield res_str

def __del__(self):
omni_vlm_cpp.omnivlm_free()

Expand Down Expand Up @@ -218,4 +237,4 @@ def run_streamlit(self, model_path: str, is_local_path = False, hf = False, proj
if args.streamlit:
inference.run_streamlit(model_path)
else:
inference.run()
inference.run()

0 comments on commit b4f651b

Please sign in to comment.