diff --git a/dependency/llama.cpp b/dependency/llama.cpp index 5859733c..47c4cad9 160000 --- a/dependency/llama.cpp +++ b/dependency/llama.cpp @@ -1 +1 @@ -Subproject commit 5859733cb86dbd61091a8045203b33df056e536a +Subproject commit 47c4cad9fb6569a821e93b451e652b51faf25305 diff --git a/nexa/cli/entry.py b/nexa/cli/entry.py index 59449bde..5ec322ad 100644 --- a/nexa/cli/entry.py +++ b/nexa/cli/entry.py @@ -152,7 +152,7 @@ def run_ggml_server(args): from nexa.gguf.server.nexa_service import run_nexa_ai_service as NexaServer kwargs = {k: v for k, v in vars(args).items() if v is not None} - model_path = kwargs.pop("model_path") + model_path = kwargs.pop("model_path", None) is_local_path = kwargs.pop("local_path", False) model_type = kwargs.pop("model_type", None) hf = kwargs.pop('huggingface', False) @@ -272,6 +272,15 @@ def run_eval_tasks(args): print("Please run: pip install 'nexaai[eval]'") return +def run_siglip_server(args): + from nexa.siglip.nexa_siglip_server import run_nexa_ai_siglip_service + run_nexa_ai_siglip_service( + image_dir=args.image_dir, + host=args.host, + port=args.port, + reload=args.reload + ) + def run_embedding_generation(args): kwargs = {k: v for k, v in vars(args).items() if v is not None} model_path = kwargs.pop("model_path") @@ -556,8 +565,8 @@ def main(): quantization_parser.add_argument("--keep_split", action="store_true", help="Quantize to the same number of shards") # GGML server parser - server_parser = subparsers.add_parser("server", help="Run the Nexa AI Text Generation Service") - server_parser.add_argument("model_path", type=str, nargs='?', help="Path or identifier for the model in Nexa Model Hub") + server_parser = subparsers.add_parser("server", help="Run the Nexa AI local service") + server_parser.add_argument("--model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") server_parser.add_argument("-lp", "--local_path", action="store_true", help="Indicate that the model path provided is the local path") server_parser.add_argument("-mt", "--model_type", type=str, choices=[e.name for e in ModelType], help="Indicate the model running type, must be used with -lp, -hf or -ms") server_parser.add_argument("-hf", "--huggingface", action="store_true", help="Load model from Hugging Face Hub") @@ -599,6 +608,13 @@ def main(): perf_eval_group.add_argument("--device", type=str, help="Device to run performance evaluation on, choose from 'cpu', 'cuda', 'mps'", default="cpu") perf_eval_group.add_argument("--new_tokens", type=int, help="Number of new tokens to evaluate", default=100) + # Siglip Server + siglip_parser = subparsers.add_parser("siglip", help="Run the Nexa AI SigLIP Service") + siglip_parser.add_argument("--image_dir", type=str, help="Directory of images to load") + siglip_parser.add_argument("--host", type=str, default="localhost", help="Host to bind the server to") + siglip_parser.add_argument("--port", type=int, default=8100, help="Port to bind the server to") + siglip_parser.add_argument("--reload", action="store_true", help="Enable automatic reloading on code changes") + args = parser.parse_args() if args.command == "run": @@ -627,6 +643,8 @@ def main(): run_onnx_inference(args) elif args.command == "eval": run_eval_tasks(args) + elif args.command == "siglip": + run_siglip_server(args) elif args.command == "embed": run_embedding_generation(args) elif args.command == "pull": diff --git a/nexa/constants.py b/nexa/constants.py index 5f6beb69..fe3c051e 100644 --- a/nexa/constants.py +++ b/nexa/constants.py @@ -66,6 +66,7 @@ class ModelType(Enum): "phi2": "Phi-2:q4_0", "phi3": "Phi-3-mini-128k-instruct:q4_0", "phi3.5": "Phi-3.5-mini-instruct:q4_0", + "phi4": "Phi:q4_0", "llama2-uncensored": "Llama2-7b-chat-uncensored:q4_0", "llama3-uncensored": "Llama3-8B-Lexi-Uncensored:q4_K_M", "openelm": "OpenELM-3B:q4_K_M", @@ -413,6 +414,7 @@ class ModelType(Enum): "Phi-3-mini-128k-instruct": ModelType.NLP, "Phi-3-mini-4k-instruct": ModelType.NLP, "Phi-3.5-mini-instruct": ModelType.NLP, + "Phi-4": ModelType.NLP, "CodeQwen1.5-7B-Instruct": ModelType.NLP, "Qwen2-0.5B-Instruct": ModelType.NLP, "Qwen2-1.5B-Instruct": ModelType.NLP, diff --git a/nexa/eval/nexa_eval.py b/nexa/eval/nexa_eval.py index 7f24829d..14815be8 100644 --- a/nexa/eval/nexa_eval.py +++ b/nexa/eval/nexa_eval.py @@ -12,7 +12,6 @@ from nexa.eval import evaluator from nexa.eval.nexa_task.task_manager import TaskManager from nexa.eval.utils import make_table, handle_non_serializable -from nexa.gguf.server.nexa_service import run_nexa_ai_service as NexaServer from nexa.constants import NEXA_MODEL_EVAL_RESULTS_PATH, NEXA_RUN_MODEL_MAP from nexa.eval.nexa_perf import ( Benchmark, diff --git a/nexa/gguf/llama/audio_lm_cpp.py b/nexa/gguf/llama/audio_lm_cpp.py index 88db2a33..249cd1e8 100644 --- a/nexa/gguf/llama/audio_lm_cpp.py +++ b/nexa/gguf/llama/audio_lm_cpp.py @@ -93,6 +93,21 @@ def process_full(ctx: omni_context_p, params: omni_context_params_p, is_qwen: bo _lib = _lib_qwen if is_qwen else _lib_omni return _lib.omni_process_full(ctx, params) + +def process_streaming(ctx: omni_context_p, params: omni_context_params_p, is_qwen: bool = True): + _lib = _lib_qwen if is_qwen else _lib_omni + return _lib.omni_process_streaming(ctx, params) + + +def sample(omni_streaming: ctypes.c_void_p, is_qwen: bool = True): + _lib = _lib_qwen if is_qwen else _lib_omni + return _lib.sample(omni_streaming) + + +def get_str(omni_streaming: ctypes.c_void_p, is_qwen: bool = True): + _lib = _lib_qwen if is_qwen else _lib_omni + return _lib.get_str(omni_streaming) + # OMNI_AUDIO_API void omni_free(struct omni_context *ctx_omni); def free(ctx: omni_context_p, is_qwen: bool = True): _lib = _lib_qwen if is_qwen else _lib_omni @@ -111,6 +126,15 @@ def free(ctx: omni_context_p, is_qwen: bool = True): lib.omni_process_full.argtypes = [omni_context_p, omni_context_params_p] lib.omni_process_full.restype = ctypes.c_char_p + lib.omni_process_streaming.argtypes = [omni_context_p, omni_context_params_p] + lib.omni_process_streaming.restype = ctypes.c_void_p + + lib.sample.argtypes = [ctypes.c_void_p] + lib.sample.restype = ctypes.c_int32 + + lib.get_str.argtypes = [ctypes.c_void_p] + lib.get_str.restype = ctypes.c_char_p + # Configure free lib.omni_free.argtypes = [omni_context_p] lib.omni_free.restype = None diff --git a/nexa/gguf/llama/omni_vlm_cpp.py b/nexa/gguf/llama/omni_vlm_cpp.py index ab7a5699..e6d4312a 100644 --- a/nexa/gguf/llama/omni_vlm_cpp.py +++ b/nexa/gguf/llama/omni_vlm_cpp.py @@ -71,9 +71,33 @@ def omnivlm_inference(prompt: omni_char_p, image_path: omni_char_p): _lib.omnivlm_inference.restype = omni_char_p +def omnivlm_inference_streaming(prompt: omni_char_p, image_path: omni_char_p): + return _lib.omnivlm_inference_streaming(prompt, image_path) + + +_lib.omnivlm_inference_streaming.argtypes = [omni_char_p, omni_char_p] +_lib.omnivlm_inference_streaming.restype = ctypes.c_void_p + + +def sample(omni_vlm_streaming: ctypes.c_void_p): + return _lib.sample(omni_vlm_streaming) + + +_lib.sample.argtypes = [ctypes.c_void_p] +_lib.sample.restype = ctypes.c_int32 + + +def get_str(omni_vlm_streaming: ctypes.c_void_p): + return _lib.get_str(omni_vlm_streaming) + + +_lib.get_str.argtypes = [ctypes.c_void_p] +_lib.get_str.restype = ctypes.c_char_p + + def omnivlm_free(): return _lib.omnivlm_free() _lib.omnivlm_free.argtypes = [] -_lib.omnivlm_free.restype = None \ No newline at end of file +_lib.omnivlm_free.restype = None diff --git a/nexa/gguf/nexa_inference_audio_lm.py b/nexa/gguf/nexa_inference_audio_lm.py index 64d87e7e..9f9f8311 100644 --- a/nexa/gguf/nexa_inference_audio_lm.py +++ b/nexa/gguf/nexa_inference_audio_lm.py @@ -160,12 +160,21 @@ def run(self): ) try: - with suppress_stdout_stderr(): - response = self.inference(audio_path, user_input) + # with suppress_stdout_stderr(): + # response = self.inference(audio_path, user_input) + first_chunk = True + for chunk in self.inference_streaming(audio_path, user_input): + if first_chunk: + stop_spinner(stop_event, spinner_thread) + first_chunk = False + if chunk == '\n': + chunk = '' + # print("FUCK") + print(chunk, end='', flush=True) + print() # '\n' finally: stop_spinner(stop_event, spinner_thread) - print(f"{response}") self.cleanup() except KeyboardInterrupt: @@ -216,6 +225,42 @@ def inference(self, audio_path: str, prompt: str = "") -> str: except Exception as e: raise RuntimeError(f"Error during inference: {str(e)}") + def inference_streaming(self, audio_path: str, prompt: str = "") -> str: + """ + Perform a single inference with the audio language model. + """ + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + + try: + # Ensure audio is at 16kHz before processing + audio_path = self._ensure_16khz(audio_path) + + self.ctx_params.file = ctypes.c_char_p(audio_path.encode("utf-8")) + self.ctx_params.prompt = ctypes.c_char_p(prompt.encode("utf-8")) + + with suppress_stdout_stderr(): + self.context = audio_lm_cpp.init_context( + ctypes.byref(self.ctx_params), is_qwen=self.is_qwen + ) + if not self.context: + raise RuntimeError("Failed to load audio language model") + logging.debug("Model loaded successfully") + + oss = audio_lm_cpp.process_streaming( + self.context, ctypes.byref(self.ctx_params), is_qwen=self.is_qwen + ) + res = 0 + while res >= 0: + res = audio_lm_cpp.sample(oss) + res_str = audio_lm_cpp.get_str(oss).decode('utf-8') + + if '<|im_start|>' in res_str or '' in res_str: + continue + yield res_str + except Exception as e: + raise RuntimeError(f"Error during inference: {str(e)}") + def cleanup(self): """ Explicitly cleanup resources diff --git a/nexa/gguf/nexa_inference_vlm_omni.py b/nexa/gguf/nexa_inference_vlm_omni.py index 4a76a4eb..2f111847 100644 --- a/nexa/gguf/nexa_inference_vlm_omni.py +++ b/nexa/gguf/nexa_inference_vlm_omni.py @@ -1,3 +1,4 @@ +import time import ctypes import logging import os @@ -126,7 +127,8 @@ def run(self): try: image_path = nexa_prompt("Image Path (required): ") if not os.path.exists(image_path): - print(f"Image path: {image_path} not found, running omni VLM without image input.") + print(f"Image path: {image_path} not found, exiting...") + exit(1) # Skip user input for OCR version user_input = "" if self.omni_vlm_version == "vlm-81-ocr" else nexa_prompt() @@ -134,12 +136,15 @@ def run(self): style="default", message="" ) + first_chunk = True + for chunk in self.inference_streaming(user_input, image_path): + if first_chunk: + stop_spinner(stop_event, spinner_thread) + first_chunk = False + if chunk == '\n': + chunk = '' + print(chunk, end='', flush=True) - response = self.inference(user_input, image_path) - - stop_spinner(stop_event, spinner_thread) - - print(f"\nResponse: {response}") except KeyboardInterrupt: print("\nExiting...") break @@ -159,6 +164,20 @@ def inference(self, prompt: str, image_path: str): return decoded_response + def inference_streaming(self, prompt: str, image_path: str): + with suppress_stdout_stderr(): + prompt = ctypes.c_char_p(prompt.encode("utf-8")) + image_path = ctypes.c_char_p(image_path.encode("utf-8")) + oss = omni_vlm_cpp.omnivlm_inference_streaming(prompt, image_path) + + res = 0 + while res >= 0: + res = omni_vlm_cpp.sample(oss) + res_str = omni_vlm_cpp.get_str(oss).decode('utf-8') + if '<|im_start|>' in res_str or '' in res_str: + continue + yield res_str + def __del__(self): omni_vlm_cpp.omnivlm_free() @@ -218,4 +237,4 @@ def run_streamlit(self, model_path: str, is_local_path = False, hf = False, proj if args.streamlit: inference.run_streamlit(model_path) else: - inference.run() \ No newline at end of file + inference.run() diff --git a/nexa/gguf/nexa_inference_voice.py b/nexa/gguf/nexa_inference_voice.py index b3659776..9be8f257 100644 --- a/nexa/gguf/nexa_inference_voice.py +++ b/nexa/gguf/nexa_inference_voice.py @@ -12,6 +12,7 @@ from nexa.general import pull_model from nexa.utils import nexa_prompt, SpinningCursorAnimation from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr +import numpy as np logging.basicConfig(level=logging.INFO) @@ -58,6 +59,12 @@ def __init__(self, model_path=None, local_path=None, **kwargs): exit(1) self.params.update(kwargs) + + # for StreamASRProcessor + self.audio_buffer = np.array([], dtype=np.float32) + self.commited = [] + self.buffer_time_offset = 0.0 + self.model = None if not kwargs.get("streamlit", False): @@ -80,6 +87,50 @@ def _load_model(self): ) logging.debug("Model loaded successfully") + # for StreamASRProcessor + def live_transcribe(self, audio, prompt=""): + segments, info = self.model.transcribe(audio, language=self.params["language"], beam_size=self.params["beam_size"], word_timestamps=True, condition_on_previous_text=True, initial_prompt=prompt) + return list(segments) + + def ts_words(self, segments): + # return a list of (start, end, "word") for each word + words = [] + for seg in segments: + if seg.no_speech_prob > 0.9: + continue + for w in seg.words: + words.append((w.start, w.end, w.word)) + return words + + def insert_audio_chunk(self, audio): + self.audio_buffer = np.append(self.audio_buffer, audio) + + def process_iter(self): + # Transcribe the current buffer + if len(self.audio_buffer) == 0: + return (None, None, "") + res = self.live_transcribe(self.audio_buffer) + tsw = self.ts_words(res) + if len(tsw) == 0: + return (None, None, "") + + # We'll consider all words as committed for simplicity + self.commited = tsw + # return the entire transcription so far + text = " ".join([w[2] for w in self.commited]) + beg = self.commited[0][0] + self.buffer_time_offset + end = self.commited[-1][1] + self.buffer_time_offset + return (beg, end, text) + + def finish(self): + # Final flush when done + if len(self.commited) == 0: + return (None, None, "") + text = " ".join([w[2] for w in self.commited]) + beg = self.commited[0][0] + self.buffer_time_offset + end = self.commited[-1][1] + self.buffer_time_offset + return (beg, end, text) + def run(self): from nexa.gguf.llama._utils_spinner import start_spinner, stop_spinner @@ -180,6 +231,54 @@ def transcribe(self, audio, **kwargs): audio, **kwargs, ) + + def stream_transcription(self, audio_path, chunk_duration=1.0): + """ + Simulate streaming by processing the audio in small increments of time. + Yields partial transcripts as they become available. + """ + import librosa + SAMPLING_RATE = 16000 + audio, sr = librosa.load(audio_path, sr=SAMPLING_RATE, dtype=np.float32) + duration = len(audio) / SAMPLING_RATE + + start = time.time() + beg = 0.0 + while beg < duration: + now = time.time() - start + # Simulate waiting for real-time + if now < beg + chunk_duration: + time.sleep((beg + chunk_duration) - now) + + end = time.time() - start + if end > duration: + end = duration + + chunk_samples = int((end - beg)*SAMPLING_RATE) + chunk_audio = audio[int(beg*SAMPLING_RATE):int(beg*SAMPLING_RATE)+chunk_samples] + beg = end + + # Process incrementally + self.insert_audio_chunk(chunk_audio) + o = self.process_iter() + if o[0] is not None: + yield { + "emission_time_ms": (time.time()-start)*1000, + "segment_start_ms": o[0]*1000, + "segment_end_ms": o[1]*1000, + "text": o[2] + } + + # Final flush + o = self.finish() + if o[0] is not None: + yield { + "emission_time_ms": (time.time()-start)*1000, + "segment_start_ms": o[0]*1000, + "segment_end_ms": o[1]*1000, + "text": o[2], + "final": True + } def _transcribe_audio(self, audio_path): logging.debug(f"Transcribing audio from: {audio_path}") diff --git a/nexa/gguf/server/nexa_service.py b/nexa/gguf/server/nexa_service.py index 73772e31..6b354daa 100644 --- a/nexa/gguf/server/nexa_service.py +++ b/nexa/gguf/server/nexa_service.py @@ -18,16 +18,20 @@ from io import BytesIO from PIL import Image import base64 -import re from urllib.parse import urlparse from nexa.constants import ( NEXA_RUN_CHAT_TEMPLATE_MAP, NEXA_RUN_MODEL_MAP_VLM, NEXA_RUN_PROJECTOR_MAP, + NEXA_RUN_OMNI_VLM_MAP, + NEXA_RUN_OMNI_VLM_PROJECTOR_MAP, + NEXA_RUN_MODEL_MAP_AUDIO_LM, + NEXA_RUN_AUDIO_LM_PROJECTOR_MAP, NEXA_RUN_COMPLETION_TEMPLATE_MAP, NEXA_RUN_MODEL_PRECISION_MAP, NEXA_RUN_MODEL_MAP_FUNCTION_CALLING, + NEXA_MODEL_LIST_PATH ) from nexa.gguf.lib_utils import is_gpu_available from nexa.gguf.llama.llama_chat_format import ( @@ -38,9 +42,15 @@ from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr from nexa.general import pull_model from nexa.gguf.llama.llama import Llama +from nexa.gguf.nexa_inference_vlm_omni import NexaOmniVlmInference +from nexa.gguf.nexa_inference_audio_lm import NexaAudioLMInference from nexa.gguf.sd.stable_diffusion import StableDiffusion from faster_whisper import WhisperModel +import numpy as np import argparse +import soundfile as sf +import librosa +import io logging.basicConfig(level=logging.INFO) @@ -72,25 +82,28 @@ ) model = None +whisper_model = None chat_format = None completion_template = None hostname = socket.gethostname() chat_completion_system_prompt = [{"role": "system", "content": "You are a helpful assistant"}] function_call_system_prompt = [{"role": "system", "content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"}] model_path = None +whisper_model_path = "faster-whisper-tiny" # by default, use tiny whisper model n_ctx = None is_local_path = False model_type = None is_huggingface = False is_modelscope = False projector_path = None +SAMPLING_RATE = 16000 # Request Classes class GenerationRequest(BaseModel): prompt: str = "Tell me a story" - temperature: float = 1.0 + temperature: float = 0.8 max_new_tokens: int = 128 - top_k: int = 50 - top_p: float = 1.0 + top_k: int = 40 + top_p: float = 0.95 stop_words: Optional[List[str]] = [] logprobs: Optional[int] = None stream: Optional[bool] = False @@ -101,7 +114,10 @@ class TextContent(BaseModel): class ImageUrlContent(BaseModel): type: Literal["image_url"] = "image_url" - image_url: Dict[str, Union[HttpUrl, str]] + image_url: Dict[str, Union[HttpUrl, str, None]] = Field( + default={"url": None, "path": None}, + description="Either url or path must be provided" + ) ContentItem = Union[str, TextContent, ImageUrlContent] @@ -117,7 +133,7 @@ class ChatCompletionRequest(BaseModel): messages: List[Message] = [ {"role": "user", "content": "Tell me a story"}] max_tokens: Optional[int] = 128 - temperature: Optional[float] = 0.1 + temperature: Optional[float] = 0.2 stream: Optional[bool] = False stop_words: Optional[List[str]] = [] logprobs: Optional[bool] = False @@ -125,6 +141,23 @@ class ChatCompletionRequest(BaseModel): top_k: Optional[int] = 40 top_p: Optional[float] = 0.95 +class VLMChatCompletionRequest(BaseModel): + messages: List[Message] = [ + {"role": "user", "content": [ + {"type": "text", "text": "What’s in this image?"}, + {"type": "image_url", "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + }} + ] + } + ] + max_tokens: Optional[int] = 128 + temperature: Optional[float] = 0.2 + stream: Optional[bool] = False + stop_words: Optional[List[str]] = [] + top_k: Optional[int] = 40 + top_p: Optional[float] = 0.95 + class FunctionDefinitionRequestClass(BaseModel): type: str = "function" function: Dict[str, Any] @@ -175,6 +208,82 @@ class EmbeddingRequest(BaseModel): normalize: Optional[bool] = False truncate: Optional[bool] = True +class LoadModelRequest(BaseModel): + model_path: str = "llama3.2" + model_type: Optional[str] = None + is_local_path: Optional[bool] = False + is_huggingface: Optional[bool] = False + is_modelscope: Optional[bool] = False + local_projector_path: Optional[str] = None + + model_config = { + "protected_namespaces": () + } +class LoadWhisperModelRequest(BaseModel): + whisper_model_path: str = "faster-whisper-tiny" + +class DownloadModelRequest(BaseModel): + model_path: str = "llama3.2" + + model_config = { + "protected_namespaces": () + } + +class StreamASRProcessor: + def __init__(self, asr, task, language): + self.asr = asr + self.task = task + self.language = None if language == "auto" else language + self.audio_buffer = np.array([], dtype=np.float32) + self.commited = [] + self.buffer_time_offset = 0.0 + + def insert_audio_chunk(self, audio): + self.audio_buffer = np.append(self.audio_buffer, audio) + + def process_iter(self): + if len(self.audio_buffer) == 0: + return (None, None, "") + res = self.transcribe(self.audio_buffer) + tsw = self.ts_words(res) + if len(tsw) == 0: + return (None, None, "") + + self.commited = tsw + text = " ".join([w[2] for w in self.commited]) + beg = self.commited[0][0] + self.buffer_time_offset + end = self.commited[-1][1] + self.buffer_time_offset + return (beg, end, text) + + def finish(self): + if len(self.commited) == 0: + return (None, None, "") + text = " ".join([w[2] for w in self.commited]) + beg = self.commited[0][0] + self.buffer_time_offset + end = self.commited[-1][1] + self.buffer_time_offset + return (beg, end, text) + + def transcribe(self, audio, prompt=""): + segments, info = self.asr.transcribe( + audio, + language=self.language, + task=self.task, + beam_size=5, + word_timestamps=True, + condition_on_previous_text=True, + initial_prompt=prompt + ) + return list(segments) + + def ts_words(self, segments): + words = [] + for seg in segments: + if seg.no_speech_prob > 0.9: + continue + for w in seg.words: + words.append((w.start, w.end, w.word)) + return words + # helper functions async def load_model(): global model, chat_format, completion_template, model_path, n_ctx, is_local_path, model_type, is_huggingface, is_modelscope, projector_path @@ -192,10 +301,16 @@ async def load_model(): raise ValueError("Multimodal and Audio models are not supported for Hugging Face") downloaded_path, _ = pull_model(model_path, hf=is_huggingface, ms=is_modelscope) else: - if model_path in NEXA_RUN_MODEL_MAP_VLM: # for Multimodal models - downloaded_path, _ = pull_model(NEXA_RUN_MODEL_MAP_VLM[model_path]) - projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[model_path]) - model_type = "Multimodal" + if model_path in NEXA_RUN_MODEL_MAP_VLM or model_path in NEXA_RUN_OMNI_VLM_MAP or model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM: + if model_path in NEXA_RUN_OMNI_VLM_MAP: + downloaded_path, model_type = pull_model(NEXA_RUN_OMNI_VLM_MAP[model_path]) + projector_downloaded_path, _ = pull_model(NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[model_path]) + elif model_path in NEXA_RUN_MODEL_MAP_VLM: + downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_VLM[model_path]) + projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[model_path]) + elif model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM: + downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_AUDIO_LM[model_path]) + projector_downloaded_path, _ = pull_model(NEXA_RUN_AUDIO_LM_PROJECTOR_MAP[model_path]) else: downloaded_path, model_type = pull_model(model_path) @@ -282,47 +397,88 @@ async def load_model(): logging.info(f"model loaded as {model}") elif model_type == "Multimodal": with suppress_stdout_stderr(): - projector_handler = NEXA_PROJECTOR_HANDLER_MAP.get(model_path, Llava15ChatHandler) - projector = (projector_handler( - clip_model_path=projector_downloaded_path, verbose=False - ) if projector_downloaded_path else None) - - chat_format = NEXA_RUN_CHAT_TEMPLATE_MAP.get(model_path, None) - try: - model = Llama( - model_path=downloaded_path, - chat_handler=projector, - verbose=False, - chat_format=chat_format, - n_ctx=n_ctx, - n_gpu_layers=-1 if is_gpu_available() else 0, + if 'omni' in model_path.lower(): + try: + model = NexaOmniVlmInference( + model_path=model_path, + device="gpu" if is_gpu_available() else "cpu" + ) + except Exception as e: + logging.error( + f"Failed to load OmniVLM model: {e}. Falling back to CPU.", + exc_info=True, + ) + model = NexaOmniVlmInference( + model_path=model_path, + device="cpu" + ) + else: + projector_handler = NEXA_PROJECTOR_HANDLER_MAP.get(model_path, Llava15ChatHandler) + projector = (projector_handler( + clip_model_path=projector_downloaded_path, verbose=False + ) if projector_downloaded_path else None) + + chat_format = NEXA_RUN_CHAT_TEMPLATE_MAP.get(model_path, None) + try: + model = Llama( + model_path=downloaded_path, + chat_handler=projector, + verbose=False, + chat_format=chat_format, + n_ctx=n_ctx, + n_gpu_layers=-1 if is_gpu_available() else 0, + ) + except Exception as e: + logging.error( + f"Failed to load model: {e}. Falling back to CPU.", + exc_info=True, + ) + model = Llama( + model_path=downloaded_path, + chat_handler=projector, + verbose=False, + chat_format=chat_format, + n_ctx=n_ctx, + n_gpu_layers=0, # hardcode to use CPU + ) + logging.info(f"Model loaded as {model}") + elif model_type == "AudioLM": + with suppress_stdout_stderr(): + try: + model = NexaAudioLMInference( + model_path=model_path, + device="gpu" if is_gpu_available() else "cpu" ) except Exception as e: logging.error( f"Failed to load model: {e}. Falling back to CPU.", exc_info=True, ) - model = Llama( - model_path=downloaded_path, - chat_handler=projector, - verbose=False, - chat_format=chat_format, - n_ctx=n_ctx, - n_gpu_layers=0, # hardcode to use CPU + model = NexaAudioLMInference( + model_path=model_path, + device="cpu" ) + logging.info(f"model loaded as {model}") + else: + raise ValueError(f"Model {model_path} not found in Model Hub. If you are using local path, be sure to add --local_path and --model_type flags.") - logging.info(f"Model loaded as {model}") - elif model_type == "Audio": +async def load_whisper_model(custom_whisper_model_path=None): + global whisper_model, whisper_model_path + try: + if custom_whisper_model_path: + whisper_model_path = custom_whisper_model_path + downloaded_path, _ = pull_model(whisper_model_path) with suppress_stdout_stderr(): - model = WhisperModel( + whisper_model = WhisperModel( downloaded_path, device="cpu", # only support cpu for now because cuDNN needs to be installed on user's machine compute_type="default" ) - logging.info(f"model loaded as {model}") - else: - raise ValueError(f"Model {model_path} not found in Model Hub. If you are using local path, be sure to add --local_path and --model_type flags.") - + logging.info(f"whisper model loaded as {whisper_model}") + except Exception as e: + logging.error(f"Error loading Whisper model: {e}") + raise ValueError(f"Failed to load Whisper model: {str(e)}") + def nexa_run_text_generation( prompt, temperature, stop_words, max_new_tokens, top_k, top_p, logprobs=None, stream=False, is_chat_completion=True ) -> Dict[str, Any]: @@ -477,17 +633,20 @@ def is_url(s: Union[str, AnyUrl]) -> bool: except ValueError: return False -def process_image_input(image_input: Union[str, AnyUrl]) -> str: - """Process image input, returning a data URI for both URL and base64 inputs.""" - if isinstance(image_input, AnyUrl) or is_url(image_input): - return image_url_to_base64(str(image_input)) - elif is_base64(image_input): - if image_input.startswith('data:image'): - return image_input - else: - return f"data:image/png;base64,{image_input}" +def process_image_input(image_data: Dict[str, Union[HttpUrl, str, None]]) -> str: + """Process image input from either URL or file path, returning a data URI.""" + url = image_data.get("url") + path = image_data.get("path") + if url: + if isinstance(url, str) and (url.startswith('data:image') or is_base64(url)): + return url if url.startswith('data:image') else f"data:image/png;base64,{url}" + return image_url_to_base64(str(url)) + elif path: + if not os.path.exists(path): + raise ValueError(f"Image file not found: {path}") + return image_path_to_base64(path) else: - raise ValueError("Invalid image input. Must be a URL or base64 encoded image.") + raise ValueError("Either 'url' or 'path' must be provided in image_url") def image_url_to_base64(image_url: str) -> str: response = requests.get(image_url) @@ -496,6 +655,19 @@ def image_url_to_base64(image_url: str) -> str: img.save(buffered, format="PNG") return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}" +def image_path_to_base64(file_path): + if file_path and os.path.exists(file_path): + with open(file_path, "rb") as img_file: + base64_data = base64.b64encode(img_file.read()).decode("utf-8") + return f"data:image/png;base64,{base64_data}" + return None + +def load_audio_from_bytes(audio_bytes: bytes): + buffer = io.BytesIO(audio_bytes) + a, sr = sf.read(buffer, dtype='float32') + if sr != SAMPLING_RATE: + a = librosa.resample(a, orig_sr=sr, target_sr=SAMPLING_RATE) + return a def run_nexa_ai_service(model_path_arg=None, is_local_path_arg=False, model_type_arg=None, huggingface=False, modelscope=False, projector_local_path_arg=None, **kwargs): global model_path, n_ctx, is_local_path, model_type, is_huggingface, is_modelscope, projector_path @@ -511,31 +683,23 @@ def run_nexa_ai_service(model_path_arg=None, is_local_path_arg=False, model_type model_path = os.path.abspath(model_path_arg) if is_local_path_arg else model_path_arg model_type = model_type_arg else: - model_path = model_path_arg or "gemma" + model_path = model_path_arg model_type = None - os.environ["MODEL_PATH"] = model_path - os.environ["IS_LOCAL_PATH"] = str(is_local_path_arg) - os.environ["MODEL_TYPE"] = model_type if model_type else "" - os.environ["HUGGINGFACE"] = str(huggingface) - os.environ["MODELSCOPE"] = str(modelscope) - os.environ["PROJECTOR_PATH"] = projector_path if projector_path else "" n_ctx = kwargs.get("nctx", 2048) host = kwargs.get("host", "localhost") port = kwargs.get("port", 8000) reload = kwargs.get("reload", False) + uvicorn.run(app, host=host, port=port, reload=reload) # Endpoints @app.on_event("startup") async def startup_event(): - global model_path, is_local_path, model_type, is_huggingface, is_modelscope, projector_path - model_path = os.getenv("MODEL_PATH", "gemma") - is_local_path = os.getenv("IS_LOCAL_PATH", "False").lower() == "true" - model_type = os.getenv("MODEL_TYPE", None) - is_huggingface = os.getenv("HUGGINGFACE", "False").lower() == "true" - is_modelscope = os.getenv("MODELSCOPE", "False").lower() == "true" - projector_path = os.getenv("PROJECTOR_PATH", None) - await load_model() + global model_path + if model_path: + await load_model() + else: + logging.info("No model path provided. Server started without loading a model.") @app.get("/", response_class=HTMLResponse, tags=["Root"]) @@ -557,11 +721,118 @@ def _resp_async_generator(streamer): yield f"data: {json.dumps(chunk)}\n\n" yield "data: [DONE]\n\n" +@app.post("/v1/download_model", tags=["Model"]) +async def download_model(request: DownloadModelRequest): + """Download a model from the model hub""" + try: + if request.model_path in NEXA_RUN_MODEL_MAP_VLM or request.model_path in NEXA_RUN_OMNI_VLM_MAP or request.model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM: # models and projectors + if request.model_path in NEXA_RUN_MODEL_MAP_VLM: + downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_VLM[request.model_path]) + projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[request.model_path]) + elif request.model_path in NEXA_RUN_OMNI_VLM_MAP: + downloaded_path, model_type = pull_model(NEXA_RUN_OMNI_VLM_MAP[request.model_path]) + projector_downloaded_path, _ = pull_model(NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[request.model_path]) + elif request.model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM: + downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_AUDIO_LM[request.model_path]) + projector_downloaded_path, _ = pull_model(NEXA_RUN_AUDIO_LM_PROJECTOR_MAP[request.model_path]) + return { + "status": "success", + "message": "Successfully downloaded model and projector", + "model_path": request.model_path, + "model_local_path": downloaded_path, + "projector_local_path": projector_downloaded_path, + "model_type": model_type + } + else: + downloaded_path, model_type = pull_model(request.model_path) + return { + "status": "success", + "message": "Successfully downloaded model", + "model_path": request.model_path, + "model_local_path": downloaded_path, + "model_type": model_type + } + + except Exception as e: + logging.error(f"Error downloading model: {e}") + raise HTTPException( + status_code=500, + detail=f"Failed to download model: {str(e)}" + ) + +@app.post("/v1/load_model", tags=["Model"]) +async def load_different_model(request: LoadModelRequest): + """Load a different model while maintaining the global model state""" + try: + global model_path, is_local_path, model_type, is_huggingface, is_modelscope, projector_path + + # Update global variables with new configuration + model_path = request.model_path + is_local_path = request.is_local_path + model_type = request.model_type + is_huggingface = request.is_huggingface + is_modelscope = request.is_modelscope + projector_path = request.local_projector_path + + # Load the new model + await load_model() + + return { + "status": "success", + "message": f"Successfully loaded model: {model_path}", + "model_type": model_type + } + + except Exception as e: + logging.error(f"Error loading model: {e}") + raise HTTPException( + status_code=500, + detail=f"Failed to load model: {str(e)}" + ) + +@app.post("/v1/load_whisper_model", tags=["Model"]) +async def load_different_whisper_model(request: LoadWhisperModelRequest): + """Load a different Whisper model while maintaining the global model state""" + try: + global whisper_model_path + whisper_model_path = request.whisper_model_path + await load_whisper_model(custom_whisper_model_path=whisper_model_path) + + return { + "status": "success", + "message": f"Successfully loaded Whisper model: {whisper_model_path}", + "model_type": "Audio", + } + except Exception as e: + logging.error(f"Error loading Whisper model: {e}") + raise HTTPException( + status_code=500, + detail=f"Failed to load Whisper model: {str(e)}" + ) + +@app.get("/v1/list_models", tags=["Model"]) +async def list_models(): + """List all models available in the model hub""" + try: + if NEXA_MODEL_LIST_PATH.exists(): + with open(NEXA_MODEL_LIST_PATH, "r") as f: + model_list = json.load(f) + else: + model_list = {} + return JSONResponse(content=model_list) + except Exception as e: + logging.error(f"Error listing models: {e}") + raise HTTPException(status_code=500, detail=str(e)) + @app.post("/v1/completions", tags=["NLP"]) async def generate_text(request: GenerationRequest): try: + if model_type != "NLP": + raise HTTPException( + status_code=400, + detail="The model that is loaded is not an NLP model. Please use an NLP model for text generation." + ) generation_kwargs = request.dict() - if request.stream: # Run the generation and stream the response streamer = nexa_run_text_generation(is_chat_completion=False, **generation_kwargs) @@ -587,85 +858,226 @@ async def generate_text(request: GenerationRequest): @app.post("/v1/chat/completions", tags=["NLP"]) -async def chat_completions(request: ChatCompletionRequest): +async def text_chat_completions(request: ChatCompletionRequest): + """Endpoint for text-only chat completions using NLP models""" try: - is_vlm = any(isinstance(msg.content, list) for msg in request.messages) + if model_type != "NLP": + raise HTTPException( + status_code=400, + detail="The model that is loaded is not an NLP model. Please use an NLP model for text chat completion." + ) + + generation_kwargs = GenerationRequest( + prompt="" if len(request.messages) == 0 else request.messages[-1].content, + temperature=request.temperature, + max_new_tokens=request.max_tokens, + stop_words=request.stop_words, + logprobs=request.logprobs, + top_logprobs=request.top_logprobs, + stream=request.stream, + top_k=request.top_k, + top_p=request.top_p + ).dict() + + if request.stream: + streamer = nexa_run_text_generation(is_chat_completion=True, **generation_kwargs) + return StreamingResponse(_resp_async_generator(streamer), media_type="application/x-ndjson") - if is_vlm: - if model_type != "Multimodal": - raise HTTPException(status_code=400, detail="The model that is loaded is not a Multimodal model. Please use a Multimodal model (e.g. llava1.6-vicuna) for VLM.") - # Process VLM request - processed_messages = [] - for msg in request.messages: - if isinstance(msg.content, list): - processed_content = [] - for item in msg.content: - if isinstance(item, TextContent): - processed_content.append({"type": "text", "text": item.text}) - elif isinstance(item, ImageUrlContent): - try: - image_input = item.image_url["url"] - image_data_uri = process_image_input(image_input) - processed_content.append({"type": "image_url", "image_url": {"url": image_data_uri}}) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - processed_messages.append({"role": msg.role, "content": processed_content}) + result = nexa_run_text_generation(is_chat_completion=True, **generation_kwargs) + return { + "id": str(uuid.uuid4()), + "object": "chat.completion", + "created": time.time(), + "choices": [{ + "message": Message(role="assistant", content=result["result"]), + "logprobs": result["logprobs"] if "logprobs" in result else None, + }], + } + + except HTTPException as e: + raise e + except Exception as e: + logging.error(f"Error in text chat completions: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/v1/vlm/chat/completions", tags=["Multimodal"]) +async def multimodal_chat_completions(request: VLMChatCompletionRequest): + """Endpoint for multimodal chat completions using VLM models""" + try: + if model_type != "Multimodal" or 'omni' in model_path.lower(): + raise HTTPException( + status_code=400, + detail="The model that is loaded is not a Multimodal model. Please use a Multimodal model (e.g. nanollava) for VLM." + ) + + processed_messages = [] + for msg in request.messages: + if isinstance(msg.content, list): + processed_content = [] + for item in msg.content: + if isinstance(item, TextContent): + processed_content.append({"type": "text", "text": item.text}) + elif isinstance(item, ImageUrlContent): + try: + image_data_uri = process_image_input(item.image_url) + processed_content.append({ + "type": "image_url", + "image_url": {"url": image_data_uri} + }) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + processed_messages.append({"role": msg.role, "content": processed_content}) + else: + processed_messages.append({"role": msg.role, "content": msg.content}) + + response = model.create_chat_completion( + messages=processed_messages, + max_tokens=request.max_tokens, + temperature=request.temperature, + top_k=request.top_k, + top_p=request.top_p, + stream=request.stream, + stop=request.stop_words, + ) + + if request.stream: + return StreamingResponse(_resp_async_generator(response), media_type="application/x-ndjson") + return response + + except HTTPException as e: + raise e + except Exception as e: + logging.error(f"Error in multimodal chat completions: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +async def _resp_omnivlm_async_generator(model, prompt: str, image_path: str): + _id = str(uuid.uuid4()) + try: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + + for token in model.inference_streaming(prompt, image_path): + chunk = { + "id": _id, + "object": "chat.completion.chunk", + "created": time.time(), + "choices": [{ + "delta": {"content": token}, + "index": 0, + "finish_reason": None + }] + } + yield f"data: {json.dumps(chunk)}\n\n" + yield "data: [DONE]\n\n" + except Exception as e: + logging.error(f"Error in OmniVLM streaming: {e}") + raise + +@app.post("/v1/omnivlm/chat/completions", tags=["Multimodal"]) +async def omnivlm_chat_completions(request: VLMChatCompletionRequest): + """Endpoint for Multimodal chat completions using OmniVLM models""" + temp_file = None + image_path = None + + try: + if model_type != "Multimodal" or 'omni' not in model_path.lower(): + raise HTTPException( + status_code=400, + detail="Please use an OmniVLM model for this endpoint." + ) + + prompt = "" + last_message = request.messages[-1] + + if isinstance(last_message.content, list): + for item in last_message.content: + if isinstance(item, TextContent): + prompt = item.text + elif isinstance(item, ImageUrlContent): + try: + base64_image = process_image_input(item.image_url) + base64_data = base64_image.split(',')[1] + image_data = base64.b64decode(base64_data) + + temp_file = tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) + temp_file.write(image_data) + temp_file.flush() + os.fsync(temp_file.fileno()) + temp_file.close() + + image_path = temp_file.name + + if not os.path.exists(image_path): + raise ValueError(f"Failed to create temporary file at {image_path}") + + except Exception as e: + if temp_file and os.path.exists(temp_file.name): + os.unlink(temp_file.name) + raise ValueError(f"Failed to process image: {str(e)}") else: - processed_messages.append({"role": msg.role, "content": msg.content}) - - response = model.create_chat_completion( - messages=processed_messages, - temperature=request.temperature, - max_tokens=request.max_tokens, - top_k=request.top_k, - top_p=request.top_p, - stream=request.stream, + raise ValueError("Either url or path must be provided for image") + else: + prompt = last_message.content + + if not image_path: + raise HTTPException( + status_code=400, + detail="Image is required for OmniVLM inference" + ) + + if request.stream: + async def stream_with_cleanup(): + try: + async for chunk in _resp_omnivlm_async_generator(model, prompt, image_path): + yield chunk + finally: + if image_path and os.path.exists(image_path): + try: + os.unlink(image_path) + except Exception as e: + logging.error(f"Error cleaning up file {image_path}: {e}") + + return StreamingResponse( + stream_with_cleanup(), + media_type="text/event-stream" ) else: - # Process regular chat completion request - generation_kwargs = GenerationRequest( - prompt="" if len(request.messages) == 0 else request.messages[-1].content, - temperature=request.temperature, - max_new_tokens=request.max_tokens, - stop_words=request.stop_words, - logprobs=request.logprobs, - top_logprobs=request.top_logprobs, - stream=request.stream, - top_k=request.top_k, - top_p=request.top_p - ).dict() - - if request.stream: - streamer = nexa_run_text_generation(is_chat_completion=True, **generation_kwargs) - return StreamingResponse(_resp_async_generator(streamer), media_type="application/x-ndjson") - else: - result = nexa_run_text_generation(is_chat_completion=True, **generation_kwargs) + try: + response = model.inference(prompt, image_path) return { "id": str(uuid.uuid4()), "object": "chat.completion", "created": time.time(), "choices": [{ - "message": Message(role="assistant", content=result["result"]), - "logprobs": result["logprobs"] if "logprobs" in result else None, + "message": {"role": "assistant", "content": response}, + "index": 0, + "finish_reason": "stop" }], } - - if request.stream: - return StreamingResponse(_resp_async_generator(response), media_type="application/x-ndjson") - else: - return response - - except HTTPException as e: - raise e + finally: + if image_path and os.path.exists(image_path): + os.unlink(image_path) except Exception as e: - logging.error(f"Error in chat completions: {e}") + if image_path and os.path.exists(image_path): + try: + os.unlink(image_path) + except Exception as cleanup_error: + logging.error(f"Error cleaning up file {image_path}: {cleanup_error}") + + if isinstance(e, HTTPException): + raise e + logging.error(f"Error in OmniVLM chat completions: {e}") raise HTTPException(status_code=500, detail=str(e)) - @app.post("/v1/function-calling", tags=["NLP"]) async def function_call(request: FunctionCallRequest): try: + if model_type != "NLP": + raise HTTPException( + status_code=400, + detail="The model that is loaded is not an NLP model. Please use an NLP model for function calling." + ) messages = function_call_system_prompt + [ {"role": msg.role, "content": msg.content} for msg in request.messages ] @@ -687,8 +1099,12 @@ async def function_call(request: FunctionCallRequest): @app.post("/v1/txt2img", tags=["Computer Vision"]) async def txt2img(request: ImageGenerationRequest): try: + if model_type != "Computer Vision": + raise HTTPException( + status_code=400, + detail="The model that is loaded is not a Computer Vision model. Please use a Computer Vision model for image generation." + ) generation_kwargs = request.dict() - generated_images = await nexa_run_image_generation(**generation_kwargs) resp = {"created": time.time(), "data": []} @@ -711,6 +1127,11 @@ async def txt2img(request: ImageGenerationRequest): @app.post("/v1/img2img", tags=["Computer Vision"]) async def img2img(request: ImageGenerationRequest): try: + if model_type != "Computer Vision": + raise HTTPException( + status_code=400, + detail="The model that is loaded is not a Computer Vision model. Please use a Computer Vision model for image generation." + ) generation_kwargs = request.dict() generated_images = await nexa_run_image_generation(**generation_kwargs) @@ -732,62 +1153,218 @@ async def img2img(request: ImageGenerationRequest): logging.error(f"Error in img2img generation: {e}") raise HTTPException(status_code=500, detail=str(e)) -@app.post("/v1/audio/transcriptions", tags=["Audio"]) -async def transcribe_audio( +@app.post("/v1/audio/processing", tags=["Audio"]) +async def process_audio( file: UploadFile = File(...), - beam_size: Optional[int] = Query(5, description="Beam size for transcription"), - language: Optional[str] = Query(None, description="Language code (e.g., 'en', 'fr')"), - temperature: Optional[float] = Query(0.0, description="Temperature for sampling"), + task: str = Query("transcribe", + description="Task to perform on the audio. Options are: 'transcribe' or 'translate'.", + regex="^(transcribe|translate)$" + ), + beam_size: Optional[int] = Query(5, description="Beam size for decoding."), + language: Optional[str] = Query(None, description="Language code (e.g. 'en', 'fr') for transcription."), + temperature: Optional[float] = Query(0.0, description="Temperature for sampling.") ): - try: + if not whisper_model: + raise HTTPException( + status_code=400, + detail="Whisper model is not loaded. Please load a Whisper model first." + ) + with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_audio: temp_audio.write(await file.read()) temp_audio_path = temp_audio.name - transcribe_params = { + # Set up parameters for Whisper or similar model + task_params = { "beam_size": beam_size, - "language": language, - "task": "transcribe", "temperature": temperature, - "vad_filter": True + "vad_filter": True, + "task": task } - segments, _ = model.transcribe(temp_audio_path, **transcribe_params) - transcription = "".join(segment.text for segment in segments) - return JSONResponse(content={"text": transcription}) + + # Only include language parameter if task is "transcribe" + # For "translate", the language is always defined as "en" + if task == "transcribe" and language: + task_params["language"] = language + + segments, _ = whisper_model.transcribe(temp_audio_path, **task_params) + result_text = "".join(segment.text for segment in segments) + return JSONResponse(content={"text": result_text}) + except Exception as e: - raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}") + raise HTTPException(status_code=500, detail=f"Error during {task}: {str(e)}") finally: - os.unlink(temp_audio_path) + if 'temp_audio_path' in locals() and os.path.exists(temp_audio_path): + os.unlink(temp_audio_path) -@app.post("/v1/audio/translations", tags=["Audio"]) -async def translate_audio( +@app.post("/v1/audio/processing_stream", tags=["Audio"]) +async def processing_stream_audio( file: UploadFile = File(...), - beam_size: Optional[int] = Query(5, description="Beam size for translation"), - temperature: Optional[float] = Query(0.0, description="Temperature for sampling"), + task: str = Query("transcribe", + description="Task to perform on the audio. Options are: 'transcribe' or 'translate'.", + regex="^(transcribe|translate)$" + ), + language: Optional[str] = Query("auto", description="Language code (e.g., 'en', 'fr')"), + min_chunk: Optional[float] = Query(1.0, description="Minimum chunk duration for streaming"), ): try: - with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_audio: - temp_audio.write(await file.read()) - temp_audio_path = temp_audio.name + if not whisper_model: + raise HTTPException( + status_code=400, + detail="Whisper model is not loaded. Please load a Whisper model first." + ) + + # Read the entire file into memory + audio_bytes = await file.read() + a_full = load_audio_from_bytes(audio_bytes) + duration = len(a_full) / SAMPLING_RATE + + # Only include language parameter if task is "transcribe" + # For "translate", the language is always defined as "en" + if task == "transcribe" and language != "auto": + used_language = language + else: + used_language = None + + warmup_audio = a_full[:SAMPLING_RATE] # first second + whisper_model.transcribe(warmup_audio) + + streamer = StreamASRProcessor(whisper_model, task, used_language) + + start = time.time() + beg = 0.0 + + def stream_generator(): + nonlocal beg + while beg < duration: + now = time.time() - start + if now < beg + min_chunk: + time.sleep((beg + min_chunk) - now) + end = time.time() - start + if end > duration: + end = duration + + chunk_samples = int((end - beg)*SAMPLING_RATE) + chunk_audio = a_full[int(beg*SAMPLING_RATE):int(beg*SAMPLING_RATE)+chunk_samples] + beg = end + + streamer.insert_audio_chunk(chunk_audio) + o = streamer.process_iter() + if o[0] is not None: + data = { + "emission_time_ms": (time.time()-start)*1000, + "segment_start_ms": o[0]*1000, + "segment_end_ms": o[1]*1000, + "text": o[2] + } + yield f"data: {json.dumps(data)}\n\n".encode("utf-8") + + # Final flush + o = streamer.finish() + if o[0] is not None: + data = { + "emission_time_ms": (time.time()-start)*1000, + "segment_start_ms": o[0]*1000, + "segment_end_ms": o[1]*1000, + "text": o[2], + "final": True + } + yield f"data: {json.dumps(data)}\n\n".encode("utf-8") + + return StreamingResponse(stream_generator(), media_type="application/x-ndjson") - translate_params = { - "beam_size": beam_size, - "task": "translate", - "temperature": temperature, - "vad_filter": True - } - segments, _ = model.transcribe(temp_audio_path, **translate_params) - translation = "".join(segment.text for segment in segments) - return JSONResponse(content={"text": translation}) except Exception as e: - raise HTTPException(status_code=500, detail=f"Error during translation: {str(e)}") - finally: - os.unlink(temp_audio_path) + logging.error(f"Error in audio processing stream: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/v1/audiolm/chat/completions", tags=["AudioLM"]) +async def audio_chat_completions( + file: UploadFile = File(...), + prompt: Optional[str] = Query(None, description="Prompt for audio chat completions"), + stream: Optional[bool] = Query(False, description="Whether to stream the response"), +): + temp_file = None + + try: + if model_type != "AudioLM": + raise HTTPException( + status_code=400, + detail="The model that is loaded is not an AudioLM model. Please use an AudioLM model for audio chat completions." + ) + + temp_file = tempfile.NamedTemporaryFile(suffix=os.path.splitext(file.filename)[1], delete=False) + temp_file.write(await file.read()) + temp_file.flush() + os.fsync(temp_file.fileno()) + audio_path = temp_file.name + + if stream: + async def stream_with_cleanup(): + try: + for token in model.inference_streaming(audio_path, prompt or ""): + chunk = { + "id": str(uuid.uuid4()), + "object": "chat.completion.chunk", + "created": time.time(), + "choices": [{ + "delta": {"content": token}, + "index": 0, + "finish_reason": None + }] + } + yield f"data: {json.dumps(chunk)}\n\n" + yield "data: [DONE]\n\n" + finally: + temp_file.close() + if os.path.exists(audio_path): + os.unlink(audio_path) + + return StreamingResponse( + stream_with_cleanup(), + media_type="text/event-stream" + ) + else: + try: + print("audio_path: ", audio_path) + response = model.inference(audio_path, prompt or "") + return { + "id": str(uuid.uuid4()), + "object": "chat.completion", + "created": time.time(), + "choices": [{ + "message": {"role": "assistant", "content": response}, + "index": 0, + "finish_reason": "stop" + }], + } + finally: + temp_file.close() + if os.path.exists(audio_path): + os.unlink(audio_path) + + except Exception as e: + if temp_file: + temp_file.close() + if os.path.exists(temp_file.name): + try: + os.unlink(temp_file.name) + except Exception as cleanup_error: + logging.error(f"Error cleaning up file {temp_file.name}: {cleanup_error}") + + if isinstance(e, HTTPException): + raise e + logging.error(f"Error in audio chat completions: {e}") + raise HTTPException(status_code=500, detail=str(e)) @app.post("/v1/embeddings", tags=["Embedding"]) async def create_embedding(request: EmbeddingRequest): try: + if model_type != "Text Embedding": + raise HTTPException( + status_code=400, + detail="The model that is loaded is not a Text Embedding model. Please use a Text Embedding model for embedding generation." + ) if isinstance(request.input, list): embeddings_results = [model.embed(text, normalize=request.normalize, truncate=request.truncate) for text in request.input] else: @@ -832,7 +1409,9 @@ async def create_embedding(request: EmbeddingRequest): parser = argparse.ArgumentParser( description="Run the Nexa AI Text Generation Service" ) - parser.add_argument("model_path", type=str, nargs='?', default="gemma", help="Folder Path on Amazon S3") + parser.add_argument( + "--model_path", type=str, help="Path or identifier for the model in Nexa Model Hub" + ) parser.add_argument( "--nctx", type=int, default=2048, help="Length of context window" ) @@ -870,7 +1449,7 @@ async def create_embedding(request: EmbeddingRequest): ) args = parser.parse_args() run_nexa_ai_service( - args.model_path, + model_path_arg=args.model_path, is_local_path_arg=args.local_path, model_type_arg=args.model_type, huggingface=args.huggingface, diff --git a/nexa/siglip/__init__.py b/nexa/siglip/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nexa/siglip/nexa_siglip_server.py b/nexa/siglip/nexa_siglip_server.py new file mode 100644 index 00000000..c2fba134 --- /dev/null +++ b/nexa/siglip/nexa_siglip_server.py @@ -0,0 +1,178 @@ +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from fastapi import Request +from fastapi.responses import HTMLResponse +from fastapi.middleware.cors import CORSMiddleware +import uvicorn +import os +import socket +import time +import argparse +from PIL import Image +import torch +from transformers import AutoProcessor, AutoModel + +app = FastAPI(title="Nexa AI SigLIP Image-Text Matching Service") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Allows all origins + allow_credentials=True, + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers +) + +# Global variables +hostname = socket.gethostname() +siglip_model = None +siglip_processor = None +images_dict = {} + +class ImagePathRequest(BaseModel): + image_dir: str + +class SearchResponse(BaseModel): + image_path: str + similarity_score: float + latency: float + +def init_model(): + """Initialize SigLIP model and processor""" + global siglip_model, siglip_processor + siglip_model = AutoModel.from_pretrained("google/siglip-base-patch16-384") + siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-384") + +def load_images_from_directory(image_dir, valid_extensions=('.jpg', '.jpeg', '.png', '.webp')): + """Load images from directory""" + images_dict = {} + + if not os.path.exists(image_dir): + raise ValueError(f"Directory {image_dir} does not exist") + + for filename in os.listdir(image_dir): + if filename.lower().endswith(valid_extensions): + image_path = os.path.join(image_dir, filename) + try: + image = Image.open(image_path).convert("RGB") + images_dict[image_path] = image + except Exception as e: + print(f"Failed to load image {filename}: {str(e)}") + + if not images_dict: + raise ValueError(f"No valid image files found in {image_dir}") + + return images_dict + +@app.on_event("startup") +async def startup_event(): + """Initialize model and load images when service starts""" + init_model() + # Add image loading if image_dir is provided + if hasattr(app, "image_dir") and app.image_dir: + global images_dict + try: + images_dict = load_images_from_directory(app.image_dir) + print(f"Successfully loaded {len(images_dict)} images from {app.image_dir}") + except Exception as e: + print(f"Failed to load images: {str(e)}") + +@app.get("/", response_class=HTMLResponse, tags=["Root"]) +async def read_root(request: Request): + return HTMLResponse( + content=f"

