diff --git a/nexa/gguf/server/nexa_service.py b/nexa/gguf/server/nexa_service.py index 0b38fb79..e28d565f 100644 --- a/nexa/gguf/server/nexa_service.py +++ b/nexa/gguf/server/nexa_service.py @@ -24,6 +24,8 @@ NEXA_RUN_CHAT_TEMPLATE_MAP, NEXA_RUN_MODEL_MAP_VLM, NEXA_RUN_PROJECTOR_MAP, + NEXA_RUN_OMNI_VLM_MAP, + NEXA_RUN_OMNI_VLM_PROJECTOR_MAP, NEXA_RUN_COMPLETION_TEMPLATE_MAP, NEXA_RUN_MODEL_PRECISION_MAP, NEXA_RUN_MODEL_MAP_FUNCTION_CALLING, @@ -38,6 +40,7 @@ from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr from nexa.general import pull_model from nexa.gguf.llama.llama import Llama +from nexa.gguf.nexa_inference_vlm_omni import NexaOmniVlmInference from nexa.gguf.sd.stable_diffusion import StableDiffusion from faster_whisper import WhisperModel import argparse @@ -231,9 +234,15 @@ async def load_model(): raise ValueError("Multimodal and Audio models are not supported for Hugging Face") downloaded_path, _ = pull_model(model_path, hf=is_huggingface, ms=is_modelscope) else: - if model_path in NEXA_RUN_MODEL_MAP_VLM: # for Multimodal models - downloaded_path, _ = pull_model(NEXA_RUN_MODEL_MAP_VLM[model_path]) - projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[model_path]) + if model_path in NEXA_RUN_MODEL_MAP_VLM or model_path in NEXA_RUN_OMNI_VLM_MAP: + if model_path in NEXA_RUN_OMNI_VLM_MAP: + logging.info(f"Path is OmniVLM model: {model_path}") + downloaded_path, _ = pull_model(NEXA_RUN_OMNI_VLM_MAP[model_path]) + projector_downloaded_path, _ = pull_model(NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[model_path]) + else: + logging.info(f"Path is in NEXA_RUN_MODEL_MAP_VLM: {model_path}") + downloaded_path, _ = pull_model(NEXA_RUN_MODEL_MAP_VLM[model_path]) + projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[model_path]) model_type = "Multimodal" else: downloaded_path, model_type = pull_model(model_path) @@ -321,35 +330,50 @@ async def load_model(): logging.info(f"model loaded as {model}") elif model_type == "Multimodal": with suppress_stdout_stderr(): - projector_handler = NEXA_PROJECTOR_HANDLER_MAP.get(model_path, Llava15ChatHandler) - projector = (projector_handler( - clip_model_path=projector_downloaded_path, verbose=False - ) if projector_downloaded_path else None) - - chat_format = NEXA_RUN_CHAT_TEMPLATE_MAP.get(model_path, None) - try: - model = Llama( - model_path=downloaded_path, - chat_handler=projector, - verbose=False, - chat_format=chat_format, - n_ctx=n_ctx, - n_gpu_layers=-1 if is_gpu_available() else 0, - ) - except Exception as e: - logging.error( - f"Failed to load model: {e}. Falling back to CPU.", - exc_info=True, - ) - model = Llama( - model_path=downloaded_path, - chat_handler=projector, - verbose=False, - chat_format=chat_format, - n_ctx=n_ctx, - n_gpu_layers=0, # hardcode to use CPU - ) - + if 'omni' in model_path.lower(): + try: + model = NexaOmniVlmInference( + model_path=model_path, + device="gpu" if is_gpu_available() else "cpu" + ) + except Exception as e: + logging.error( + f"Failed to load OmniVLM model: {e}. Falling back to CPU.", + exc_info=True, + ) + model = NexaOmniVlmInference( + model_path=model_path, + device="cpu" + ) + else: + projector_handler = NEXA_PROJECTOR_HANDLER_MAP.get(model_path, Llava15ChatHandler) + projector = (projector_handler( + clip_model_path=projector_downloaded_path, verbose=False + ) if projector_downloaded_path else None) + + chat_format = NEXA_RUN_CHAT_TEMPLATE_MAP.get(model_path, None) + try: + model = Llama( + model_path=downloaded_path, + chat_handler=projector, + verbose=False, + chat_format=chat_format, + n_ctx=n_ctx, + n_gpu_layers=-1 if is_gpu_available() else 0, + ) + except Exception as e: + logging.error( + f"Failed to load model: {e}. Falling back to CPU.", + exc_info=True, + ) + model = Llama( + model_path=downloaded_path, + chat_handler=projector, + verbose=False, + chat_format=chat_format, + n_ctx=n_ctx, + n_gpu_layers=0, # hardcode to use CPU + ) logging.info(f"Model loaded as {model}") elif model_type == "Audio": with suppress_stdout_stderr(): @@ -754,7 +778,7 @@ async def text_chat_completions(request: ChatCompletionRequest): async def multimodal_chat_completions(request: VLMChatCompletionRequest): """Endpoint for multimodal chat completions using VLM models""" try: - if model_type != "Multimodal": + if model_type != "Multimodal" or 'omni' in model_path.lower(): raise HTTPException( status_code=400, detail="The model that is loaded is not a Multimodal model. Please use a Multimodal model (e.g. nanollava) for VLM." @@ -770,6 +794,7 @@ async def multimodal_chat_completions(request: VLMChatCompletionRequest): elif isinstance(item, ImageUrlContent): try: image_data_uri = process_image_input(item.image_url) + logging.info(f"BRIAN: image_data_uri: {image_data_uri}") processed_content.append({ "type": "image_url", "image_url": {"url": image_data_uri} @@ -800,6 +825,125 @@ async def multimodal_chat_completions(request: VLMChatCompletionRequest): logging.error(f"Error in multimodal chat completions: {e}") raise HTTPException(status_code=500, detail=str(e)) +async def _resp_omnivlm_async_generator(model, prompt: str, image_path: str): + _id = str(uuid.uuid4()) + try: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + + for token in model.inference_streaming(prompt, image_path): + chunk = { + "id": _id, + "object": "chat.completion.chunk", + "created": time.time(), + "choices": [{ + "delta": {"content": token}, + "index": 0, + "finish_reason": None + }] + } + yield f"data: {json.dumps(chunk)}\n\n" + yield "data: [DONE]\n\n" + except Exception as e: + logging.error(f"Error in OmniVLM streaming: {e}") + raise + +@app.post("/v1/omnivlm/chat/completions", tags=["Multimodal"]) +async def omnivlm_chat_completions(request: VLMChatCompletionRequest): + """Endpoint for Multimodal chat completions using OmniVLM models""" + temp_file = None + image_path = None + + try: + if model_type != "Multimodal" or 'omni' not in model_path.lower(): + raise HTTPException( + status_code=400, + detail="Please use an OmniVLM model for this endpoint." + ) + + prompt = "" + last_message = request.messages[-1] + + if isinstance(last_message.content, list): + for item in last_message.content: + if isinstance(item, TextContent): + prompt = item.text + elif isinstance(item, ImageUrlContent): + try: + base64_image = process_image_input(item.image_url) + base64_data = base64_image.split(',')[1] + image_data = base64.b64decode(base64_data) + + temp_file = tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) + temp_file.write(image_data) + temp_file.flush() + os.fsync(temp_file.fileno()) + temp_file.close() + + image_path = temp_file.name + + if not os.path.exists(image_path): + raise ValueError(f"Failed to create temporary file at {image_path}") + + except Exception as e: + if temp_file and os.path.exists(temp_file.name): + os.unlink(temp_file.name) + raise ValueError(f"Failed to process image: {str(e)}") + else: + raise ValueError("Either url or path must be provided for image") + else: + prompt = last_message.content + + if not image_path: + raise HTTPException( + status_code=400, + detail="Image is required for OmniVLM inference" + ) + + if request.stream: + async def stream_with_cleanup(): + try: + async for chunk in _resp_omnivlm_async_generator(model, prompt, image_path): + yield chunk + finally: + if image_path and os.path.exists(image_path): + try: + os.unlink(image_path) + except Exception as e: + logging.error(f"Error cleaning up file {image_path}: {e}") + + return StreamingResponse( + stream_with_cleanup(), + media_type="text/event-stream" + ) + else: + try: + response = model.inference(prompt, image_path) + return { + "id": str(uuid.uuid4()), + "object": "chat.completion", + "created": time.time(), + "choices": [{ + "message": {"role": "assistant", "content": response}, + "index": 0, + "finish_reason": "stop" + }], + } + finally: + if image_path and os.path.exists(image_path): + os.unlink(image_path) + + except Exception as e: + if image_path and os.path.exists(image_path): + try: + os.unlink(image_path) + except Exception as cleanup_error: + logging.error(f"Error cleaning up file {image_path}: {cleanup_error}") + + if isinstance(e, HTTPException): + raise e + logging.error(f"Error in OmniVLM chat completions: {e}") + raise HTTPException(status_code=500, detail=str(e)) @app.post("/v1/function-calling", tags=["NLP"]) async def function_call(request: FunctionCallRequest):