Skip to content

Commit

Permalink
Merge pull request #224 from NexaAI/david/newfeature
Browse files Browse the repository at this point in the history
add vlm_omni & audio_lm streamlit support
  • Loading branch information
zhiyuan8 authored Nov 8, 2024
2 parents 0485667 + 8016aac commit 4eee461
Show file tree
Hide file tree
Showing 8 changed files with 378 additions and 34 deletions.
2 changes: 1 addition & 1 deletion dependency/llama.cpp
6 changes: 3 additions & 3 deletions nexa/cli/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def run_ggml_inference(args):
if is_local_path:
local_path = os.path.abspath(model_path)
model_path = local_path
if run_type == "Multimodal":
if run_type == "Multimodal" or run_type == "AudioLM":
if not os.path.isdir(local_path):
print("Error: For Multimodal models with --local_path, the provided path must be a directory containing both model and projector ggufs.")
return
Expand All @@ -73,7 +73,7 @@ def run_ggml_inference(args):
else: # hf case
# TODO: remove this after adding support for Multimodal model in CLI
if run_type == "Multimodal" or run_type == "Audio" or run_type == "TTS":
print("Running multimodal model or audio model from Hugging Face is currently not supported in CLI mode. Please use SDK to run Multimodal model or Audio model or TTS model.")
print("Running multimodal model or audio model or TTS model from Hugging Face is currently not supported in CLI mode. Please use SDK to run Multimodal model or Audio model or TTS model.")
return
from nexa.general import pull_model
local_path, _ = pull_model(model_path, hf=True)
Expand Down Expand Up @@ -130,7 +130,7 @@ def run_ggml_inference(args):
return

if hasattr(args, 'streamlit') and args.streamlit:
if run_type == "Multimodal":
if run_type == "Multimodal" or run_type == "AudioLM":
inference.run_streamlit(model_path, is_local_path = is_local_path, hf = hf, projector_local_path = projector_local_path)
else:
inference.run_streamlit(model_path, is_local_path = is_local_path, hf = hf)
Expand Down
2 changes: 1 addition & 1 deletion nexa/gguf/llama/omni_vlm_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def omnivlm_inference(prompt: omni_char_p, image_path: omni_char_p):


_lib.omnivlm_inference.argtypes = [omni_char_p, omni_char_p]
_lib.omnivlm_inference.restype = None
_lib.omnivlm_inference.restype = omni_char_p


