diff --git a/nexa/gguf/nexa_inference_audio_lm.py b/nexa/gguf/nexa_inference_audio_lm.py index f264cd3a..471e813e 100644 --- a/nexa/gguf/nexa_inference_audio_lm.py +++ b/nexa/gguf/nexa_inference_audio_lm.py @@ -133,12 +133,12 @@ def _load_model(self): 0x7FFFFFFF if self.n_gpu_layers == -1 else self.n_gpu_layers ) # 0x7FFFFFFF is INT32 max, will be auto set to all layers - self.context = audio_lm_cpp.init_context( - ctypes.byref(self.ctx_params), is_qwen=self.is_qwen - ) - if not self.context: - raise RuntimeError("Failed to load audio language model") - logging.debug("Model loaded successfully") + # self.context = audio_lm_cpp.init_context( + # ctypes.byref(self.ctx_params), is_qwen=self.is_qwen + # ) + # if not self.context: + # raise RuntimeError("Failed to load audio language model") + # logging.debug("Model loaded successfully") except Exception as e: logging.error(f"Error loading model: {e}") raise @@ -190,6 +190,13 @@ def inference(self, audio_path: str, prompt: str = "") -> str: self.ctx_params.file = ctypes.c_char_p(audio_path.encode("utf-8")) self.ctx_params.prompt = ctypes.c_char_p(prompt.encode("utf-8")) + self.context = audio_lm_cpp.init_context( + ctypes.byref(self.ctx_params), is_qwen=self.is_qwen + ) + if not self.context: + raise RuntimeError("Failed to load audio language model") + logging.debug("Model loaded successfully") + response = audio_lm_cpp.process_full( self.context, ctypes.byref(self.ctx_params), is_qwen=self.is_qwen )