Welcome to Nexa AI SigLIP Image-Text Matching Service

Hostname: {hostname}

" + ) + +@app.get("/v1/list_images") +async def list_images(): + """Return current image directory path and loaded images""" + current_dir = getattr(app, "image_dir", None) + return { + "image_dir": current_dir, + "images_count": len(images_dict), + "images": list(images_dict.keys()), + "status": "active" if current_dir and images_dict else "no_images_loaded" + } + +@app.post("/v1/load_images") +async def load_images(request: ImagePathRequest): + """Load images from specified directory, replacing any previously loaded images""" + global images_dict + try: + temp_images = load_images_from_directory(request.image_dir) + + if not temp_images: + raise ValueError("No valid images found in the specified directory") + + images_dict.clear() + images_dict.update(temp_images) + app.image_dir = request.image_dir + + return { + "message": f"Successfully loaded {len(images_dict)} images from {request.image_dir}", + "images": list(images_dict.keys()) + } + except Exception as e: + current_count = len(images_dict) + error_message = f"Failed to load images: {str(e)}. Keeping existing {current_count} images." + raise HTTPException(status_code=400, detail=error_message) + +@app.post("/v1/find_similar", response_model=SearchResponse) +async def find_similar(text: str): + """Find image most similar to input text""" + if not images_dict: + raise HTTPException(status_code=400, detail="No images available, please load images first") + + try: + start_time = time.time() + image_paths = list(images_dict.keys()) + images = list(images_dict.values()) + + inputs = siglip_processor(text=[text], images=images, padding="max_length", return_tensors="pt") + + with torch.no_grad(): + outputs = siglip_model(**inputs) + + logits_per_image = outputs.logits_per_image + probs = torch.sigmoid(logits_per_image) + max_prob_index = torch.argmax(probs).item() + max_prob = probs[max_prob_index][0].item() + + return SearchResponse( + image_path=image_paths[max_prob_index], + similarity_score=max_prob, + latency = round(time.time() - start_time, 3) + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") + + +def run_nexa_ai_siglip_service(**kwargs): + host = kwargs.get("host", "localhost") + port = kwargs.get("port", 8100) + reload = kwargs.get("reload", False) + if kwargs.get("image_dir"): + app.image_dir = kwargs.get("image_dir") + uvicorn.run(app, host=host, port=port, reload=reload) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run the Nexa AI SigLIP Service" + ) + parser.add_argument( + "--image_dir", type=str, help="Directory of images to load" + ) + parser.add_argument( + "--host", type=str, default="localhost", help="Host to bind the server to" + ) + parser.add_argument( + "--port", type=int, default=8100, help="Port to bind the server to" + ) + parser.add_argument( + "--reload", type=bool, default=False, help="Reload the server on code changes" + ) + args = parser.parse_args() + run_nexa_ai_siglip_service( + image_dir=args.image_dir, + host=args.host, + port=args.port, + reload=args.reload + ) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 24b6ee35..c973ee18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,12 @@ convert = [ "nexa-gguf", ] +siglip = [ + "torch", + "transformers", + "sentencepiece", +] + [project.urls] Homepage = "https://github.com/NexaAI/nexa-sdk" Issues = "https://github.com/NexaAI/nexa-sdk/issues" diff --git a/requirements.txt b/requirements.txt index 6e732a0e..ed3e097d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,4 +46,9 @@ pytablewriter sacrebleu langdetect rouge_score -immutabledict \ No newline at end of file +immutabledict + +# For SigLIP +torch +transformers +sentencepiece \ No newline at end of file