Skip to content

Commit

Permalink
update load_model function to support loading both model and whisper_…
Browse files Browse the repository at this point in the history
…model
  • Loading branch information
Davidqian123 committed Dec 19, 2024
1 parent 5531b0d commit 0f05d20
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 100 deletions.
2 changes: 2 additions & 0 deletions nexa/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class ModelType(Enum):
"phi2": "Phi-2:q4_0",
"phi3": "Phi-3-mini-128k-instruct:q4_0",
"phi3.5": "Phi-3.5-mini-instruct:q4_0",
"phi4": "Phi:q4_0",
"llama2-uncensored": "Llama2-7b-chat-uncensored:q4_0",
"llama3-uncensored": "Llama3-8B-Lexi-Uncensored:q4_K_M",
"openelm": "OpenELM-3B:q4_K_M",
Expand Down Expand Up @@ -413,6 +414,7 @@ class ModelType(Enum):
"Phi-3-mini-128k-instruct": ModelType.NLP,
"Phi-3-mini-4k-instruct": ModelType.NLP,
"Phi-3.5-mini-instruct": ModelType.NLP,
"Phi-4": ModelType.NLP,
"CodeQwen1.5-7B-Instruct": ModelType.NLP,
"Qwen2-0.5B-Instruct": ModelType.NLP,
"Qwen2-1.5B-Instruct": ModelType.NLP,
Expand Down
223 changes: 123 additions & 100 deletions nexa/gguf/server/nexa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
NEXA_RUN_PROJECTOR_MAP,
NEXA_RUN_OMNI_VLM_MAP,
NEXA_RUN_OMNI_VLM_PROJECTOR_MAP,
NEXA_RUN_MODEL_MAP_AUDIO_LM,
NEXA_RUN_AUDIO_LM_PROJECTOR_MAP,
NEXA_RUN_COMPLETION_TEMPLATE_MAP,
NEXA_RUN_MODEL_PRECISION_MAP,
NEXA_RUN_MODEL_MAP_FUNCTION_CALLING,
Expand Down Expand Up @@ -80,12 +82,14 @@
)

model = None
whisper_model = None
chat_format = None
completion_template = None
hostname = socket.gethostname()
chat_completion_system_prompt = [{"role": "system", "content": "You are a helpful assistant"}]
function_call_system_prompt = [{"role": "system", "content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"}]
model_path = None
whisper_model_path = "faster-whisper-tiny" # by default, use tiny whisper model
n_ctx = None
is_local_path = False
model_type = None
Expand Down Expand Up @@ -215,6 +219,8 @@ class LoadModelRequest(BaseModel):
model_config = {
"protected_namespaces": ()
}
class LoadWhisperModelRequest(BaseModel):
whisper_model_path: str = "faster-whisper-tiny"

class DownloadModelRequest(BaseModel):
model_path: str = "llama3.2"
Expand Down Expand Up @@ -295,16 +301,16 @@ async def load_model():
raise ValueError("Multimodal and Audio models are not supported for Hugging Face")
downloaded_path, _ = pull_model(model_path, hf=is_huggingface, ms=is_modelscope)
else:
if model_path in NEXA_RUN_MODEL_MAP_VLM or model_path in NEXA_RUN_OMNI_VLM_MAP:
if model_path in NEXA_RUN_MODEL_MAP_VLM or model_path in NEXA_RUN_OMNI_VLM_MAP or model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM:
if model_path in NEXA_RUN_OMNI_VLM_MAP:
logging.info(f"Path is OmniVLM model: {model_path}")
downloaded_path, _ = pull_model(NEXA_RUN_OMNI_VLM_MAP[model_path])
downloaded_path, model_type = pull_model(NEXA_RUN_OMNI_VLM_MAP[model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[model_path])
else:
logging.info(f"Path is in NEXA_RUN_MODEL_MAP_VLM: {model_path}")
downloaded_path, _ = pull_model(NEXA_RUN_MODEL_MAP_VLM[model_path])
elif model_path in NEXA_RUN_MODEL_MAP_VLM:
downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_VLM[model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[model_path])
model_type = "Multimodal"
elif model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM:
downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_AUDIO_LM[model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_AUDIO_LM_PROJECTOR_MAP[model_path])
else:
downloaded_path, model_type = pull_model(model_path)

Expand Down Expand Up @@ -436,14 +442,6 @@ async def load_model():
n_gpu_layers=0, # hardcode to use CPU
)
logging.info(f"Model loaded as {model}")
elif model_type == "Audio":
with suppress_stdout_stderr():
model = WhisperModel(
downloaded_path,
device="cpu", # only support cpu for now because cuDNN needs to be installed on user's machine
compute_type="default"
)
logging.info(f"model loaded as {model}")
elif model_type == "AudioLM":
with suppress_stdout_stderr():
try:
Expand All @@ -463,7 +461,24 @@ async def load_model():
logging.info(f"model loaded as {model}")
else:
raise ValueError(f"Model {model_path} not found in Model Hub. If you are using local path, be sure to add --local_path and --model_type flags.")


