diff --git a/dependency/llama.cpp b/dependency/llama.cpp index 868ceb69..39bf36ec 160000 --- a/dependency/llama.cpp +++ b/dependency/llama.cpp @@ -1 +1 @@ -Subproject commit 868ceb6965faa43de7ab9bc3efd97104185406c8 +Subproject commit 39bf36eceffc0fefa791b144859c4e642fab17d0 diff --git a/nexa/gguf/llama/audio_lm_cpp.py b/nexa/gguf/llama/audio_lm_cpp.py index 88db2a33..249cd1e8 100644 --- a/nexa/gguf/llama/audio_lm_cpp.py +++ b/nexa/gguf/llama/audio_lm_cpp.py @@ -93,6 +93,21 @@ def process_full(ctx: omni_context_p, params: omni_context_params_p, is_qwen: bo _lib = _lib_qwen if is_qwen else _lib_omni return _lib.omni_process_full(ctx, params) + +def process_streaming(ctx: omni_context_p, params: omni_context_params_p, is_qwen: bool = True): + _lib = _lib_qwen if is_qwen else _lib_omni + return _lib.omni_process_streaming(ctx, params) + + +def sample(omni_streaming: ctypes.c_void_p, is_qwen: bool = True): + _lib = _lib_qwen if is_qwen else _lib_omni + return _lib.sample(omni_streaming) + + +def get_str(omni_streaming: ctypes.c_void_p, is_qwen: bool = True): + _lib = _lib_qwen if is_qwen else _lib_omni + return _lib.get_str(omni_streaming) + # OMNI_AUDIO_API void omni_free(struct omni_context *ctx_omni); def free(ctx: omni_context_p, is_qwen: bool = True): _lib = _lib_qwen if is_qwen else _lib_omni @@ -111,6 +126,15 @@ def free(ctx: omni_context_p, is_qwen: bool = True): lib.omni_process_full.argtypes = [omni_context_p, omni_context_params_p] lib.omni_process_full.restype = ctypes.c_char_p + lib.omni_process_streaming.argtypes = [omni_context_p, omni_context_params_p] + lib.omni_process_streaming.restype = ctypes.c_void_p + + lib.sample.argtypes = [ctypes.c_void_p] + lib.sample.restype = ctypes.c_int32 + + lib.get_str.argtypes = [ctypes.c_void_p] + lib.get_str.restype = ctypes.c_char_p + # Configure free lib.omni_free.argtypes = [omni_context_p] lib.omni_free.restype = None diff --git a/nexa/gguf/nexa_inference_audio_lm.py b/nexa/gguf/nexa_inference_audio_lm.py index 64d87e7e..9f9f8311 100644 --- a/nexa/gguf/nexa_inference_audio_lm.py +++ b/nexa/gguf/nexa_inference_audio_lm.py @@ -160,12 +160,21 @@ def run(self): ) try: - with suppress_stdout_stderr(): - response = self.inference(audio_path, user_input) + # with suppress_stdout_stderr(): + # response = self.inference(audio_path, user_input) + first_chunk = True + for chunk in self.inference_streaming(audio_path, user_input): + if first_chunk: + stop_spinner(stop_event, spinner_thread) + first_chunk = False + if chunk == '\n': + chunk = '' + # print("FUCK") + print(chunk, end='', flush=True) + print() # '\n' finally: stop_spinner(stop_event, spinner_thread) - print(f"{response}") self.cleanup() except KeyboardInterrupt: @@ -216,6 +225,42 @@ def inference(self, audio_path: str, prompt: str = "") -> str: except Exception as e: raise RuntimeError(f"Error during inference: {str(e)}") + def inference_streaming(self, audio_path: str, prompt: str = "") -> str: + """ + Perform a single inference with the audio language model. + """ + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + + try: + # Ensure audio is at 16kHz before processing + audio_path = self._ensure_16khz(audio_path) + + self.ctx_params.file = ctypes.c_char_p(audio_path.encode("utf-8")) + self.ctx_params.prompt = ctypes.c_char_p(prompt.encode("utf-8")) + + with suppress_stdout_stderr(): + self.context = audio_lm_cpp.init_context( + ctypes.byref(self.ctx_params), is_qwen=self.is_qwen + ) + if not self.context: + raise RuntimeError("Failed to load audio language model") + logging.debug("Model loaded successfully") + + oss = audio_lm_cpp.process_streaming( + self.context, ctypes.byref(self.ctx_params), is_qwen=self.is_qwen + ) + res = 0 + while res >= 0: + res = audio_lm_cpp.sample(oss) + res_str = audio_lm_cpp.get_str(oss).decode('utf-8') + + if '<|im_start|>' in res_str or '' in res_str: + continue + yield res_str + except Exception as e: + raise RuntimeError(f"Error during inference: {str(e)}") + def cleanup(self): """ Explicitly cleanup resources diff --git a/nexa/gguf/server/nexa_service.py b/nexa/gguf/server/nexa_service.py index ecb4aaef..dce560a2 100644 --- a/nexa/gguf/server/nexa_service.py +++ b/nexa/gguf/server/nexa_service.py @@ -41,6 +41,7 @@ from nexa.general import pull_model from nexa.gguf.llama.llama import Llama from nexa.gguf.nexa_inference_vlm_omni import NexaOmniVlmInference +from nexa.gguf.nexa_inference_audio_lm import NexaAudioLMInference from nexa.gguf.sd.stable_diffusion import StableDiffusion from faster_whisper import WhisperModel import numpy as np @@ -443,6 +444,23 @@ async def load_model(): compute_type="float16" ) logging.info(f"model loaded as {model}") + elif model_type == "AudioLM": + with suppress_stdout_stderr(): + try: + model = NexaAudioLMInference( + model_path=model_path, + device="gpu" if is_gpu_available() else "cpu" + ) + except Exception as e: + logging.error( + f"Failed to load model: {e}. Falling back to CPU.", + exc_info=True, + ) + model = NexaAudioLMInference( + model_path=model_path, + device="cpu" + ) + logging.info(f"model loaded as {model}") else: raise ValueError(f"Model {model_path} not found in Model Hub. If you are using local path, be sure to add --local_path and --model_type flags.") @@ -860,7 +878,6 @@ async def multimodal_chat_completions(request: VLMChatCompletionRequest): elif isinstance(item, ImageUrlContent): try: image_data_uri = process_image_input(item.image_url) - logging.info(f"BRIAN: image_data_uri: {image_data_uri}") processed_content.append({ "type": "image_url", "image_url": {"url": image_data_uri} @@ -1238,6 +1255,85 @@ async def translate_audio( finally: os.unlink(temp_audio_path) +@app.post("/v1/audiolm/chat/completions", tags=["AudioLM"]) +async def audio_chat_completions( + file: UploadFile = File(...), + prompt: Optional[str] = Query(None, description="Prompt for audio chat completions"), + stream: Optional[bool] = Query(False, description="Whether to stream the response"), +): + temp_file = None + + try: + if model_type != "AudioLM": + raise HTTPException( + status_code=400, + detail="The model that is loaded is not an AudioLM model. Please use an AudioLM model for audio chat completions." + ) + + temp_file = tempfile.NamedTemporaryFile(suffix=os.path.splitext(file.filename)[1], delete=False) + temp_file.write(await file.read()) + temp_file.flush() + os.fsync(temp_file.fileno()) + audio_path = temp_file.name + + if stream: + async def stream_with_cleanup(): + try: + for token in model.inference_streaming(audio_path, prompt or ""): + chunk = { + "id": str(uuid.uuid4()), + "object": "chat.completion.chunk", + "created": time.time(), + "choices": [{ + "delta": {"content": token}, + "index": 0, + "finish_reason": None + }] + } + yield f"data: {json.dumps(chunk)}\n\n" + yield "data: [DONE]\n\n" + finally: + temp_file.close() + if os.path.exists(audio_path): + os.unlink(audio_path) + + return StreamingResponse( + stream_with_cleanup(), + media_type="text/event-stream" + ) + else: + try: + print("audio_path: ", audio_path) + response = model.inference(audio_path, prompt or "") + return { + "id": str(uuid.uuid4()), + "object": "chat.completion", + "created": time.time(), + "choices": [{ + "message": {"role": "assistant", "content": response}, + "index": 0, + "finish_reason": "stop" + }], + } + finally: + temp_file.close() + if os.path.exists(audio_path): + os.unlink(audio_path) + + except Exception as e: + if temp_file: + temp_file.close() + if os.path.exists(temp_file.name): + try: + os.unlink(temp_file.name) + except Exception as cleanup_error: + logging.error(f"Error cleaning up file {temp_file.name}: {cleanup_error}") + + if isinstance(e, HTTPException): + raise e + logging.error(f"Error in audio chat completions: {e}") + raise HTTPException(status_code=500, detail=str(e)) + @app.post("/v1/embeddings", tags=["Embedding"]) async def create_embedding(request: EmbeddingRequest): try: