Skip to content

Commit

Permalink
support streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
JoyboyBrian committed Dec 17, 2024
1 parent 476cb1a commit a6e4a87
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions nexa/gguf/nexa_inference_vlm_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,21 +127,23 @@ def run(self):
try:
image_path = nexa_prompt("Image Path (required): ")
if not os.path.exists(image_path):
print(f"Image path: {image_path} not found, running omni VLM without image input.")
print(f"Image path: {image_path} not found, exiting...")
exit(1)
# Skip user input for OCR version
user_input = "" if self.omni_vlm_version == "vlm-81-ocr" else nexa_prompt()

stop_event, spinner_thread = start_spinner(
style="default",
message=""
)

# response = self.inference(user_input, image_path)
response = self.inference_streaming(user_input, image_path)

stop_spinner(stop_event, spinner_thread)

print(f"\nResponse: {response}")
first_chunk = True
for chunk in self.inference_streaming(user_input, image_path):
if first_chunk:
stop_spinner(stop_event, spinner_thread)
first_chunk = False
if chunk == '\n':
chunk = ''
print(chunk, end='', flush=True)

except KeyboardInterrupt:
print("\nExiting...")
Expand All @@ -163,19 +165,18 @@ def inference(self, prompt: str, image_path: str):
return decoded_response

def inference_streaming(self, prompt: str, image_path: str):
response = ""
with suppress_stdout_stderr():
prompt = ctypes.c_char_p(prompt.encode("utf-8"))
image_path = ctypes.c_char_p(image_path.encode("utf-8"))
oss = omni_vlm_cpp.omnivlm_inference_streaming(prompt, image_path)
res = 0
while res >= 0:
res = omni_vlm_cpp.sample(oss)
res_str = omni_vlm_cpp.get_str(oss).decode('utf-8')
print(res_str, flush=True)
response += res_str
return response

res = 0
while res >= 0:
res = omni_vlm_cpp.sample(oss)
res_str = omni_vlm_cpp.get_str(oss).decode('utf-8')
if '<|im_start|>' in res_str or '</s>' in res_str:
continue
yield res_str

def __del__(self):
omni_vlm_cpp.omnivlm_free()
Expand Down

0 comments on commit a6e4a87

Please sign in to comment.