def omnivlm_free():
Expand Down
151 changes: 131 additions & 20 deletions nexa/gguf/nexa_inference_audio_lm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import ctypes
import logging
import os
import sys
import librosa
import tempfile
import soundfile as sf
from pathlib import Path
from streamlit.web import cli as stcli
from nexa.utils import SpinningCursorAnimation, nexa_prompt
from nexa.constants import (
DEFAULT_TEXT_GEN_PARAMS,
Expand All @@ -13,7 +18,6 @@
from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr
from nexa.general import pull_model


def is_qwen(model_name):
if "qwen" in model_name.lower(): # TEMPORARY SOLUTION : this hardcode can be risky
return True
Expand Down Expand Up @@ -60,6 +64,8 @@ def __init__(
self.projector_downloaded_path = projector_local_path
self.device = device
self.context = None
self.temp_file = None

if self.device == "auto" or self.device == "gpu":
self.n_gpu_layers = -1 if is_gpu_available() else 0
else:
Expand Down Expand Up @@ -138,30 +144,84 @@ def _load_model(self):
raise

def run(self):
"""
Run the audio language model inference loop.
"""
try:
while True:
audio_path = self._get_valid_audio_path()
user_input = nexa_prompt("Enter text (leave empty if no prompt): ")

response = self.inference(audio_path, user_input)
print(response)

except KeyboardInterrupt:
print("\nExiting...")
except Exception as e:
logging.error(f"\nError during audio generation: {e}", exc_info=True)
finally:
self.cleanup()

def _get_valid_audio_path(self) -> str:
"""
Helper method to get a valid audio file path from user
"""
while True:
try:
while True:
audio_path = nexa_prompt("Enter the path to your audio file (required): ")
if os.path.exists(audio_path):
break
print(f"'{audio_path}' is not a valid audio path. Please try again.")
audio_path = nexa_prompt("Enter the path to your audio file (required): ")
if os.path.exists(audio_path):
# Check if it's a supported audio format
if any(audio_path.lower().endswith(ext) for ext in ['.wav', '.mp3', '.m4a', '.flac', '.ogg']):
return audio_path
print(f"Unsupported audio format. Please use WAV, MP3, M4A, FLAC, or OGG files.")
else:
print(f"'{audio_path}' is not a valid audio path. Please try again.")

user_input = nexa_prompt("Enter text (leave empty if no prompt): ")
def inference(self, audio_path: str, prompt: str = "") -> str:
"""
Perform a single inference with the audio language model.
"""
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")

self.ctx_params.file = ctypes.c_char_p(audio_path.encode("utf-8"))
self.ctx_params.prompt = ctypes.c_char_p(user_input.encode("utf-8"))
try:
# Ensure audio is at 16kHz before processing
audio_path = self._ensure_16khz(audio_path)

response = audio_lm_cpp.process_full(
self.context, ctypes.byref(self.ctx_params), is_qwen=self.is_qwen
).decode("utf-8")
print(response)
self.ctx_params.file = ctypes.c_char_p(audio_path.encode("utf-8"))
self.ctx_params.prompt = ctypes.c_char_p(prompt.encode("utf-8"))

except KeyboardInterrupt:
print("\nExiting...")
break
response = audio_lm_cpp.process_full(
self.context, ctypes.byref(self.ctx_params), is_qwen=self.is_qwen
)
return response
except Exception as e:
raise RuntimeError(f"Error during inference: {str(e)}")
finally:
if self.temp_file:
try:
self.temp_file.close()
if os.path.exists(self.temp_file.name):
os.unlink(self.temp_file.name)
except:
pass
self.temp_file = None

except Exception as e:
logging.error(f"\nError during audio generation: {e}", exc_info=True)
def cleanup(self):
"""
Explicitly cleanup resources
"""
if self.context:
audio_lm_cpp.free(self.context, is_qwen=self.is_qwen)
self.context = None

if self.temp_file:
try:
self.temp_file.close()
if os.path.exists(self.temp_file.name):
os.unlink(self.temp_file.name)
except:
pass
self.temp_file = None

def __del__(self):
"""
Expand All @@ -170,6 +230,47 @@ def __del__(self):
if self.context:
audio_lm_cpp.free(self.context, is_qwen=self.is_qwen)

def _ensure_16khz(self, audio_path: str) -> str:
"""
Check if audio is 16kHz, resample if necessary.
Supports various audio formats (mp3, wav, m4a, etc.)
"""
try:
y, sr = librosa.load(audio_path, sr=None)

if sr == 16000:
return audio_path

# Resample to 16kHz
print(f"Resampling audio from {sr} to 16000")
y_resampled = librosa.resample(y=y, orig_sr=sr, target_sr=16000)
self.temp_file = tempfile.NamedTemporaryFile(
suffix='.wav',
delete=False
)
sf.write(
self.temp_file.name,
y_resampled,
16000,
subtype='PCM_16'
)
return self.temp_file.name

except Exception as e:
raise RuntimeError(f"Error processing audio file: {str(e)}")

def run_streamlit(self, model_path: str, is_local_path = False, hf = False, projector_local_path = None):
"""
Run the Streamlit UI.
"""
logging.info("Running Streamlit UI...")

streamlit_script_path = (
Path(os.path.abspath(__file__)).parent / "streamlit" / "streamlit_audio_lm.py"
)

sys.argv = ["streamlit", "run", str(streamlit_script_path), model_path, str(is_local_path), str(hf), str(projector_local_path)]
sys.exit(stcli.main())

if __name__ == "__main__":
import argparse
Expand All @@ -190,10 +291,20 @@ def __del__(self):
default="auto",
help="Device to use for inference (auto, cpu, or gpu)",
)
parser.add_argument(
"-st",
"--streamlit",
action="store_true",
help="Run the inference in Streamlit UI",
)

args = parser.parse_args()
kwargs = {k: v for k, v in vars(args).items() if v is not None}
model_path = kwargs.pop("model_path")
device = kwargs.pop("device", "auto")

inference = NexaAudioLMInference(model_path, device=device, **kwargs)
inference.run()
if args.streamlit:
inference.run_streamlit(model_path)
else:
inference.run()
46 changes: 40 additions & 6 deletions nexa/gguf/nexa_inference_vlm_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import ctypes
import logging
import os
import sys
from pathlib import Path
from streamlit.web import cli as stcli
from nexa.utils import nexa_prompt, SpinningCursorAnimation
from nexa.constants import (
DEFAULT_TEXT_GEN_PARAMS,
Expand Down Expand Up @@ -104,22 +106,44 @@ def run(self):
image_path = nexa_prompt("Image Path (required): ")
if not os.path.exists(image_path):
print(f"Image path: {image_path} not found, running omni VLM without image input.")

user_input = nexa_prompt()
image_path = ctypes.c_char_p(image_path.encode("utf-8"))
user_input = ctypes.c_char_p(user_input.encode("utf-8"))
omni_vlm_cpp.omnivlm_inference(user_input, image_path)

response = self.inference(user_input, image_path)
print(f"\nResponse: {response}")
except KeyboardInterrupt:
print("\nExiting...")
break
except Exception as e:
logging.error(f"\nError during audio generation: {e}", exc_info=True)
print("\n")

def inference(self, prompt: str, image_path: str):
with suppress_stdout_stderr():
prompt = ctypes.c_char_p(prompt.encode("utf-8"))
image_path = ctypes.c_char_p(image_path.encode("utf-8"))
response = omni_vlm_cpp.omnivlm_inference(prompt, image_path)

decoded_response = response.decode('utf-8')
if '<|im_start|>assistant' in decoded_response:
decoded_response = decoded_response.replace('<|im_start|>assistant', '').strip()

return decoded_response

def __del__(self):
omni_vlm_cpp.omnivlm_free()

def run_streamlit(self, model_path: str, is_local_path = False, hf = False, projector_local_path = None):
"""
Run the Streamlit UI.
"""
logging.info("Running Streamlit UI...")

streamlit_script_path = (
Path(os.path.abspath(__file__)).parent / "streamlit" / "streamlit_vlm_omni.py"
)

sys.argv = ["streamlit", "run", str(streamlit_script_path), model_path, str(is_local_path), str(hf), str(projector_local_path)]
sys.exit(stcli.main())


if __name__ == "__main__":
import argparse
Expand All @@ -140,10 +164,20 @@ def __del__(self):
default="auto",
help="Device to use for inference (auto, cpu, or gpu)",
)
parser.add_argument(
"-st",
"--streamlit",
action="store_true",
help="Run the inference in Streamlit UI",
)

args = parser.parse_args()
kwargs = {k: v for k, v in vars(args).items() if v is not None}
model_path = kwargs.pop("model_path")
device = kwargs.pop("device", "auto")

inference = NexaOmniVlmInference(model_path, device=device, **kwargs)
inference.run()
if args.streamlit:
inference.run_streamlit(model_path)
else:
inference.run()
Loading

0 comments on commit 4eee461

Please sign in to comment.