Skip to content

Commit

Permalink
nexa service support omnivlm now!! also support streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
JoyboyBrian committed Dec 18, 2024
1 parent ede2827 commit eae77c9
Showing 1 changed file with 177 additions and 33 deletions.
210 changes: 177 additions & 33 deletions nexa/gguf/server/nexa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
NEXA_RUN_CHAT_TEMPLATE_MAP,
NEXA_RUN_MODEL_MAP_VLM,
NEXA_RUN_PROJECTOR_MAP,
NEXA_RUN_OMNI_VLM_MAP,
NEXA_RUN_OMNI_VLM_PROJECTOR_MAP,
NEXA_RUN_COMPLETION_TEMPLATE_MAP,
NEXA_RUN_MODEL_PRECISION_MAP,
NEXA_RUN_MODEL_MAP_FUNCTION_CALLING,
Expand All @@ -38,6 +40,7 @@
from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr
from nexa.general import pull_model
from nexa.gguf.llama.llama import Llama
from nexa.gguf.nexa_inference_vlm_omni import NexaOmniVlmInference
from nexa.gguf.sd.stable_diffusion import StableDiffusion
from faster_whisper import WhisperModel
import argparse
Expand Down Expand Up @@ -231,9 +234,15 @@ async def load_model():
raise ValueError("Multimodal and Audio models are not supported for Hugging Face")
downloaded_path, _ = pull_model(model_path, hf=is_huggingface, ms=is_modelscope)
else:
if model_path in NEXA_RUN_MODEL_MAP_VLM: # for Multimodal models
downloaded_path, _ = pull_model(NEXA_RUN_MODEL_MAP_VLM[model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[model_path])
if model_path in NEXA_RUN_MODEL_MAP_VLM or model_path in NEXA_RUN_OMNI_VLM_MAP:
if model_path in NEXA_RUN_OMNI_VLM_MAP:
logging.info(f"Path is OmniVLM model: {model_path}")
downloaded_path, _ = pull_model(NEXA_RUN_OMNI_VLM_MAP[model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[model_path])
else:
logging.info(f"Path is in NEXA_RUN_MODEL_MAP_VLM: {model_path}")
downloaded_path, _ = pull_model(NEXA_RUN_MODEL_MAP_VLM[model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[model_path])
model_type = "Multimodal"
else:
downloaded_path, model_type = pull_model(model_path)
Expand Down Expand Up @@ -321,35 +330,50 @@ async def load_model():
logging.info(f"model loaded as {model}")
elif model_type == "Multimodal":
with suppress_stdout_stderr():
projector_handler = NEXA_PROJECTOR_HANDLER_MAP.get(model_path, Llava15ChatHandler)
projector = (projector_handler(
clip_model_path=projector_downloaded_path, verbose=False
) if projector_downloaded_path else None)

chat_format = NEXA_RUN_CHAT_TEMPLATE_MAP.get(model_path, None)
try:
model = Llama(
model_path=downloaded_path,
chat_handler=projector,
verbose=False,
chat_format=chat_format,
n_ctx=n_ctx,
n_gpu_layers=-1 if is_gpu_available() else 0,
)
except Exception as e:
logging.error(
f"Failed to load model: {e}. Falling back to CPU.",
exc_info=True,
)
model = Llama(
model_path=downloaded_path,
chat_handler=projector,
verbose=False,
chat_format=chat_format,
n_ctx=n_ctx,
n_gpu_layers=0, # hardcode to use CPU
)

if 'omni' in model_path.lower():
try:
model = NexaOmniVlmInference(
model_path=model_path,
device="gpu" if is_gpu_available() else "cpu"
)
except Exception as e:
logging.error(
f"Failed to load OmniVLM model: {e}. Falling back to CPU.",
exc_info=True,
)
model = NexaOmniVlmInference(
model_path=model_path,
device="cpu"
)
else:
projector_handler = NEXA_PROJECTOR_HANDLER_MAP.get(model_path, Llava15ChatHandler)
projector = (projector_handler(
clip_model_path=projector_downloaded_path, verbose=False
) if projector_downloaded_path else None)

chat_format = NEXA_RUN_CHAT_TEMPLATE_MAP.get(model_path, None)
try:
model = Llama(
model_path=downloaded_path,
chat_handler=projector,
verbose=False,
chat_format=chat_format,
n_ctx=n_ctx,
n_gpu_layers=-1 if is_gpu_available() else 0,
)
except Exception as e:
logging.error(
f"Failed to load model: {e}. Falling back to CPU.",
exc_info=True,
)
model = Llama(
model_path=downloaded_path,
chat_handler=projector,
verbose=False,
chat_format=chat_format,
n_ctx=n_ctx,
n_gpu_layers=0, # hardcode to use CPU
)
logging.info(f"Model loaded as {model}")
elif model_type == "Audio":
with suppress_stdout_stderr():
Expand Down Expand Up @@ -754,7 +778,7 @@ async def text_chat_completions(request: ChatCompletionRequest):
async def multimodal_chat_completions(request: VLMChatCompletionRequest):
"""Endpoint for multimodal chat completions using VLM models"""
try:
if model_type != "Multimodal":
if model_type != "Multimodal" or 'omni' in model_path.lower():
raise HTTPException(
status_code=400,
detail="The model that is loaded is not a Multimodal model. Please use a Multimodal model (e.g. nanollava) for VLM."
Expand All @@ -770,6 +794,7 @@ async def multimodal_chat_completions(request: VLMChatCompletionRequest):
elif isinstance(item, ImageUrlContent):
try:
image_data_uri = process_image_input(item.image_url)
logging.info(f"BRIAN: image_data_uri: {image_data_uri}")
processed_content.append({
"type": "image_url",
"image_url": {"url": image_data_uri}
Expand Down Expand Up @@ -800,6 +825,125 @@ async def multimodal_chat_completions(request: VLMChatCompletionRequest):
logging.error(f"Error in multimodal chat completions: {e}")
raise HTTPException(status_code=500, detail=str(e))

async def _resp_omnivlm_async_generator(model, prompt: str, image_path: str):
_id = str(uuid.uuid4())
try:
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")

for token in model.inference_streaming(prompt, image_path):
chunk = {
"id": _id,
"object": "chat.completion.chunk",
"created": time.time(),
"choices": [{
"delta": {"content": token},
"index": 0,
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logging.error(f"Error in OmniVLM streaming: {e}")
raise

@app.post("/v1/omnivlm/chat/completions", tags=["Multimodal"])
async def omnivlm_chat_completions(request: VLMChatCompletionRequest):
"""Endpoint for Multimodal chat completions using OmniVLM models"""
temp_file = None
image_path = None

try:
if model_type != "Multimodal" or 'omni' not in model_path.lower():
raise HTTPException(
status_code=400,
detail="Please use an OmniVLM model for this endpoint."
)

prompt = ""
last_message = request.messages[-1]

if isinstance(last_message.content, list):
for item in last_message.content:
if isinstance(item, TextContent):
prompt = item.text
elif isinstance(item, ImageUrlContent):
try:
base64_image = process_image_input(item.image_url)
base64_data = base64_image.split(',')[1]
image_data = base64.b64decode(base64_data)

temp_file = tempfile.NamedTemporaryFile(suffix='.jpg', delete=False)
temp_file.write(image_data)
temp_file.flush()
os.fsync(temp_file.fileno())
temp_file.close()

image_path = temp_file.name

if not os.path.exists(image_path):
raise ValueError(f"Failed to create temporary file at {image_path}")

except Exception as e:
if temp_file and os.path.exists(temp_file.name):
os.unlink(temp_file.name)
raise ValueError(f"Failed to process image: {str(e)}")
else:
raise ValueError("Either url or path must be provided for image")
else:
prompt = last_message.content

if not image_path:
raise HTTPException(
status_code=400,
detail="Image is required for OmniVLM inference"
)

if request.stream:
async def stream_with_cleanup():
try:
async for chunk in _resp_omnivlm_async_generator(model, prompt, image_path):
yield chunk
finally:
if image_path and os.path.exists(image_path):
try:
os.unlink(image_path)
except Exception as e:
logging.error(f"Error cleaning up file {image_path}: {e}")

return StreamingResponse(
stream_with_cleanup(),
media_type="text/event-stream"
)
else:
try:
response = model.inference(prompt, image_path)
return {
"id": str(uuid.uuid4()),
"object": "chat.completion",
"created": time.time(),
"choices": [{
"message": {"role": "assistant", "content": response},
"index": 0,
"finish_reason": "stop"
}],
}
finally:
if image_path and os.path.exists(image_path):
os.unlink(image_path)

except Exception as e:
if image_path and os.path.exists(image_path):
try:
os.unlink(image_path)
except Exception as cleanup_error:
logging.error(f"Error cleaning up file {image_path}: {cleanup_error}")

if isinstance(e, HTTPException):
raise e
logging.error(f"Error in OmniVLM chat completions: {e}")
raise HTTPException(status_code=500, detail=str(e))

@app.post("/v1/function-calling", tags=["NLP"])
async def function_call(request: FunctionCallRequest):
Expand Down

0 comments on commit eae77c9

Please sign in to comment.