async def load_whisper_model(custom_whisper_model_path=None):
global whisper_model, whisper_model_path
try:
if custom_whisper_model_path:
whisper_model_path = custom_whisper_model_path
downloaded_path, _ = pull_model(whisper_model_path)
with suppress_stdout_stderr():
whisper_model = WhisperModel(
downloaded_path,
device="cpu", # only support cpu for now because cuDNN needs to be installed on user's machine
compute_type="default"
)
logging.info(f"whisper model loaded as {whisper_model}")
except Exception as e:
logging.error(f"Error loading Whisper model: {e}")
raise ValueError(f"Failed to load Whisper model: {str(e)}")

def nexa_run_text_generation(
prompt, temperature, stop_words, max_new_tokens, top_k, top_p, logprobs=None, stream=False, is_chat_completion=True
) -> Dict[str, Any]:
Expand Down Expand Up @@ -710,16 +725,23 @@ def _resp_async_generator(streamer):
async def download_model(request: DownloadModelRequest):
"""Download a model from the model hub"""
try:
if request.model_path in NEXA_RUN_MODEL_MAP_VLM: # for Multimodal models
downloaded_path, _ = pull_model(NEXA_RUN_MODEL_MAP_VLM[request.model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[request.model_path])
if request.model_path in NEXA_RUN_MODEL_MAP_VLM or request.model_path in NEXA_RUN_OMNI_VLM_MAP or request.model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM: # models and projectors
if request.model_path in NEXA_RUN_MODEL_MAP_VLM:
downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_VLM[request.model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_PROJECTOR_MAP[request.model_path])
elif request.model_path in NEXA_RUN_OMNI_VLM_MAP:
downloaded_path, model_type = pull_model(NEXA_RUN_OMNI_VLM_MAP[request.model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[request.model_path])
elif request.model_path in NEXA_RUN_MODEL_MAP_AUDIO_LM:
downloaded_path, model_type = pull_model(NEXA_RUN_MODEL_MAP_AUDIO_LM[request.model_path])
projector_downloaded_path, _ = pull_model(NEXA_RUN_AUDIO_LM_PROJECTOR_MAP[request.model_path])
return {
"status": "success",
"message": "Successfully downloaded multimodal model and projector",
"message": "Successfully downloaded model and projector",
"model_path": request.model_path,
"model_local_path": downloaded_path,
"projector_local_path": projector_downloaded_path,
"model_type": "Multimodal"
"model_type": model_type
}
else:
downloaded_path, model_type = pull_model(request.model_path)
Expand Down Expand Up @@ -768,6 +790,26 @@ async def load_different_model(request: LoadModelRequest):
detail=f"Failed to load model: {str(e)}"
)

@app.post("/v1/load_whisper_model", tags=["Model"])
async def load_different_whisper_model(request: LoadWhisperModelRequest):
"""Load a different Whisper model while maintaining the global model state"""
try:
global whisper_model_path
whisper_model_path = request.whisper_model_path
await load_whisper_model(custom_whisper_model_path=whisper_model_path)

return {
"status": "success",
"message": f"Successfully loaded Whisper model: {whisper_model_path}",
"model_type": "Audio",
}
except Exception as e:
logging.error(f"Error loading Whisper model: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to load Whisper model: {str(e)}"
)

@app.get("/v1/list_models", tags=["Model"])
async def list_models():
"""List all models available in the model hub"""
Expand Down Expand Up @@ -1123,10 +1165,10 @@ async def process_audio(
temperature: Optional[float] = Query(0.0, description="Temperature for sampling.")
):
try:
if model_type != "Audio":
if not whisper_model:
raise HTTPException(
status_code=400,
detail="The model that is loaded is not an Audio model. Please use an Audio model."
detail="Whisper model is not loaded. Please load a Whisper model first."
)

with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_audio:
Expand All @@ -1146,7 +1188,7 @@ async def process_audio(
if task == "transcribe" and language:
task_params["language"] = language

segments, _ = model.transcribe(temp_audio_path, **task_params)
segments, _ = whisper_model.transcribe(temp_audio_path, **task_params)
result_text = "".join(segment.text for segment in segments)
return JSONResponse(content={"text": result_text})

Expand All @@ -1166,94 +1208,75 @@ async def processing_stream_audio(
language: Optional[str] = Query("auto", description="Language code (e.g., 'en', 'fr')"),
min_chunk: Optional[float] = Query(1.0, description="Minimum chunk duration for streaming"),
):
# Read the entire file into memory
audio_bytes = await file.read()
a_full = load_audio_from_bytes(audio_bytes)
duration = len(a_full) / SAMPLING_RATE

# Only include language parameter if task is "transcribe"
# For "translate", the language is always defined as "en"
if task == "transcribe" and language != "auto":
used_language = language
else:
used_language = None

warmup_audio = a_full[:SAMPLING_RATE] # first second
model.transcribe(warmup_audio)

streamer = StreamASRProcessor(model, task, used_language)

start = time.time()
beg = 0.0

def stream_generator():
nonlocal beg
while beg < duration:
now = time.time() - start
if now < beg + min_chunk:
time.sleep((beg + min_chunk) - now)
end = time.time() - start
if end > duration:
end = duration
try:
if not whisper_model:
raise HTTPException(
status_code=400,
detail="Whisper model is not loaded. Please load a Whisper model first."
)

chunk_samples = int((end - beg)*SAMPLING_RATE)
chunk_audio = a_full[int(beg*SAMPLING_RATE):int(beg*SAMPLING_RATE)+chunk_samples]
beg = end
# Read the entire file into memory
audio_bytes = await file.read()
a_full = load_audio_from_bytes(audio_bytes)
duration = len(a_full) / SAMPLING_RATE

streamer.insert_audio_chunk(chunk_audio)
o = streamer.process_iter()
# Only include language parameter if task is "transcribe"
# For "translate", the language is always defined as "en"
if task == "transcribe" and language != "auto":
used_language = language
else:
used_language = None

warmup_audio = a_full[:SAMPLING_RATE] # first second
whisper_model.transcribe(warmup_audio)

streamer = StreamASRProcessor(whisper_model, task, used_language)

start = time.time()
beg = 0.0

def stream_generator():
nonlocal beg
while beg < duration:
now = time.time() - start
if now < beg + min_chunk:
time.sleep((beg + min_chunk) - now)
end = time.time() - start
if end > duration:
end = duration

chunk_samples = int((end - beg)*SAMPLING_RATE)
chunk_audio = a_full[int(beg*SAMPLING_RATE):int(beg*SAMPLING_RATE)+chunk_samples]
beg = end

streamer.insert_audio_chunk(chunk_audio)
o = streamer.process_iter()
if o[0] is not None:
data = {
"emission_time_ms": (time.time()-start)*1000,
"segment_start_ms": o[0]*1000,
"segment_end_ms": o[1]*1000,
"text": o[2]
}
yield f"data: {json.dumps(data)}\n\n".encode("utf-8")

# Final flush
o = streamer.finish()
if o[0] is not None:
data = {
"emission_time_ms": (time.time()-start)*1000,
"segment_start_ms": o[0]*1000,
"segment_end_ms": o[1]*1000,
"text": o[2]
"text": o[2],
"final": True
}
yield f"data: {json.dumps(data)}\n\n".encode("utf-8")

# Final flush
o = streamer.finish()
if o[0] is not None:
data = {
"emission_time_ms": (time.time()-start)*1000,
"segment_start_ms": o[0]*1000,
"segment_end_ms": o[1]*1000,
"text": o[2],
"final": True
}
yield f"data: {json.dumps(data)}\n\n".encode("utf-8")

return StreamingResponse(stream_generator(), media_type="application/x-ndjson")

@app.post("/v1/audio/translations", tags=["Audio"])
async def translate_audio(
file: UploadFile = File(...),
beam_size: Optional[int] = Query(5, description="Beam size for translation"),
temperature: Optional[float] = Query(0.0, description="Temperature for sampling"),
):
try:
if model_type != "Audio":
raise HTTPException(
status_code=400,
detail="The model that is loaded is not an Audio model. Please use an Audio model for audio translation."
)
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_audio:
temp_audio.write(await file.read())
temp_audio_path = temp_audio.name
return StreamingResponse(stream_generator(), media_type="application/x-ndjson")

translate_params = {
"beam_size": beam_size,
"task": "translate",
"temperature": temperature,
"vad_filter": True
}
segments, _ = model.transcribe(temp_audio_path, **translate_params)
translation = "".join(segment.text for segment in segments)
return JSONResponse(content={"text": translation})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during translation: {str(e)}")
finally:
os.unlink(temp_audio_path)
logging.error(f"Error in audio processing stream: {e}")
raise HTTPException(status_code=500, detail=str(e))

@app.post("/v1/audiolm/chat/completions", tags=["AudioLM"])
async def audio_chat_completions(
Expand Down

0 comments on commit 0f05d20

Please sign in to comment.