From 0868715718c8b9d90e085e3ae8f3b9cb65df4098 Mon Sep 17 00:00:00 2001 From: Davidqian123 Date: Wed, 21 Aug 2024 19:36:53 +0000 Subject: [PATCH 1/7] update --- nexa/cli/entry.py | 254 +++++++++++++++++----------------------------- 1 file changed, 92 insertions(+), 162 deletions(-) diff --git a/nexa/cli/entry.py b/nexa/cli/entry.py index 66c3ca70..93f1dd0d 100644 --- a/nexa/cli/entry.py +++ b/nexa/cli/entry.py @@ -1,71 +1,56 @@ import argparse +from nexa.general import pull_model import uvicorn +def run_inference(args): + local_path, run_type = pull_model(args.model_path) + kwargs = {k: v for k, v in vars(args).items() if v is not None} + + if run_type == "gen-text": + from nexa.gguf.nexa_inference_text import NexaTextInference + inference = NexaTextInference(local_path, **kwargs) + elif run_type == "gen-image": + from nexa.gguf.nexa_inference_image import NexaImageInference + inference = NexaImageInference(local_path, **kwargs) + elif run_type == "vlm": + from nexa.gguf.nexa_inference_vlm import NexaVLMInference + inference = NexaVLMInference(local_path, **kwargs) + elif run_type == "asr": + from nexa.gguf.nexa_inference_voice import NexaVoiceInference + inference = NexaVoiceInference(local_path, **kwargs) + else: + raise ValueError(f"Unknown task: {run_type}") + + if args.streamlit: + inference.run_streamlit(local_path) + else: + inference.run() + def run_onnx_inference(args): + local_path, run_type = pull_model(args.model_path) kwargs = {k: v for k, v in vars(args).items() if v is not None} - model_path = kwargs.pop("model_path") - if args.onnx_command == "gen-text": + + if run_type == "gen-text": from nexa.onnx.nexa_inference_text import \ NexaTextInference as NexaTextOnnxInference - inference = NexaTextOnnxInference(model_path, **kwargs) - elif args.onnx_command == "gen-image": + inference = NexaTextOnnxInference(local_path, **kwargs) + elif run_type == "gen-image": from nexa.onnx.nexa_inference_image import \ NexaImageInference as NexaImageOnnxInference - inference = NexaImageOnnxInference(model_path, **kwargs) - elif args.onnx_command == "asr": + inference = NexaImageOnnxInference(local_path, **kwargs) + elif run_type == "asr": from nexa.onnx.nexa_inference_voice import \ NexaVoiceInference as NexaVoiceOnnxInference - inference = NexaVoiceOnnxInference(model_path, **kwargs) - elif args.onnx_command == "tts": + inference = NexaVoiceOnnxInference(local_path, **kwargs) + elif run_type == "tts": from nexa.onnx.nexa_inference_tts import \ NexaTTSInference as NexaTTSOnnxInference - inference = NexaTTSOnnxInference(model_path, **kwargs) - elif args.onnx_command == "server": - from nexa.onnx.server.nexa_service import app as NexaOnnxServer - uvicorn.run(NexaOnnxServer, host=args.host, port=args.port, reload=args.reload) + inference = NexaTTSOnnxInference(local_path, **kwargs) else: raise ValueError(f"Unknown ONNX command: {args.onnx_command}") if hasattr(args, 'streamlit') and args.streamlit: - inference.run_streamlit(model_path) - else: - inference.run() - -def run_ggml_inference(args): - kwargs = {k: v for k, v in vars(args).items() if v is not None} - model_path = kwargs.pop("model_path") - - if args.command == "server": - from nexa.gguf.server.nexa_service import run_nexa_ai_service as NexaServer - NexaServer(model_path, **kwargs) - return - - stop_words = kwargs.pop("stop_words", []) - - if args.command == "gen-text": - from nexa.gguf.nexa_inference_text import NexaTextInference - inference = NexaTextInference(model_path, stop_words=stop_words, **kwargs) - elif args.command == "gen-image": - from nexa.gguf.nexa_inference_image import NexaImageInference - inference = NexaImageInference(model_path, **kwargs) - if hasattr(args, 'streamlit') and args.streamlit: - inference.run_streamlit(model_path) - elif args.img2img: - inference.loop_img2img() - else: - inference.loop_txt2img() - return - elif args.command == "vlm": - from nexa.gguf.nexa_inference_vlm import NexaVLMInference - inference = NexaVLMInference(model_path, stop_words=stop_words, **kwargs) - elif args.command == "asr": - from nexa.gguf.nexa_inference_voice import NexaVoiceInference - inference = NexaVoiceInference(model_path, **kwargs) - else: - raise ValueError(f"Unknown command: {args.command}") - - if hasattr(args, 'streamlit') and args.streamlit: - inference.run_streamlit(model_path) + inference.run_streamlit(local_path) else: inference.run() @@ -75,104 +60,53 @@ def main(): ) subparsers = parser.add_subparsers(dest="command", help="sub-command help") - # ONNX subparsers - onnx_parser = subparsers.add_parser("onnx", help="Run ONNX models for inference.") - onnx_subparsers = onnx_parser.add_subparsers(dest="onnx_command", help="ONNX sub-command help") - - # ONNX Text Generation - onnx_text_parser = onnx_subparsers.add_parser("gen-text", help="Run ONNX model for text generation.") - onnx_text_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - onnx_text_parser.add_argument("-t", "--temperature", type=float, default=0.8, help="Temperature for sampling") - onnx_text_parser.add_argument("-m", "--max_new_tokens", type=int, default=512, help="Maximum number of new tokens to generate") - onnx_text_parser.add_argument("-k", "--top_k", type=int, default=50, help="Top-k sampling parameter") - onnx_text_parser.add_argument("-p", "--top_p", type=float, default=1.0, help="Top-p sampling parameter") - onnx_text_parser.add_argument("-sw", "--stop_words", nargs="*", default=[], help="List of stop words for early stopping") - onnx_text_parser.add_argument("-pf", "--profiling", action="store_true", help="Enable profiling logs for the inference process") - onnx_text_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") - - # ONNX Image Generation - onnx_image_parser = onnx_subparsers.add_parser("gen-image", help="Run ONNX model for image generation.") - onnx_image_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - onnx_image_parser.add_argument("-ns", "--num_inference_steps", type=int, default=20, help="Number of inference steps") - onnx_image_parser.add_argument("-np", "--num_images_per_prompt", type=int, default=1, help="Number of images to generate per prompt") - onnx_image_parser.add_argument("-H", "--height", type=int, default=512, help="Height of the output image") - onnx_image_parser.add_argument("-W", "--width", type=int, default=512, help="Width of the output image") - onnx_image_parser.add_argument("-g", "--guidance_scale", type=float, default=7.5, help="Guidance scale for diffusion") - onnx_image_parser.add_argument("-o", "--output", type=str, default="generated_images/image.png", help="Output path for the generated image") - onnx_image_parser.add_argument("-s", "--random_seed", type=int, default=41, help="Random seed for image generation") - onnx_image_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") - - # ONNX ASR parser - onnx_asr_parser = onnx_subparsers.add_parser('asr', help='Run ONNX model for auto-speech-recognition.') - onnx_asr_parser.add_argument("model_path", type=str, help="Path or identifier for the model in S3") - onnx_asr_parser.add_argument("-o", "--output_dir", type=str, default="transcriptions", help="Output directory for transcriptions") - onnx_asr_parser.add_argument("-r", "--sampling_rate", type=int, default=16000, help="Sampling rate for audio processing") - onnx_asr_parser.add_argument("-st", "--streamlit", action='store_true', help="Run the inference in Streamlit UI") - - # ONNX voice-generation parser - onnx_tts_parser = onnx_subparsers.add_parser('tts', help='Run ONNX model for text-to-speech generation.') - onnx_tts_parser.add_argument("model_path", type=str, help="Path or identifier for the model in S3") - onnx_tts_parser.add_argument("-o", "--output_dir", type=str, default="tts", help="Output directory for tts") - onnx_tts_parser.add_argument("-r", "--sampling_rate", type=int, default=16000, help="Sampling rate for audio processing") - onnx_tts_parser.add_argument("-st", "--streamlit", action='store_true', help="Run the inference in Streamlit UI") - - # ONNX server parser - onnx_server_parser = onnx_subparsers.add_parser("server", help="Run the Nexa AI Text Generation Service") - onnx_server_parser.add_argument("model_path", type=str, help="Path or identifier for the model in S3") - onnx_server_parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to") - onnx_server_parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to") - onnx_server_parser.add_argument("--reload", action="store_true", help="Enable automatic reloading on code changes") - - # GGML Text Generation - gen_text_parser = subparsers.add_parser("gen-text", help="Run a GGUF model locally for text generation.") - gen_text_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - gen_text_parser.add_argument("-t", "--temperature", type=float, default=0.8, help="Temperature for sampling") - gen_text_parser.add_argument("-m", "--max_new_tokens", type=int, default=512, help="Maximum number of new tokens to generate") - gen_text_parser.add_argument("-k", "--top_k", type=int, default=50, help="Top-k sampling parameter") - gen_text_parser.add_argument("-p", "--top_p", type=float, default=1.0, help="Top-p sampling parameter") - gen_text_parser.add_argument("-sw", "--stop_words", nargs="*", default=[], help="List of stop words for early stopping") - gen_text_parser.add_argument("-pf", "--profiling", action="store_true", help="Enable profiling logs for the inference process") - gen_text_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") - - # GGML Image Generation - gen_image_parser = subparsers.add_parser("gen-image", help="Run a GGUF model locally for image generation.") - gen_image_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - gen_image_parser.add_argument("-i2i","--img2img",action="store_true",help="Whether to run image-to-image generation") - gen_image_parser.add_argument("-ns", "--num_inference_steps", type=int, help="Number of inference steps") - gen_image_parser.add_argument("-np", "--num_images_per_prompt", type=int, default=1, help="Number of images to generate per prompt") - gen_image_parser.add_argument("-H", "--height", type=int, help="Height of the output image") - gen_image_parser.add_argument("-W", "--width", type=int, help="Width of the output image") - gen_image_parser.add_argument("-g", "--guidance_scale", type=float, help="Guidance scale for diffusion") - gen_image_parser.add_argument("-o", "--output", type=str, default="generated_images/image.png", help="Output path for the generated image") - gen_image_parser.add_argument("-s", "--random_seed", type=int, help="Random seed for image generation") - gen_image_parser.add_argument("--lora_dir", type=str, help="Path to directory containing LoRA files") - gen_image_parser.add_argument("--wtype", type=str, help="weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0). If not specified, the default is the type of the weight file.") - gen_image_parser.add_argument("--control_net_path", type=str, help="Path to control net model") - gen_image_parser.add_argument("--control_image_path", type=str, help="Path to image condition for Control Net") - gen_image_parser.add_argument("--control_strength", type=str, help="Strength to apply Control Net") - gen_image_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") - - # GGML VLM Inference - vlm_parser = subparsers.add_parser("vlm", help="Run a GGUF model locally for VLM inference.") - vlm_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - vlm_parser.add_argument("-t", "--temperature", type=float, default=0.8, help="Temperature for sampling") - vlm_parser.add_argument("-m", "--max_new_tokens", type=int, default=2048, help="Maximum number of new tokens to generate") - vlm_parser.add_argument("-k", "--top_k", type=int, default=50, help="Top-k sampling parameter") - vlm_parser.add_argument("-p", "--top_p", type=float, default=1.0, help="Top-p sampling parameter") - vlm_parser.add_argument("-sw", "--stop_words", nargs="*", default=[], help="List of stop words for early stopping") - vlm_parser.add_argument("-pf", "--profiling", action="store_true", help="Enable profiling logs for the inference process") - vlm_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") - - # Voice Inference - asr_parser = subparsers.add_parser("asr", help="Run a GGUF model locally for voice inference.") - asr_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - asr_parser.add_argument("-o", "--output_dir", type=str, default="transcriptions", help="Output directory for transcriptions") - asr_parser.add_argument("-b", "--beam_size", type=int, default=5, help="Beam size to use for transcription") - asr_parser.add_argument("-l", "--language", type=str, default=None, help="The language spoken in the audio. It should be a language code such as 'en' or 'fr'.") - asr_parser.add_argument("--task", type=str, default="transcribe", help="Task to execute (transcribe or translate)") - asr_parser.add_argument("-t", "--temperature", type=float, default=0.0, help="Temperature for sampling") - asr_parser.add_argument("-c", "--compute_type", type=str, default="default", help="Type to use for computation (e.g., float16, int8, int8_float16)") - asr_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") + # Run command + run_parser = subparsers.add_parser("run", help="Run inference for various tasks.") + run_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") + + # Common arguments + common_group = run_parser.add_argument_group('Common options') + common_group.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") + common_group.add_argument("-pf", "--profiling", action="store_true", help="Enable profiling logs for the inference process") + + # Text generation/vlm arguments + text_group = run_parser.add_argument_group('Text generation/VLM options') + text_group.add_argument("-t", "--temperature", type=float, help="Temperature for sampling") + text_group.add_argument("-m", "--max_new_tokens", type=int, help="Maximum number of new tokens to generate") + text_group.add_argument("-k", "--top_k", type=int, help="Top-k sampling parameter") + text_group.add_argument("-p", "--top_p", type=float, help="Top-p sampling parameter") + text_group.add_argument("-sw", "--stop_words", nargs="*", help="List of stop words for early stopping") + + # Image generation arguments + image_group = run_parser.add_argument_group('Image generation options') + image_group.add_argument("-i2i", "--img2img", action="store_true", help="Whether to run image-to-image generation") + image_group.add_argument("-ns", "--num_inference_steps", type=int, help="Number of inference steps") + image_group.add_argument("-np", "--num_images_per_prompt", type=int, help="Number of images to generate per prompt") + image_group.add_argument("-H", "--height", type=int, help="Height of the output image") + image_group.add_argument("-W", "--width", type=int, help="Width of the output image") + image_group.add_argument("-g", "--guidance_scale", type=float, help="Guidance scale for diffusion") + image_group.add_argument("-o", "--output", type=str, help="Output path for the generated image") + image_group.add_argument("-s", "--random_seed", type=int, help="Random seed for image generation") + image_group.add_argument("--lora_dir", type=str, help="Path to directory containing LoRA files") + image_group.add_argument("--wtype", type=str, help="Weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)") + image_group.add_argument("--control_net_path", type=str, help="Path to control net model") + image_group.add_argument("--control_image_path", type=str, help="Path to image condition for Control Net") + image_group.add_argument("--control_strength", type=str, help="Strength to apply Control Net") + + # ASR arguments + asr_group = run_parser.add_argument_group('Automatic Speech Recognition options') + asr_group.add_argument("-b", "--beam_size", type=int, help="Beam size to use for transcription") + asr_group.add_argument("-l", "--language", type=str, help="Language code for audio (e.g., 'en' or 'fr')") + asr_group.add_argument("--task", type=str, help="Task to execute (transcribe or translate)") + asr_group.add_argument("-c", "--compute_type", type=str, help="Type to use for computation") + + # Other commands (unchanged) + subparsers.add_parser("pull", help="Pull a model from official or hub.").add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") + subparsers.add_parser("remove", help="Remove a model from local machine.").add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") + subparsers.add_parser("list", help="List all models in the local machine.") + subparsers.add_parser("login", help="Login to Nexa API.") + subparsers.add_parser("whoami", help="Show current user information.") + subparsers.add_parser("logout", help="Logout from Nexa API.") # GGML server parser server_parser = subparsers.add_parser("server", help="Run the Nexa AI Text Generation Service") @@ -181,20 +115,17 @@ def main(): server_parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to") server_parser.add_argument("--reload", action="store_true", help="Enable automatic reloading on code changes") - # GGML general - subparsers.add_parser("pull", help="Pull a model from official or hub.").add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - subparsers.add_parser("remove", help="Remove a model from local machine.").add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - subparsers.add_parser("list", help="List all models in the local machine.") - subparsers.add_parser("login", help="Login to Nexa API.") - subparsers.add_parser("whoami", help="Show current user information.") - subparsers.add_parser("logout", help="Logout from Nexa API.") - args = parser.parse_args() - if args.command == "onnx": + if args.command == "run": + run_inference(args) + elif args.command == "onnx": run_onnx_inference(args) - elif args.command in ["gen-text", "gen-image", "vlm", "asr", "server"]: - run_ggml_inference(args) + elif args.command == "serve": + from nexa.gguf.server.nexa_service import run_nexa_ai_service as NexaServer + kwargs = {k: v for k, v in vars(args).items() if v is not None} + NexaServer(args.model_path, **kwargs) + return elif args.command == "pull": from nexa.general import pull_model pull_model(args.model_path) @@ -216,6 +147,5 @@ def main(): else: parser.print_help() - if __name__ == "__main__": - main() + main() \ No newline at end of file From 220197d67405a94a31460c507277350649e93afa Mon Sep 17 00:00:00 2001 From: Davidqian123 Date: Thu, 22 Aug 2024 04:44:53 +0000 Subject: [PATCH 2/7] change to nexa run --- nexa/cli/entry.py | 148 +++++++++++++------- nexa/constants.py | 79 ++++++++++- nexa/general.py | 137 +++++++++++++----- nexa/gguf/nexa_inference_image.py | 38 ++--- nexa/gguf/nexa_inference_text.py | 36 ++--- nexa/gguf/nexa_inference_vlm.py | 29 ++-- nexa/gguf/nexa_inference_voice.py | 27 ++-- nexa/gguf/server/nexa_service.py | 8 +- nexa/gguf/streamlit/streamlit_image_chat.py | 5 +- nexa/gguf/streamlit/streamlit_text_chat.py | 5 +- nexa/gguf/streamlit/streamlit_vlm.py | 5 +- nexa/gguf/streamlit/streamlit_voice_chat.py | 3 +- nexa/onnx/nexa_inference_image.py | 14 +- nexa/onnx/nexa_inference_text.py | 12 +- nexa/onnx/nexa_inference_tts.py | 10 +- nexa/onnx/nexa_inference_voice.py | 15 +- nexa/onnx/server/nexa_service.py | 2 +- nexa/onnx/streamlit/streamlit_image_chat.py | 4 +- nexa/onnx/streamlit/streamlit_text_chat.py | 4 +- nexa/onnx/streamlit/streamlit_tts.py | 5 +- nexa/onnx/streamlit/streamlit_voice_chat.py | 4 +- pyproject.toml | 2 - requirements.txt | 5 +- tomls/pyproject_cuda.toml | 2 - tomls/pyproject_metal.toml | 2 - 25 files changed, 357 insertions(+), 244 deletions(-) diff --git a/nexa/cli/entry.py b/nexa/cli/entry.py index 93f1dd0d..d61def4c 100644 --- a/nexa/cli/entry.py +++ b/nexa/cli/entry.py @@ -1,56 +1,70 @@ import argparse -from nexa.general import pull_model -import uvicorn -def run_inference(args): - local_path, run_type = pull_model(args.model_path) +def run_ggml_inference(args): kwargs = {k: v for k, v in vars(args).items() if v is not None} + model_path = kwargs.pop("model_path") - if run_type == "gen-text": + if args.command == "server": + from nexa.gguf.server.nexa_service import run_nexa_ai_service as NexaServer + NexaServer(model_path, **kwargs) + return + + from nexa.general import pull_model + local_path, run_type = pull_model(model_path) + + stop_words = kwargs.pop("stop_words", []) + + if run_type == "NLP": from nexa.gguf.nexa_inference_text import NexaTextInference - inference = NexaTextInference(local_path, **kwargs) - elif run_type == "gen-image": + inference = NexaTextInference(model_path=model_path, local_path=local_path, stop_words=stop_words, **kwargs) + elif run_type == "Computer Vision": from nexa.gguf.nexa_inference_image import NexaImageInference - inference = NexaImageInference(local_path, **kwargs) - elif run_type == "vlm": + inference = NexaImageInference(model_path, local_path, **kwargs) + if hasattr(args, 'streamlit') and args.streamlit: + inference.run_streamlit(model_path) + elif args.img2img: + inference.loop_img2img() + else: + inference.loop_txt2img() + return + elif run_type == "Multimodal": from nexa.gguf.nexa_inference_vlm import NexaVLMInference - inference = NexaVLMInference(local_path, **kwargs) - elif run_type == "asr": + inference = NexaVLMInference(model_path, local_path, stop_words=stop_words, **kwargs) + elif run_type == "Audio": from nexa.gguf.nexa_inference_voice import NexaVoiceInference - inference = NexaVoiceInference(local_path, **kwargs) + inference = NexaVoiceInference(model_path, local_path, **kwargs) else: raise ValueError(f"Unknown task: {run_type}") - if args.streamlit: - inference.run_streamlit(local_path) + if hasattr(args, 'streamlit') and args.streamlit: + inference.run_streamlit(model_path) else: inference.run() def run_onnx_inference(args): - local_path, run_type = pull_model(args.model_path) kwargs = {k: v for k, v in vars(args).items() if v is not None} - - if run_type == "gen-text": - from nexa.onnx.nexa_inference_text import \ - NexaTextInference as NexaTextOnnxInference - inference = NexaTextOnnxInference(local_path, **kwargs) - elif run_type == "gen-image": - from nexa.onnx.nexa_inference_image import \ - NexaImageInference as NexaImageOnnxInference - inference = NexaImageOnnxInference(local_path, **kwargs) - elif run_type == "asr": - from nexa.onnx.nexa_inference_voice import \ - NexaVoiceInference as NexaVoiceOnnxInference - inference = NexaVoiceOnnxInference(local_path, **kwargs) - elif run_type == "tts": - from nexa.onnx.nexa_inference_tts import \ - NexaTTSInference as NexaTTSOnnxInference - inference = NexaTTSOnnxInference(local_path, **kwargs) + model_path = kwargs.pop("model_path") + + from nexa.general import pull_model + local_path, run_type = pull_model(model_path) + + if run_type == "NLP": + from nexa.onnx.nexa_inference_text import NexaTextInference as NexaTextOnnxInference + inference = NexaTextOnnxInference(model_path, local_path, **kwargs) + elif run_type == "Computer Vision": + from nexa.onnx.nexa_inference_image import NexaImageInference as NexaImageOnnxInference + inference = NexaImageOnnxInference(model_path, local_path, **kwargs) + elif run_type == "Audio": + from nexa.onnx.nexa_inference_voice import NexaVoiceInference as NexaVoiceOnnxInference + inference = NexaVoiceOnnxInference(model_path, local_path, **kwargs) + elif run_type == "TTS": + from nexa.onnx.nexa_inference_tts import NexaTTSInference as NexaTTSOnnxInference + inference = NexaTTSOnnxInference(model_path, local_path, **kwargs) else: - raise ValueError(f"Unknown ONNX command: {args.onnx_command}") + raise ValueError(f"Unknown task: {run_type}") if hasattr(args, 'streamlit') and args.streamlit: - inference.run_streamlit(local_path) + inference.run_streamlit(model_path) else: inference.run() @@ -61,13 +75,10 @@ def main(): subparsers = parser.add_subparsers(dest="command", help="sub-command help") # Run command - run_parser = subparsers.add_parser("run", help="Run inference for various tasks.") + run_parser = subparsers.add_parser("run", help="Run inference for various tasks using GGUF models.") run_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - - # Common arguments - common_group = run_parser.add_argument_group('Common options') - common_group.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") - common_group.add_argument("-pf", "--profiling", action="store_true", help="Enable profiling logs for the inference process") + run_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") + run_parser.add_argument("-pf", "--profiling", action="store_true", help="Enable profiling logs for the inference process") # Text generation/vlm arguments text_group = run_parser.add_argument_group('Text generation/VLM options') @@ -100,13 +111,34 @@ def main(): asr_group.add_argument("--task", type=str, help="Task to execute (transcribe or translate)") asr_group.add_argument("-c", "--compute_type", type=str, help="Type to use for computation") - # Other commands (unchanged) - subparsers.add_parser("pull", help="Pull a model from official or hub.").add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - subparsers.add_parser("remove", help="Remove a model from local machine.").add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - subparsers.add_parser("list", help="List all models in the local machine.") - subparsers.add_parser("login", help="Login to Nexa API.") - subparsers.add_parser("whoami", help="Show current user information.") - subparsers.add_parser("logout", help="Logout from Nexa API.") + # ONNX command + onnx_parser = subparsers.add_parser("onnx", help="Run inference for various tasks using ONNX models.") + onnx_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") + onnx_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") + + # ONNX Text generation arguments + onnx_text_group = onnx_parser.add_argument_group('Text generation options') + onnx_text_group.add_argument("-t", "--temperature", type=float, help="Temperature for sampling") + onnx_text_group.add_argument("-m", "--max_new_tokens", type=int, help="Maximum number of new tokens to generate") + onnx_text_group.add_argument("-k", "--top_k", type=int, help="Top-k sampling parameter") + onnx_text_group.add_argument("-p", "--top_p", type=float, help="Top-p sampling parameter") + onnx_text_group.add_argument("-sw", "--stop_words", nargs="*", help="List of stop words for early stopping") + onnx_text_group.add_argument("-pf", "--profiling", action="store_true", help="Enable profiling logs for the inference process") + + # ONNX Image generation arguments + onnx_image_group = onnx_parser.add_argument_group('Image generation options') + onnx_image_group.add_argument("-ns", "--num_inference_steps", type=int, help="Number of inference steps") + onnx_image_group.add_argument("-np", "--num_images_per_prompt", type=int, help="Number of images to generate per prompt") + onnx_image_group.add_argument("-H", "--height", type=int, help="Height of the output image") + onnx_image_group.add_argument("-W", "--width", type=int, help="Width of the output image") + onnx_image_group.add_argument("-g", "--guidance_scale", type=float, help="Guidance scale for diffusion") + onnx_image_group.add_argument("-O", "--output", type=str, help="Output path for the generated image") + onnx_image_group.add_argument("-s", "--random_seed", type=int, help="Random seed for image generation") + + # ONNX Voice arguments + onnx_voice_group = onnx_parser.add_argument_group('Voice generation options') + onnx_voice_group.add_argument("-o", "--output_dir", type=str, default="voice_output", help="Output directory for audio processing") + onnx_voice_group.add_argument("-r", "--sampling_rate", type=int, default=16000, help="Sampling rate for audio processing") # GGML server parser server_parser = subparsers.add_parser("server", help="Run the Nexa AI Text Generation Service") @@ -115,23 +147,31 @@ def main(): server_parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to") server_parser.add_argument("--reload", action="store_true", help="Enable automatic reloading on code changes") + # Other commands + subparsers.add_parser("pull", help="Pull a model from official or hub.").add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") + subparsers.add_parser("remove", help="Remove a model from local machine.").add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") + subparsers.add_parser("clean", help="Clean up all model files.") + subparsers.add_parser("list", help="List all models in the local machine.") + subparsers.add_parser("login", help="Login to Nexa API.") + subparsers.add_parser("whoami", help="Show current user information.") + subparsers.add_parser("logout", help="Logout from Nexa API.") + + args = parser.parse_args() - if args.command == "run": - run_inference(args) + if args.command in ["run", "server"]: + run_ggml_inference(args) elif args.command == "onnx": run_onnx_inference(args) - elif args.command == "serve": - from nexa.gguf.server.nexa_service import run_nexa_ai_service as NexaServer - kwargs = {k: v for k, v in vars(args).items() if v is not None} - NexaServer(args.model_path, **kwargs) - return elif args.command == "pull": from nexa.general import pull_model pull_model(args.model_path) elif args.command == "remove": from nexa.general import remove_model remove_model(args.model_path) + elif args.command == "clean": + from nexa.general import clean + clean() elif args.command == "list": from nexa.general import list_models list_models() diff --git a/nexa/constants.py b/nexa/constants.py index 693965e1..3835f134 100644 --- a/nexa/constants.py +++ b/nexa/constants.py @@ -129,6 +129,8 @@ **NEXA_RUN_MODEL_MAP_TEXT, **NEXA_RUN_MODEL_MAP_VLM, **NEXA_RUN_MODEL_MAP_IMAGE, + **NEXA_RUN_MODEL_MAP_VOICE, + **NEXA_RUN_MODEL_MAP_FUNCTION_CALLING, } NEXA_RUN_CHAT_TEMPLATE_MAP = { @@ -205,4 +207,79 @@ "task": "transcribe", "temperature": 0.0, "compute_type": "default" -} \ No newline at end of file +} + +NEXA_OFFICIAL_MODELS_TYPE = { + 'gemma-2b': 'NLP', + 'Llama-2-7b-chat': 'NLP', + 'Llama-2-7b': 'NLP', + 'Meta-Llama-3-8B-Instruct': 'NLP', + 'Meta-Llama-3.1-8B-Instruct': 'NLP', + 'Mistral-7B-Instruct-v0.3': 'NLP', + 'Mistral-7B-Instruct-v0.2': 'NLP', + 'Phi-3-mini-128k-instruct': 'NLP', + 'Phi-3-mini-4k-instruct': 'NLP', + 'CodeQwen1.5-7B-Instruct': 'NLP', + 'Qwen2-0.5B-Instruct': 'NLP', + 'Qwen2-1.5B-Instruct': 'NLP', + 'Qwen2-7B-Instruct': 'NLP', + 'codegemma-2b': 'NLP', + 'gemma-1.1-2b-instruct': 'NLP', + 'gemma-2b-instruct': 'NLP', + 'gemma-2-9b-instruct': 'NLP', + 'gemma-1.1-7b-instruct': 'NLP', + 'gemma-7b-instruct': 'NLP', + 'gemma-7b': 'NLP', + 'Qwen2-1.5B': 'NLP', + 'codegemma-7b': 'NLP', + 'TinyLlama-1.1B-Chat-v1.0': 'NLP', + 'CodeLlama-7b-Instruct': 'NLP', + 'gpt2': 'NLP', + 'CodeLlama-7b': 'NLP', + 'CodeLlama-7b-Python': 'NLP', + 'Qwen1.5-7B-Instruct': 'NLP', + 'Qwen1.5-7B': 'NLP', + 'Phi-2': 'NLP', + 'deepseek-coder-1.3b-instruct': 'NLP', + 'deepseek-coder-1.3b-base': 'NLP', + 'deepseek-coder-6.7b-instruct': 'NLP', + 'dolphin-2.8-mistral-7b': 'NLP', + 'gemma-2-2b-instruct': 'NLP', + 'Octopus-v2': 'NLP', + 'Octopus-v4': 'NLP', + 'Octo-planner': 'NLP', + 'deepseek-coder-6.7b-base': 'NLP', + 'Llama2-7b-chat-uncensored': 'NLP', + 'Llama3-8B-Lexi-Uncensored': 'NLP', + 'Llama2-7b-function-calling': 'NLP', + 'lcm-dreamshaper-v7': 'Computer Vision', + 'stable-diffusion-v1-5': 'Computer Vision', + 'stable-diffusion-v1-4': 'Computer Vision', + 'stable-diffusion-v2-1': 'Computer Vision', + 'sdxl-turbo': 'Computer Vision', + 'hassaku-hentai-model-v13-LCM': 'Computer Vision', + 'anything-v30-LCM': 'Computer Vision', + 'Phi-3-vision-128k-instruct': 'Multimodal', + 'nanoLLaVA': 'Multimodal', + 'llava-v1.6-mistral-7b': 'Multimodal', + 'llava-v1.6-vicuna-7b': 'Multimodal', + 'llava-phi-3-mini': 'Multimodal', + 'llava-llama-3-8b-v1.1': 'Multimodal', + 'faster-whisper-tiny.en': 'Audio', + 'faster-whisper-tiny': 'Audio', + 'faster-whisper-small.en': 'Audio', + 'faster-whisper-small': 'Audio', + 'faster-whisper-medium.en': 'Audio', + 'faster-whisper-medium': 'Audio', + 'faster-whisper-base.en': 'Audio', + 'faster-whisper-base': 'Audio', + 'faster-whisper-large-v3': 'Audio', + 'whisper-tiny.en': 'Audio', + 'whisper-tiny': 'Audio', + 'whisper-small.en': 'Audio', + 'whisper-small': 'Audio', + 'whisper-base.en': 'Audio', + 'whisper-base': 'Audio', +} + + diff --git a/nexa/general.py b/nexa/general.py index 0143fa2f..b36f0688 100644 --- a/nexa/general.py +++ b/nexa/general.py @@ -1,7 +1,7 @@ import json import logging from pathlib import Path - +import shutil import requests from nexa.constants import ( @@ -13,6 +13,7 @@ NEXA_OFFICIAL_BUCKET, NEXA_RUN_MODEL_MAP, NEXA_TOKEN_PATH, + NEXA_OFFICIAL_MODELS_TYPE, ) @@ -103,25 +104,25 @@ def pull_model(model_path): try: if is_model_exists(model_path): - location = get_model_location(model_path) + location, run_type = get_model_info(model_path) logging.debug(f"Model {model_path} already exists at {location}") - return location + return location, run_type if "/" in model_path: - success, location, model_type = pull_model_from_hub(model_path) + result = pull_model_from_hub(model_path) else: - success, location, model_type = pull_model_from_official(model_path) + result = pull_model_from_official(model_path) - if success: - add_model_to_list(model_path, location, model_type) - logging.debug(f"Successfully pulled model {model_path} to {location}") - return location + if result["success"]: + add_model_to_list(model_path, result["local_path"], result["model_type"], result["run_type"]) + logging.debug(f"Successfully pulled model {model_path} to {result['local_path']}, run_type: {result['run_type']}") + return result["local_path"], result["run_type"] else: logging.debug(f"Failed to pull model {model_path}") - return None + return None, "UNKNOWN" except Exception as e: logging.error(f"An error occurred while pulling the model: {e}") - return None + return None, "UNKNOWN" def pull_model_from_hub(model_path): @@ -133,13 +134,20 @@ def pull_model_from_hub(model_path): token = file.read().strip() try: - presigned_links = get_model_presigned_link(model_path, token) + result = get_model_presigned_link(model_path, token) + run_type = result['type'] + presigned_links = result['presigned_urls'] except Exception as e: print(f"Failed to get download models: {e}") - return False, None + return { + "success": False, + "local_path": None, + "model_type": None, + "run_type": None + } success = True - model_location = None + local_path = None model_type = "undefined" # Determine model_type @@ -159,19 +167,24 @@ def pull_model_from_hub(model_path): download_path = NEXA_MODELS_HUB_DIR / file_path download_file_with_progress(presigned_link, download_path) - if model_location is None: + if local_path is None: if model_type == "onnx" or model_type == "bin": - model_location = str(download_path.parent) + local_path = str(download_path.parent) elif model_type == "gguf": - model_location = str(download_path) + local_path = str(download_path) else: # undefined - model_location = str(download_path.parent) + local_path = str(download_path.parent) except Exception as e: print(f"Failed to download {file_path}: {e}") success = False - return success, model_location, model_type + return { + "success": success, + "local_path": local_path, + "model_type": model_type, + "run_type": run_type + } def pull_model_from_official(model_path): @@ -183,8 +196,21 @@ def pull_model_from_official(model_path): model_type = "bin" else: model_type = "gguf" + + run_type = get_run_type_from_model_path(model_path) success, location = download_model_from_official(model_path, model_type) - return success, location, model_type + + return { + "success": success, + "local_path": location, + "model_type": model_type, + "run_type": run_type + } + + +def get_run_type_from_model_path(model_path): + model_name, model_version = model_path.split(":") + return NEXA_OFFICIAL_MODELS_TYPE.get(model_name, "UNKNOWN") def get_model_presigned_link(full_path, token): @@ -192,11 +218,11 @@ def get_model_presigned_link(full_path, token): Get the presigned links for downloading the contents of a model folder. Args: - full_path (str): The full path of the folder to download (e.g., "gpt2-onnx/"). - token (str): The authentication token. Can be empty. + full_path (str): The full path of the folder to download (e.g., "openai/gpt2:gguf-q2_K"). + token (str, optional): The authentication token. Defaults to None. Returns: - dict: A dictionary containing the contents of the folder, where keys are file paths and values are presigned links. + dict: A dictionary containing the model type and presigned URLs. """ url = f"{NEXA_API_URL}/model/download-tag-folder" @@ -205,13 +231,20 @@ def get_model_presigned_link(full_path, token): if token: headers["Authorization"] = f"Bearer {token}" - body = {"full_path": full_path} + body = {"full_path": full_path, "need_type": True} try: response = requests.post(url, headers=headers, json=body) response.raise_for_status() - folder_contents = response.json() - return folder_contents + result = response.json() + + run_type = result.get("type", [])[0] if result.get("type") else None + presigned_urls = result.get("presigned_urls", {}) + + return { + "run_type": run_type, + "presigned_urls": presigned_urls + } except requests.exceptions.RequestException as e: print(f"API request failed: {e}") @@ -286,7 +319,7 @@ def is_model_exists(model_name): return model_name in model_list -def add_model_to_list(model_name, model_location, model_type): +def add_model_to_list(model_name, model_location, model_type, run_type): NEXA_MODEL_LIST_PATH.parent.mkdir(parents=True, exist_ok=True) if NEXA_MODEL_LIST_PATH.exists(): @@ -295,20 +328,28 @@ def add_model_to_list(model_name, model_location, model_type): else: model_list = {} - model_list[model_name] = {"type": model_type, "location": model_location} + model_list[model_name] = { + "type": model_type, + "location": model_location, + "run_type": run_type + } with open(NEXA_MODEL_LIST_PATH, "w") as f: json.dump(model_list, f, indent=2) -def get_model_location(model_name): +def get_model_info(model_name): if not NEXA_MODEL_LIST_PATH.exists(): - return None + return None, None with open(NEXA_MODEL_LIST_PATH, "r") as f: model_list = json.load(f) - return model_list.get(model_name).get("location") + model_data = model_list.get(model_name, {}) + location = model_data.get("location") + run_type = model_data.get("run_type") + + return location, run_type def list_models(): @@ -320,15 +361,15 @@ def list_models(): model_list = json.load(f) table = [ - (model_name, model_info["type"], model_info["location"]) + (model_name, model_info["type"], model_info["run_type"], model_info["location"]) for model_name, model_info in model_list.items() ] - headers = ["Model Name", "Type", "Location"] + headers = ["Model Name", "Type", "Run Type", "Location"] from tabulate import tabulate print( tabulate( - table, headers, tablefmt="pretty", colalign=("left", "left", "left") + table, headers, tablefmt="pretty", colalign=("left", "left", "left", "left") ) ) except Exception as e: @@ -359,8 +400,6 @@ def remove_model(model_path): model_path.unlink() print(f"Deleted model file: {model_path}") elif model_path.is_dir(): - import shutil - shutil.rmtree(model_path) print(f"Deleted model directory: {model_path}") else: @@ -375,6 +414,32 @@ def remove_model(model_path): except Exception as e: print(f"An error occurred while removing the model: {e}") return None + +def clean(): + if not NEXA_MODELS_HUB_DIR.exists(): + print(f"Nothing to clean.") + return + + # Ask for user confirmation + confirmation = input(f"This will remove all downloaded models and the model list. Are you sure? (y/N): ").lower().strip() + + if confirmation != 'y': + print("Operation cancelled.") + return + + try: + # Remove all contents of the directory + for item in NEXA_MODELS_HUB_DIR.iterdir(): + if item.is_file(): + item.unlink() + elif item.is_dir(): + shutil.rmtree(item) + + print(f"Successfully removed all contents from {NEXA_MODELS_HUB_DIR}") + + except Exception as e: + print(f"An error occurred while cleaning the directory: {e}") + if __name__ == "__main__": diff --git a/nexa/gguf/nexa_inference_image.py b/nexa/gguf/nexa_inference_image.py index 21714fad..363c6015 100644 --- a/nexa/gguf/nexa_inference_image.py +++ b/nexa/gguf/nexa_inference_image.py @@ -7,7 +7,6 @@ from pathlib import Path from nexa.gguf.sd.stable_diffusion import StableDiffusion -from nexa.general import pull_model from nexa.constants import ( DEFAULT_IMG_GEN_PARAMS, EXIT_REMINDER, @@ -35,6 +34,7 @@ class NexaImageInference: Args: model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str): Local path of the model. num_inference_steps (int): Number of inference steps. width (int): Width of the output image. height (int): Height of the output image. @@ -45,51 +45,33 @@ class NexaImageInference: """ - def __init__(self, model_path, **kwargs): - self.model_path = None - self.downloaded_path = None - if model_path in NEXA_RUN_MODEL_MAP: - logging.debug(f"Found model {model_path} in public hub") - self.model_path = NEXA_RUN_MODEL_MAP.get(model_path) - self.downloaded_path = pull_model(self.model_path) - elif os.path.exists(model_path): - logging.debug(f"Using local model at {model_path}") - self.downloaded_path = model_path - else: - logging.debug(f"Trying to use model from hub at {model_path}") - self.downloaded_path = pull_model(model_path) - - if self.downloaded_path is None: - logging.error( - f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", - exc_info=True, - ) - exit(1) + def __init__(self, model_path, local_path, **kwargs): + self.model_path = model_path + self.local_path = local_path - if self.model_path == "lcm-dreamshaper-v7:fp16": + if self.model_path == "lcm-dreamshaper": self.params = DEFAULT_IMG_GEN_PARAMS_LCM - elif self.model_path == "sdxl-turbo:q8_0": + elif self.model_path == "sdxl-turbo": self.params = DEFAULT_IMG_GEN_PARAMS_TURBO else: self.params = DEFAULT_IMG_GEN_PARAMS self.params.update(kwargs) - if not kwargs.get("streamlit", False): - self._load_model(model_path) + self._load_model() if self.model is None: logging.error("Failed to load the model or pipeline.") exit(1) @SpinningCursorAnimation() - def _load_model(self, model_path: str): + def _load_model(self): with suppress_stdout_stderr(): self.model = StableDiffusion( - model_path=self.downloaded_path, + model_path=self.local_path, lora_model_dir=self.params.get("lora_dir", ""), n_threads=self.params.get("n_threads", multiprocessing.cpu_count()), wtype=self.params.get( - "wtype", NEXA_RUN_MODEL_PRECISION_MAP.get(model_path, "default") + "wtype", NEXA_RUN_MODEL_PRECISION_MAP.get(self.model_path, "default") ), # Weight type (options: default, f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0) control_net_path=self.params.get("control_net_path", ""), verbose=False, diff --git a/nexa/gguf/nexa_inference_text.py b/nexa/gguf/nexa_inference_text.py index df99fbe0..840d6dfb 100644 --- a/nexa/gguf/nexa_inference_text.py +++ b/nexa/gguf/nexa_inference_text.py @@ -12,7 +12,6 @@ NEXA_RUN_MODEL_MAP, NEXA_STOP_WORDS_MAP, ) -from nexa.general import pull_model from nexa.gguf.lib_utils import is_gpu_available from nexa.gguf.llama.llama import Llama from nexa.utils import SpinningCursorAnimation, nexa_prompt, suppress_stdout_stderr @@ -32,6 +31,7 @@ class NexaTextInference: Args: model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str): Local path of the model. stop_words (list): List of stop words for early stopping. profiling (bool): Enable timing measurements for the generation process. streamlit (bool): Run the inference in Streamlit UI. @@ -40,35 +40,16 @@ class NexaTextInference: top_k (int): Top-k sampling parameter. top_p (float): Top-p sampling parameter """ - def __init__(self, model_path, stop_words=None, **kwargs): + def __init__(self, model_path, local_path, stop_words=None, **kwargs): self.params = DEFAULT_TEXT_GEN_PARAMS self.params.update(kwargs) self.model = None - - self.model_path = None - self.downloaded_path = None - if model_path in NEXA_RUN_MODEL_MAP: - logging.debug(f"Found model {model_path} in public hub") - self.model_path = NEXA_RUN_MODEL_MAP.get(model_path) - self.downloaded_path = pull_model(self.model_path) - elif os.path.exists(model_path): - logging.debug(f"Using local model at {model_path}") - self.downloaded_path = model_path - else: - logging.debug(f"Trying to use model from hub at {model_path}") - self.downloaded_path = pull_model(model_path) - - if self.downloaded_path is None: - logging.error( - f"Model ({model_path}) is not appicable. Please refer to our docs for proper usage.", - exc_info=True, - ) - exit(1) - + self.model_path = model_path + self.local_path = local_path + self.stop_words = ( stop_words if stop_words else NEXA_STOP_WORDS_MAP.get(model_path, []) ) - self.profiling = kwargs.get("profiling", False) self.chat_format = NEXA_RUN_CHAT_TEMPLATE_MAP.get(model_path, None) @@ -83,6 +64,7 @@ def __init__(self, model_path, stop_words=None, **kwargs): "Failed to load model or tokenizer. Exiting.", exc_info=True ) exit(1) + def embed( self, input: Union[str, List[str]], @@ -107,11 +89,11 @@ def embed( @SpinningCursorAnimation() def _load_model(self): - logging.debug(f"Loading model from {self.downloaded_path}") + logging.debug(f"Loading model from {self.local_path}") start_time = time.time() with suppress_stdout_stderr(): self.model = Llama( - model_path=self.downloaded_path, + model_path=self.local_path, verbose=self.profiling, chat_format=self.chat_format, n_ctx=2048, @@ -280,6 +262,6 @@ def run_streamlit(self, model_path: str): stop_words = kwargs.pop("stop_words", []) inference = NexaTextInference(model_path, stop_words=stop_words, **kwargs) if args.streamlit: - inference.run_streamlit(model_path) + inference.run_streamlit() else: inference.run() diff --git a/nexa/gguf/nexa_inference_vlm.py b/nexa/gguf/nexa_inference_vlm.py index a79781eb..7243899b 100644 --- a/nexa/gguf/nexa_inference_vlm.py +++ b/nexa/gguf/nexa_inference_vlm.py @@ -79,6 +79,7 @@ class NexaVLMInference: Args: model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str): Local path of the model. stop_words (list): List of stop words for early stopping. profiling (bool): Enable timing measurements for the generation process. streamlit (bool): Run the inference in Streamlit UI. @@ -87,31 +88,29 @@ class NexaVLMInference: top_k (int): Top-k sampling parameter. top_p (float): Top-p sampling parameter """ - def __init__(self, model_path, stop_words=None, **kwargs): + def __init__(self, model_path, local_path, stop_words=None, **kwargs): self.params = DEFAULT_TEXT_GEN_PARAMS self.params.update(kwargs) self.model = None self.projector = None self.projector_path = NEXA_RUN_PROJECTOR_MAP.get(model_path, None) - self.downloaded_path = None - self.projector_downloaded_path = None + self.local_path = local_path + self.projector_local_path = None if model_path in NEXA_RUN_MODEL_MAP_VLM: logging.debug(f"Found model {model_path} in public hub") - self.model_path = NEXA_RUN_MODEL_MAP_VLM.get(model_path) self.projector_path = NEXA_RUN_PROJECTOR_MAP.get(model_path) - self.downloaded_path = pull_model(self.model_path) - self.projector_downloaded_path = pull_model(self.projector_path) + self.projector_local_path, run_type = pull_model(self.projector_path) elif (local_dir := Path(model_path).parent).exists(): logging.debug(f"Using local model at {local_dir}") model_name = Path(model_path).name tag_and_ext = model_name.split(":")[-1] - self.downloaded_path = local_dir / f"model-{tag_and_ext}" - self.projector_downloaded_path = local_dir / f"projector-{tag_and_ext}" + self.local_path = local_dir / f"model-{tag_and_ext}" + self.projector_local_path = local_dir / f"projector-{tag_and_ext}" if not ( - self.downloaded_path.exists() - and self.projector_downloaded_path.exists() + self.local_path.exists() + and self.projector_local_path.exists() ): logging.error( f"Model or projector not found in {local_dir}. " @@ -122,7 +121,7 @@ def __init__(self, model_path, stop_words=None, **kwargs): logging.error("VLM user model from hub is not supported yet.") exit(1) - if self.downloaded_path is None: + if self.local_path is None: logging.error( f"Model ({model_path}) is not appicable. Please refer to our docs for proper usage.", exc_info=True, @@ -147,18 +146,18 @@ def __init__(self, model_path, stop_words=None, **kwargs): @SpinningCursorAnimation() def _load_model(self): - logging.debug(f"Loading model from {self.downloaded_path}") + logging.debug(f"Loading model from {self.local_path}") start_time = time.time() with suppress_stdout_stderr(): self.projector = ( self.projector_handler( - clip_model_path=self.projector_downloaded_path, verbose=False + clip_model_path=self.projector_local_path, verbose=False ) - if self.projector_downloaded_path + if self.projector_local_path else None ) self.model = Llama( - model_path=self.downloaded_path, + model_path=self.local_path, chat_handler=self.projector, verbose=False, chat_format=self.chat_format, diff --git a/nexa/gguf/nexa_inference_voice.py b/nexa/gguf/nexa_inference_voice.py index 5835f53f..da731eaf 100644 --- a/nexa/gguf/nexa_inference_voice.py +++ b/nexa/gguf/nexa_inference_voice.py @@ -10,9 +10,7 @@ EXIT_REMINDER, NEXA_RUN_MODEL_MAP_VOICE, ) -from nexa.general import pull_model from nexa.utils import nexa_prompt -from faster_whisper import WhisperModel from nexa.utils import nexa_prompt, SpinningCursorAnimation, suppress_stdout_stderr logging.basicConfig(level=logging.INFO) @@ -28,6 +26,7 @@ class NexaVoiceInference: Args: model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str): Local path of the model. output_dir (str): Output directory for transcriptions. beam_size (int): Beam size to use for transcription. language (str): The language spoken in the audio. @@ -37,22 +36,12 @@ class NexaVoiceInference: output_dir (str): Output directory for transcriptions. """ - def __init__(self, model_path, **kwargs): - self.model_path = None - self.downloaded_path = None + def __init__(self, model_path, local_path, **kwargs): + self.model_path = model_path + self.local_path = local_path self.params = DEFAULT_VOICE_GEN_PARAMS - if model_path in NEXA_RUN_MODEL_MAP_VOICE: - logging.debug(f"Found model {model_path} in public hub") - self.model_path = NEXA_RUN_MODEL_MAP_VOICE.get(model_path) - self.downloaded_path = pull_model(self.model_path) - elif os.path.exists(model_path): - logging.debug(f"Using local model at {model_path}") - self.downloaded_path = model_path - else: - logging.error("Using voice model from hub is not supported yet.") - exit(1) - if self.downloaded_path is None: + if self.local_path is None: logging.error( f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", exc_info=True, @@ -74,10 +63,10 @@ def __init__(self, model_path, **kwargs): def _load_model(self): from faster_whisper import WhisperModel - logging.debug(f"Loading model from: {self.downloaded_path}") + logging.debug(f"Loading model from: {self.local_path}") with suppress_stdout_stderr(): self.model = WhisperModel( - self.downloaded_path, + self.local_path, device="cpu", compute_type=self.params["compute_type"], ) @@ -106,7 +95,7 @@ def _transcribe_audio(self, audio_path): ) transcription = "".join(segment.text for segment in segments) self._save_transcription(transcription) - logging.info(f"Transcription: {transcription}") + print(f"Transcription: {transcription}") except Exception as e: logging.error(f"Error during transcription: {e}", exc_info=True) diff --git a/nexa/gguf/server/nexa_service.py b/nexa/gguf/server/nexa_service.py index 5b24e7e7..36ad0575 100644 --- a/nexa/gguf/server/nexa_service.py +++ b/nexa/gguf/server/nexa_service.py @@ -67,7 +67,7 @@ async def load_model(): chat_format = NEXA_RUN_CHAT_TEMPLATE_MAP.get(model_path, None) completion_template = NEXA_RUN_COMPLETION_TEMPLATE_MAP.get(model_path, None) model_path = NEXA_RUN_MODEL_MAP_TEXT.get(model_path) - downloaded_path = pull_model(model_path) + downloaded_path, run_type = pull_model(model_path) with suppress_stdout_stderr(): model = Llama( model_path=downloaded_path, @@ -79,7 +79,7 @@ async def load_model(): elif model_path in NEXA_RUN_MODEL_MAP_FUNCTION_CALLING: chat_format = "chatml-function-calling" model_path = NEXA_RUN_MODEL_MAP_FUNCTION_CALLING.get(model_path) - downloaded_path = pull_model(model_path) + downloaded_path, run_type = pull_model(model_path) with suppress_stdout_stderr(): model = Llama( model_path=downloaded_path, @@ -89,7 +89,7 @@ async def load_model(): ) logging.info(f"model loaded as {model}") elif model_path in NEXA_RUN_MODEL_MAP_IMAGE: - downloaded_path = pull_model(model_path) + downloaded_path, run_type = pull_model(model_path) with suppress_stdout_stderr(): model = StableDiffusion( model_path=downloaded_path, @@ -101,7 +101,7 @@ async def load_model(): logging.info(f"model loaded as {model}") elif model_path in NEXA_RUN_MODEL_MAP_VOICE: model_path = NEXA_RUN_MODEL_MAP_VOICE.get(model_path) - downloaded_path = pull_model(model_path) + downloaded_path, run_type = pull_model(model_path) with suppress_stdout_stderr(): model = WhisperModel( downloaded_path, diff --git a/nexa/gguf/streamlit/streamlit_image_chat.py b/nexa/gguf/streamlit/streamlit_image_chat.py index 3117fea1..3dbe47ea 100644 --- a/nexa/gguf/streamlit/streamlit_image_chat.py +++ b/nexa/gguf/streamlit/streamlit_image_chat.py @@ -2,7 +2,7 @@ import sys from PIL import Image - +from nexa.general import pull_model import streamlit as st from nexa.gguf.nexa_inference_image import NexaImageInference @@ -11,7 +11,8 @@ @st.cache_resource def load_model(model_path): - nexa_model = NexaImageInference(model_path) + local_path, run_type = pull_model(model_path) + nexa_model = NexaImageInference(model_path, local_path) return nexa_model diff --git a/nexa/gguf/streamlit/streamlit_text_chat.py b/nexa/gguf/streamlit/streamlit_text_chat.py index 2a9d7fed..cffa517f 100644 --- a/nexa/gguf/streamlit/streamlit_text_chat.py +++ b/nexa/gguf/streamlit/streamlit_text_chat.py @@ -2,7 +2,7 @@ from typing import Iterator import streamlit as st - +from nexa.general import pull_model from nexa.gguf.nexa_inference_text import NexaTextInference default_model = sys.argv[1] @@ -11,7 +11,8 @@ @st.cache_resource def load_model(model_path): st.session_state.messages = [] - nexa_model = NexaTextInference(model_path) + local_path, run_type = pull_model(model_path) + nexa_model = NexaTextInference(model_path, local_path) return nexa_model diff --git a/nexa/gguf/streamlit/streamlit_vlm.py b/nexa/gguf/streamlit/streamlit_vlm.py index 7c8fee91..eb502ccd 100644 --- a/nexa/gguf/streamlit/streamlit_vlm.py +++ b/nexa/gguf/streamlit/streamlit_vlm.py @@ -4,7 +4,7 @@ import streamlit as st from PIL import Image - +from nexa.general import pull_model from nexa.gguf.nexa_inference_vlm import NexaVLMInference default_model = sys.argv[1] @@ -12,7 +12,8 @@ @st.cache_resource def load_model(model_path): - nexa_model = NexaVLMInference(model_path) + local_path, run_type = pull_model(model_path) + nexa_model = NexaVLMInference(model_path, local_path) return nexa_model diff --git a/nexa/gguf/streamlit/streamlit_voice_chat.py b/nexa/gguf/streamlit/streamlit_voice_chat.py index 461dcc95..f98fecc6 100644 --- a/nexa/gguf/streamlit/streamlit_voice_chat.py +++ b/nexa/gguf/streamlit/streamlit_voice_chat.py @@ -15,7 +15,8 @@ @st.cache_resource def load_model(model_path): - nexa_model = NexaVoiceInference(model_path) + local_path, run_type = pull_model(model_path) + nexa_model = NexaVoiceInference(model_path, local_path) return nexa_model diff --git a/nexa/onnx/nexa_inference_image.py b/nexa/onnx/nexa_inference_image.py index b2dc3afb..c8922c25 100644 --- a/nexa/onnx/nexa_inference_image.py +++ b/nexa/onnx/nexa_inference_image.py @@ -14,7 +14,6 @@ ) from nexa.constants import EXIT_REMINDER, NEXA_RUN_MODEL_MAP_ONNX -from nexa.general import pull_model from nexa.utils import nexa_prompt logging.basicConfig(level=logging.INFO) @@ -36,6 +35,7 @@ class NexaImageInference: Args: model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str): Local path of the model. num_inference_steps (int): Number of inference steps. num_images_per_prompt (int): Number of images to generate per prompt. width (int): Width of the output image. @@ -45,8 +45,9 @@ class NexaImageInference: random_seed (int): Random seed for image generation. streamlit (bool): Run the inference in Streamlit UI. """ - def __init__(self, model_path, **kwargs): + def __init__(self, model_path, local_path, **kwargs): self.model_path = NEXA_RUN_MODEL_MAP_ONNX.get(model_path, model_path) + self.download_onnx_folder = local_path self.params = { "num_inference_steps": 20, "num_images_per_prompt": 1, @@ -60,17 +61,12 @@ def __init__(self, model_path, **kwargs): self.pipeline = None def run(self): - # Step 1: Download the ONNX folder from S3 - downloaded_onnx_folder = pull_model(self.model_path) - if downloaded_onnx_folder is None: + if self.download_onnx_folder is None: logging.error("Failed to download the model. Exiting.") return - # Step 2: Load the model - self._load_model(downloaded_onnx_folder) - - # Step 3: Enter dialogue mode + self._load_model(self.download_onnx_folder) self._dialogue_mode() def _load_model(self, model_path): diff --git a/nexa/onnx/nexa_inference_text.py b/nexa/onnx/nexa_inference_text.py index f2f94a3c..044ecd7d 100644 --- a/nexa/onnx/nexa_inference_text.py +++ b/nexa/onnx/nexa_inference_text.py @@ -9,7 +9,6 @@ from transformers import AutoTokenizer, TextStreamer from nexa.constants import NEXA_RUN_MODEL_MAP_ONNX -from nexa.general import pull_model from nexa.utils import nexa_prompt logging.basicConfig(level=logging.INFO) @@ -25,6 +24,7 @@ class NexaTextInference: Args: model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str): Local path of the model. profiling (bool): Enable timing measurements for the generation process. streamlit (bool): Run the inference in Streamlit UI. temperature (float): Temperature for sampling. @@ -34,7 +34,7 @@ class NexaTextInference: top_p (float): Top-p sampling parameter """ - def __init__(self, model_path, **kwargs): + def __init__(self, model_path, local_path, **kwargs): self.model_path = NEXA_RUN_MODEL_MAP_ONNX.get(model_path, model_path) self.params = { "temperature": 0.5, @@ -47,7 +47,7 @@ def __init__(self, model_path, **kwargs): self.model = None self.tokenizer = None self.streamer = None - self.downloaded_onnx_folder = None + self.downloaded_onnx_folder = local_path self.timings = kwargs.get("timings", False) self.device = "cpu" @@ -144,12 +144,6 @@ def run(self): if self.params.get("streamlit"): self.run_streamlit() else: - self.downloaded_onnx_folder = pull_model( - self.model_path - ) # David TODO : move download logic to __init__(), otherwise for streamlit you will duplicate such logic again - if self.downloaded_onnx_folder is None: - logging.error("Failed to download the model. Exiting.", exc_info=True) - return self._load_model_and_tokenizer() diff --git a/nexa/onnx/nexa_inference_tts.py b/nexa/onnx/nexa_inference_tts.py index e7167ee6..81bf5b42 100644 --- a/nexa/onnx/nexa_inference_tts.py +++ b/nexa/onnx/nexa_inference_tts.py @@ -11,7 +11,6 @@ from ttstokenizer import TTSTokenizer from nexa.constants import EXIT_REMINDER, NEXA_RUN_MODEL_MAP_ONNX -from nexa.general import pull_model from nexa.utils import nexa_prompt logging.basicConfig(level=logging.INFO) @@ -28,12 +27,13 @@ class NexaTTSInference: Args: model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str): Local path of the model. output_dir (str): Output directory for tts. sampling_rate (int): Sampling rate for audio processing. streamlit (bool): Run the inference in Streamlit UI. """ - def __init__(self, model_path, **kwargs): + def __init__(self, model_path, local_path, **kwargs): self.model_path = NEXA_RUN_MODEL_MAP_ONNX.get(model_path, model_path) self.yaml_file_name = None self.params = { @@ -44,9 +44,7 @@ def __init__(self, model_path, **kwargs): self.model = None self.processor = None self.config = None - self.downloaded_onnx_folder = None - - self.downloaded_onnx_folder = pull_model(self.model_path) + self.downloaded_onnx_folder = local_path self.yaml_file_name = os.path.join(self.downloaded_onnx_folder, "config.yaml") with open(self.yaml_file_name, "r", encoding="utf-8") as f: self.config = yaml.safe_load(f) @@ -83,7 +81,7 @@ def _audio_generation(self, user_input): self._save_audio( outputs[0], self.params["sampling_rate"], self.params["output_path"] ) - logging.info(f"Audio saved to {self.params['output_path']}") + print(f"Audio saved to {self.params['output_path']}") def _save_audio(self, audio_data, sampling_rate, output_path): os.makedirs(output_path, exist_ok=True) diff --git a/nexa/onnx/nexa_inference_voice.py b/nexa/onnx/nexa_inference_voice.py index 638a723a..fa3d4a8f 100644 --- a/nexa/onnx/nexa_inference_voice.py +++ b/nexa/onnx/nexa_inference_voice.py @@ -10,7 +10,6 @@ from transformers import AutoProcessor from nexa.constants import EXIT_REMINDER, NEXA_RUN_MODEL_MAP_ONNX -from nexa.general import pull_model from nexa.utils import nexa_prompt logging.basicConfig(level=logging.INFO) @@ -26,26 +25,22 @@ class NexaVoiceInference: Args: model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str): Local path of the model. output_dir (str): Output directory for transcriptions. sampling_rate (int): Sampling rate for audio processing. streamlit (bool): Run the inference in Streamlit UI. """ - def __init__(self, model_path, **kwargs): + def __init__(self, model_path, local_path, **kwargs): self.model_path = NEXA_RUN_MODEL_MAP_ONNX.get(model_path, model_path) + self.downloaded_onnx_folder = local_path self.params = {"output_dir": "transcriptions", "sampling_rate": 16000} self.params.update(kwargs) self.model = None self.processor = None def run(self): - downloaded_onnx_folder = pull_model(self.model_path) - - if downloaded_onnx_folder is None: - logging.error("Failed to download the model. Exiting.") - return - - self._load_model(downloaded_onnx_folder) + self._load_model(self.downloaded_onnx_folder) self._dialogue_mode() def _load_model(self, model_path): @@ -99,7 +94,7 @@ def _transcribe_audio(self, audio_path): )[0] self._save_transcription(transcription) - logging.info(f"Transcription: {transcription}") + print(f"Transcription: {transcription}") except Exception as e: logging.error(f"Error during audio transcription: {e}") diff --git a/nexa/onnx/server/nexa_service.py b/nexa/onnx/server/nexa_service.py index 648d0368..387589d5 100644 --- a/nexa/onnx/server/nexa_service.py +++ b/nexa/onnx/server/nexa_service.py @@ -60,7 +60,7 @@ async def nexa_run_text_generation_preperation(model_path): global model, tokenizer, streamer, chat_template, is_chat_mode # Step 1: Check if the model_path is a key in NEXA_RUN_MODEL_MAP_ONNX, if so, get the full path full_model_path = NEXA_RUN_MODEL_MAP_ONNX.get(model_path, model_path) - downloaded_onnx_folder = pull_model(full_model_path) + downloaded_onnx_folder, run_type = pull_model(full_model_path) logging.info(f"Downloaded ONNX folder: {downloaded_onnx_folder}") # Step 2: Load the model and tokenizer diff --git a/nexa/onnx/streamlit/streamlit_image_chat.py b/nexa/onnx/streamlit/streamlit_image_chat.py index 4e632c69..f6803817 100644 --- a/nexa/onnx/streamlit/streamlit_image_chat.py +++ b/nexa/onnx/streamlit/streamlit_image_chat.py @@ -12,8 +12,8 @@ @st.cache_resource def load_model(model_path): - nexa_model = NexaImageInference(model_path) - nexa_model.downloaded_onnx_folder = pull_model(nexa_model.model_path) + local_path, run_type = pull_model(model_path) + nexa_model = NexaImageInference(model_path, local_path) if nexa_model.downloaded_onnx_folder is None: st.error("Failed to download the model. Please check the model path.") diff --git a/nexa/onnx/streamlit/streamlit_text_chat.py b/nexa/onnx/streamlit/streamlit_text_chat.py index 35d78913..35216720 100644 --- a/nexa/onnx/streamlit/streamlit_text_chat.py +++ b/nexa/onnx/streamlit/streamlit_text_chat.py @@ -12,8 +12,8 @@ @st.cache_resource def load_model(model_path): - nexa_model = NexaTextInference(model_path) - nexa_model.downloaded_onnx_folder = pull_model(nexa_model.model_path) + local_path, run_type = pull_model(model_path) + nexa_model = NexaTextInference(model_path, local_path) if nexa_model.downloaded_onnx_folder is None: st.error("Failed to download the model. Please check the model path.") diff --git a/nexa/onnx/streamlit/streamlit_tts.py b/nexa/onnx/streamlit/streamlit_tts.py index 5f30d5a7..c08922e1 100644 --- a/nexa/onnx/streamlit/streamlit_tts.py +++ b/nexa/onnx/streamlit/streamlit_tts.py @@ -4,7 +4,7 @@ import soundfile as sf import streamlit as st - +from nexa.general import pull_model from nexa.onnx.nexa_inference_tts import NexaTTSInference default_model = sys.argv[1] @@ -12,7 +12,8 @@ @st.cache_resource def load_model(model_path: str): - return NexaTTSInference(model_path) + local_path, run_type = pull_model(model_path) + return NexaTTSInference(model_path, local_path) def generate_audio(nexa_model: NexaTTSInference, text): diff --git a/nexa/onnx/streamlit/streamlit_voice_chat.py b/nexa/onnx/streamlit/streamlit_voice_chat.py index 991ec61b..400c95d0 100644 --- a/nexa/onnx/streamlit/streamlit_voice_chat.py +++ b/nexa/onnx/streamlit/streamlit_voice_chat.py @@ -15,8 +15,8 @@ @st.cache_resource def load_model(model_path): - nexa_model = NexaVoiceInference(model_path) - nexa_model.downloaded_onnx_folder = pull_model(nexa_model.model_path) + local_path, run_type = pull_model(model_path) + nexa_model = NexaVoiceInference(model_path, local_path) if nexa_model.downloaded_onnx_folder is None: st.error("Failed to download the model. Please check the model path.") diff --git a/pyproject.toml b/pyproject.toml index 49f0c0a2..499465ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,8 +16,6 @@ dependencies = [ "diskcache>=5.6.1", "jinja2>=2.11.3", "librosa>=0.8.0", - "boto3>=1.34.148", - "botocore>=1.34.148", "fastapi", "uvicorn", "pydantic", diff --git a/requirements.txt b/requirements.txt index c487b432..a05f5892 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,8 +12,6 @@ numpy>=1.20.0 diskcache>=5.6.1 jinja2>=2.11.3 librosa>=0.8.0 -boto3>=1.34.148 -botocore>=1.34.148 fastapi uvicorn pydantic @@ -21,7 +19,7 @@ pillow prompt_toolkit # For onnx -optimum[onnxruntime]>=1.7.3 # for CPU version +optimum[onnxruntime] # for CPU version diffusers # required for image generation optuna PyYAML @@ -34,6 +32,5 @@ ttstokenizer # Shared dependencies tqdm -importlib-metadata>=1.4.0,<7.0 tabulate streamlit diff --git a/tomls/pyproject_cuda.toml b/tomls/pyproject_cuda.toml index 75e9378a..e1c46c84 100644 --- a/tomls/pyproject_cuda.toml +++ b/tomls/pyproject_cuda.toml @@ -16,8 +16,6 @@ dependencies = [ "diskcache>=5.6.1", "jinja2>=2.11.3", "librosa>=0.8.0", - "boto3>=1.34.148", - "botocore>=1.34.148", "fastapi", "uvicorn", "pydantic", diff --git a/tomls/pyproject_metal.toml b/tomls/pyproject_metal.toml index e8a3006e..3aeeac47 100644 --- a/tomls/pyproject_metal.toml +++ b/tomls/pyproject_metal.toml @@ -16,8 +16,6 @@ dependencies = [ "diskcache>=5.6.1", "jinja2>=2.11.3", "librosa>=0.8.0", - "boto3>=1.34.148", - "botocore>=1.34.148", "fastapi", "uvicorn", "pydantic", From 73ecdfc6819adb42ff887d8116efc5203db01fa7 Mon Sep 17 00:00:00 2001 From: Davidqian123 Date: Thu, 22 Aug 2024 04:58:38 +0000 Subject: [PATCH 3/7] update --- nexa/gguf/nexa_inference_image.py | 6 +++--- nexa/gguf/nexa_inference_text.py | 4 ++-- nexa/gguf/nexa_inference_voice.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/nexa/gguf/nexa_inference_image.py b/nexa/gguf/nexa_inference_image.py index 7898b76b..e298b910 100644 --- a/nexa/gguf/nexa_inference_image.py +++ b/nexa/gguf/nexa_inference_image.py @@ -26,9 +26,9 @@ class NexaImageInference: A class used for loading image models and running image generation. Methods: - txt2img: (Used for SDK) Run the text-to-image generation loop. - img2img: (Used for SDK) Run the image-to-image generation loop. - run_streamlit: Run the Streamlit UI. + txt2img: (Used for SDK) Run the text-to-image generation loop. + img2img: (Used for SDK) Run the image-to-image generation loop. + run_streamlit: Run the Streamlit UI. Args: model_path (str): Path or identifier for the model in Nexa Model Hub. diff --git a/nexa/gguf/nexa_inference_text.py b/nexa/gguf/nexa_inference_text.py index 44aea0c6..a59597e3 100644 --- a/nexa/gguf/nexa_inference_text.py +++ b/nexa/gguf/nexa_inference_text.py @@ -25,8 +25,8 @@ class NexaTextInference: A class used for load text models and run text generation. Methods: - run: Run the text generation loop. - run_streamlit: Run the Streamlit UI. + run: Run the text generation loop. + run_streamlit: Run the Streamlit UI. Args: model_path (str): Path or identifier for the model in Nexa Model Hub. diff --git a/nexa/gguf/nexa_inference_voice.py b/nexa/gguf/nexa_inference_voice.py index 494e2c66..d88661af 100644 --- a/nexa/gguf/nexa_inference_voice.py +++ b/nexa/gguf/nexa_inference_voice.py @@ -21,8 +21,8 @@ class NexaVoiceInference: A class used for loading voice models and running voice transcription. Methods: - run: Run the voice transcription loop. - run_streamlit: Run the Streamlit UI. + run: Run the voice transcription loop. + run_streamlit: Run the Streamlit UI. Args: model_path (str): Path or identifier for the model in Nexa Model Hub. From 65432e645c723b78778c1b6441d9695b791b847a Mon Sep 17 00:00:00 2001 From: Davidqian123 Date: Thu, 22 Aug 2024 05:08:22 +0000 Subject: [PATCH 4/7] update --- nexa/cli/entry.py | 18 +++++++++--------- nexa/gguf/streamlit/streamlit_image_chat.py | 2 +- nexa/gguf/streamlit/streamlit_text_chat.py | 2 +- nexa/gguf/streamlit/streamlit_vlm.py | 2 +- nexa/gguf/streamlit/streamlit_voice_chat.py | 2 +- nexa/onnx/streamlit/streamlit_image_chat.py | 2 +- nexa/onnx/streamlit/streamlit_text_chat.py | 2 +- nexa/onnx/streamlit/streamlit_tts.py | 2 +- nexa/onnx/streamlit/streamlit_voice_chat.py | 2 +- 9 files changed, 17 insertions(+), 17 deletions(-) diff --git a/nexa/cli/entry.py b/nexa/cli/entry.py index d61def4c..9523db3d 100644 --- a/nexa/cli/entry.py +++ b/nexa/cli/entry.py @@ -19,20 +19,20 @@ def run_ggml_inference(args): inference = NexaTextInference(model_path=model_path, local_path=local_path, stop_words=stop_words, **kwargs) elif run_type == "Computer Vision": from nexa.gguf.nexa_inference_image import NexaImageInference - inference = NexaImageInference(model_path, local_path, **kwargs) + inference = NexaImageInference(model_path=model_path, local_path=local_path, **kwargs) if hasattr(args, 'streamlit') and args.streamlit: inference.run_streamlit(model_path) elif args.img2img: - inference.loop_img2img() + inference.run_img2img() else: - inference.loop_txt2img() + inference.run_txt2img() return elif run_type == "Multimodal": from nexa.gguf.nexa_inference_vlm import NexaVLMInference - inference = NexaVLMInference(model_path, local_path, stop_words=stop_words, **kwargs) + inference = NexaVLMInference(model_path=model_path, local_path=local_path, stop_words=stop_words, **kwargs) elif run_type == "Audio": from nexa.gguf.nexa_inference_voice import NexaVoiceInference - inference = NexaVoiceInference(model_path, local_path, **kwargs) + inference = NexaVoiceInference(model_path=model_path, local_path=local_path, **kwargs) else: raise ValueError(f"Unknown task: {run_type}") @@ -50,16 +50,16 @@ def run_onnx_inference(args): if run_type == "NLP": from nexa.onnx.nexa_inference_text import NexaTextInference as NexaTextOnnxInference - inference = NexaTextOnnxInference(model_path, local_path, **kwargs) + inference = NexaTextOnnxInference(model_path=model_path, local_path=local_path, **kwargs) elif run_type == "Computer Vision": from nexa.onnx.nexa_inference_image import NexaImageInference as NexaImageOnnxInference - inference = NexaImageOnnxInference(model_path, local_path, **kwargs) + inference = NexaImageOnnxInference(model_path=model_path, local_path=local_path, **kwargs) elif run_type == "Audio": from nexa.onnx.nexa_inference_voice import NexaVoiceInference as NexaVoiceOnnxInference - inference = NexaVoiceOnnxInference(model_path, local_path, **kwargs) + inference = NexaVoiceOnnxInference(model_path=model_path, local_path=local_path, **kwargs) elif run_type == "TTS": from nexa.onnx.nexa_inference_tts import NexaTTSInference as NexaTTSOnnxInference - inference = NexaTTSOnnxInference(model_path, local_path, **kwargs) + inference = NexaTTSOnnxInference(model_path=model_path, local_path=local_path, **kwargs) else: raise ValueError(f"Unknown task: {run_type}") diff --git a/nexa/gguf/streamlit/streamlit_image_chat.py b/nexa/gguf/streamlit/streamlit_image_chat.py index 3dbe47ea..b9d8fb97 100644 --- a/nexa/gguf/streamlit/streamlit_image_chat.py +++ b/nexa/gguf/streamlit/streamlit_image_chat.py @@ -12,7 +12,7 @@ @st.cache_resource def load_model(model_path): local_path, run_type = pull_model(model_path) - nexa_model = NexaImageInference(model_path, local_path) + nexa_model = NexaImageInference(model_path=model_path, local_path=local_path) return nexa_model diff --git a/nexa/gguf/streamlit/streamlit_text_chat.py b/nexa/gguf/streamlit/streamlit_text_chat.py index cffa517f..61848a0e 100644 --- a/nexa/gguf/streamlit/streamlit_text_chat.py +++ b/nexa/gguf/streamlit/streamlit_text_chat.py @@ -12,7 +12,7 @@ def load_model(model_path): st.session_state.messages = [] local_path, run_type = pull_model(model_path) - nexa_model = NexaTextInference(model_path, local_path) + nexa_model = NexaTextInference(model_path=model_path, local_path=local_path) return nexa_model diff --git a/nexa/gguf/streamlit/streamlit_vlm.py b/nexa/gguf/streamlit/streamlit_vlm.py index eb502ccd..9f334d13 100644 --- a/nexa/gguf/streamlit/streamlit_vlm.py +++ b/nexa/gguf/streamlit/streamlit_vlm.py @@ -13,7 +13,7 @@ @st.cache_resource def load_model(model_path): local_path, run_type = pull_model(model_path) - nexa_model = NexaVLMInference(model_path, local_path) + nexa_model = NexaVLMInference(model_pat=model_path, local_path=local_path) return nexa_model diff --git a/nexa/gguf/streamlit/streamlit_voice_chat.py b/nexa/gguf/streamlit/streamlit_voice_chat.py index f98fecc6..949be707 100644 --- a/nexa/gguf/streamlit/streamlit_voice_chat.py +++ b/nexa/gguf/streamlit/streamlit_voice_chat.py @@ -16,7 +16,7 @@ @st.cache_resource def load_model(model_path): local_path, run_type = pull_model(model_path) - nexa_model = NexaVoiceInference(model_path, local_path) + nexa_model = NexaVoiceInference(model_path=model_path, local_path=local_path) return nexa_model diff --git a/nexa/onnx/streamlit/streamlit_image_chat.py b/nexa/onnx/streamlit/streamlit_image_chat.py index f6803817..724c8647 100644 --- a/nexa/onnx/streamlit/streamlit_image_chat.py +++ b/nexa/onnx/streamlit/streamlit_image_chat.py @@ -13,7 +13,7 @@ @st.cache_resource def load_model(model_path): local_path, run_type = pull_model(model_path) - nexa_model = NexaImageInference(model_path, local_path) + nexa_model = NexaImageInference(model_path=model_path, local_path=local_path) if nexa_model.downloaded_onnx_folder is None: st.error("Failed to download the model. Please check the model path.") diff --git a/nexa/onnx/streamlit/streamlit_text_chat.py b/nexa/onnx/streamlit/streamlit_text_chat.py index 35216720..872f2774 100644 --- a/nexa/onnx/streamlit/streamlit_text_chat.py +++ b/nexa/onnx/streamlit/streamlit_text_chat.py @@ -13,7 +13,7 @@ @st.cache_resource def load_model(model_path): local_path, run_type = pull_model(model_path) - nexa_model = NexaTextInference(model_path, local_path) + nexa_model = NexaTextInference(model_path=model_path, local_path=local_path) if nexa_model.downloaded_onnx_folder is None: st.error("Failed to download the model. Please check the model path.") diff --git a/nexa/onnx/streamlit/streamlit_tts.py b/nexa/onnx/streamlit/streamlit_tts.py index c08922e1..1e60cc7a 100644 --- a/nexa/onnx/streamlit/streamlit_tts.py +++ b/nexa/onnx/streamlit/streamlit_tts.py @@ -13,7 +13,7 @@ @st.cache_resource def load_model(model_path: str): local_path, run_type = pull_model(model_path) - return NexaTTSInference(model_path, local_path) + return NexaTTSInference(model_path=model_path, local_path=local_path) def generate_audio(nexa_model: NexaTTSInference, text): diff --git a/nexa/onnx/streamlit/streamlit_voice_chat.py b/nexa/onnx/streamlit/streamlit_voice_chat.py index 400c95d0..b3316edb 100644 --- a/nexa/onnx/streamlit/streamlit_voice_chat.py +++ b/nexa/onnx/streamlit/streamlit_voice_chat.py @@ -16,7 +16,7 @@ @st.cache_resource def load_model(model_path): local_path, run_type = pull_model(model_path) - nexa_model = NexaVoiceInference(model_path, local_path) + nexa_model = NexaVoiceInference(model_path=model_path, local_path=local_path) if nexa_model.downloaded_onnx_folder is None: st.error("Failed to download the model. Please check the model path.") From 1b848424b6242017ae24f28266b3a999f91eb940 Mon Sep 17 00:00:00 2001 From: Davidqian123 Date: Thu, 22 Aug 2024 19:29:48 +0000 Subject: [PATCH 5/7] update --- nexa/gguf/nexa_inference_image.py | 27 ++++++++++++----- nexa/gguf/nexa_inference_text.py | 29 +++++++++++++----- nexa/gguf/nexa_inference_vlm.py | 49 ++++++++++++++++--------------- nexa/gguf/nexa_inference_voice.py | 19 +++++++----- nexa/onnx/nexa_inference_image.py | 14 ++++++--- nexa/onnx/nexa_inference_text.py | 13 ++++++-- nexa/onnx/nexa_inference_tts.py | 15 ++++++++-- nexa/onnx/nexa_inference_voice.py | 13 +++++++- tests/test_image_generation.py | 7 +++-- tests/test_text_generation.py | 19 ++++++------ 10 files changed, 136 insertions(+), 69 deletions(-) diff --git a/nexa/gguf/nexa_inference_image.py b/nexa/gguf/nexa_inference_image.py index 1aec4ab4..819e3dbf 100644 --- a/nexa/gguf/nexa_inference_image.py +++ b/nexa/gguf/nexa_inference_image.py @@ -15,6 +15,7 @@ ) from nexa.utils import SpinningCursorAnimation, nexa_prompt, suppress_stdout_stderr from streamlit.web import cli as stcli +from nexa.general import pull_model logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" @@ -44,34 +45,44 @@ class NexaImageInference: """ - def __init__(self, model_path, local_path, **kwargs): + def __init__(self, model_path, local_path=None, **kwargs): self.model_path = model_path - self.local_path = local_path + self.downloaded_path = local_path - if self.model_path == "lcm-dreamshaper": + if self.downloaded_path is None: + self.downloaded_path, run_type = pull_model(self.model_path) + + if self.downloaded_path is None: + logging.error( + f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", + exc_info=True, + ) + exit(1) + + if "lcm-dreamshaper" in self.model_path: self.params = DEFAULT_IMG_GEN_PARAMS_LCM - elif self.model_path == "sdxl-turbo": + elif "sdxl-turbo" in self.model_path: self.params = DEFAULT_IMG_GEN_PARAMS_TURBO else: self.params = DEFAULT_IMG_GEN_PARAMS self.params.update(kwargs) if not kwargs.get("streamlit", False): - self._load_model() + self._load_model(model_path) if self.model is None: logging.error("Failed to load the model or pipeline.") exit(1) @SpinningCursorAnimation() - def _load_model(self): + def _load_model(self, model_path: str): with suppress_stdout_stderr(): from nexa.gguf.sd.stable_diffusion import StableDiffusion self.model = StableDiffusion( - model_path=self.local_path, + model_path=self.downloaded_path, lora_model_dir=self.params.get("lora_dir", ""), n_threads=self.params.get("n_threads", multiprocessing.cpu_count()), wtype=self.params.get( - "wtype", NEXA_RUN_MODEL_PRECISION_MAP.get(self.model_path, "default") + "wtype", NEXA_RUN_MODEL_PRECISION_MAP.get(model_path, "default") ), # Weight type (options: default, f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0) control_net_path=self.params.get("control_net_path", ""), verbose=False, diff --git a/nexa/gguf/nexa_inference_text.py b/nexa/gguf/nexa_inference_text.py index ff084752..9027e50a 100644 --- a/nexa/gguf/nexa_inference_text.py +++ b/nexa/gguf/nexa_inference_text.py @@ -14,6 +14,7 @@ ) from nexa.gguf.lib_utils import is_gpu_available from nexa.utils import SpinningCursorAnimation, nexa_prompt, suppress_stdout_stderr +from nexa.general import pull_model logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" @@ -25,12 +26,13 @@ class NexaTextInference: A class used for load text models and run text generation. Methods: - run: Run the text generation loop. - run_streamlit: Run the Streamlit UI. + run: Run the text generation loop. + run_streamlit: Run the Streamlit UI. Args: model_path (str): Path or identifier for the model in Nexa Model Hub. local_path (str): Local path of the model. + embedding (bool): Enable embedding generation. stop_words (list): List of stop words for early stopping. profiling (bool): Enable timing measurements for the generation process. streamlit (bool): Run the inference in Streamlit UI. @@ -39,13 +41,24 @@ class NexaTextInference: top_k (int): Top-k sampling parameter. top_p (float): Top-p sampling parameter """ - def __init__(self, model_path, local_path, stop_words=None, **kwargs): + def __init__(self, model_path, local_path=None, stop_words=None, **kwargs): self.params = DEFAULT_TEXT_GEN_PARAMS self.params.update(kwargs) self.model = None + self.model_path = model_path - self.local_path = local_path + self.downloaded_path = local_path + if self.downloaded_path is None: + self.downloaded_path, run_type = pull_model(self.model_path) + + if self.downloaded_path is None: + logging.error( + f"Model ({model_path}) is not appicable. Please refer to our docs for proper usage.", + exc_info=True, + ) + exit(1) + self.stop_words = ( stop_words if stop_words else NEXA_STOP_WORDS_MAP.get(model_path, []) ) @@ -79,14 +92,14 @@ def create_embedding( @SpinningCursorAnimation() def _load_model(self): - logging.debug(f"Loading model from {self.local_path}") + logging.debug(f"Loading model from {self.downloaded_path}") start_time = time.time() with suppress_stdout_stderr(): try: from nexa.gguf.llama.llama import Llama self.model = Llama( embedding=self.params.get("embedding", False), - model_path=self.local_path, + model_path=self.downloaded_path, verbose=self.profiling, chat_format=self.chat_format, n_ctx=2048, @@ -95,7 +108,7 @@ def _load_model(self): except Exception as e: logging.error(f"Failed to load model: {e}. Falling back to CPU.", exc_info=True) self.model = Llama( - model_path=self.local_path, + model_path=self.downloaded_path, verbose=self.profiling, chat_format=self.chat_format, n_ctx=2048, @@ -306,6 +319,6 @@ def run_streamlit(self, model_path: str): stop_words = kwargs.pop("stop_words", []) inference = NexaTextInference(model_path, stop_words=stop_words, **kwargs) if args.streamlit: - inference.run_streamlit() + inference.run_streamlit(model_path) else: inference.run() diff --git a/nexa/gguf/nexa_inference_vlm.py b/nexa/gguf/nexa_inference_vlm.py index c60ab8a1..ac886ccf 100644 --- a/nexa/gguf/nexa_inference_vlm.py +++ b/nexa/gguf/nexa_inference_vlm.py @@ -87,30 +87,31 @@ class NexaVLMInference: top_k (int): Top-k sampling parameter. top_p (float): Top-p sampling parameter """ - def __init__(self, model_path, local_path, stop_words=None, **kwargs): + def __init__(self, model_path, local_path=None, stop_words=None, **kwargs): self.params = DEFAULT_TEXT_GEN_PARAMS self.params.update(kwargs) self.model = None self.projector = None - self.projector_path = NEXA_RUN_PROJECTOR_MAP.get(model_path, None) - self.local_path = local_path - self.projector_local_path = None - - if model_path in NEXA_RUN_MODEL_MAP_VLM: - logging.debug(f"Found model {model_path} in public hub") - self.projector_path = NEXA_RUN_PROJECTOR_MAP.get(model_path) - self.projector_local_path, run_type = pull_model(self.projector_path) - elif (local_dir := Path(model_path).parent).exists(): - logging.debug(f"Using local model at {local_dir}") + self.downloaded_path = local_path + self.projector_downloaded_path = None + + if self.downloaded_path is not None: + if model_path in NEXA_RUN_MODEL_MAP_VLM: + self.projector_path = NEXA_RUN_PROJECTOR_MAP[model_path] + self.projector_downloaded_path, _ = pull_model(self.projector_path) + elif model_path in NEXA_RUN_MODEL_MAP_VLM: + self.model_path = NEXA_RUN_MODEL_MAP_VLM[model_path] + self.projector_path = NEXA_RUN_PROJECTOR_MAP[model_path] + self.downloaded_path, _ = pull_model(self.model_path) + self.projector_downloaded_path, _ = pull_model(self.projector_path) + elif Path(model_path).parent.exists(): + local_dir = Path(model_path).parent model_name = Path(model_path).name tag_and_ext = model_name.split(":")[-1] - self.local_path = local_dir / f"model-{tag_and_ext}" - self.projector_local_path = local_dir / f"projector-{tag_and_ext}" - if not ( - self.local_path.exists() - and self.projector_local_path.exists() - ): + self.downloaded_path = local_dir / f"model-{tag_and_ext}" + self.projector_downloaded_path = local_dir / f"projector-{tag_and_ext}" + if not (self.downloaded_path.exists() and self.projector_downloaded_path.exists()): logging.error( f"Model or projector not found in {local_dir}. " "Make sure to name them as 'model-.gguf' and 'projector-.gguf'." @@ -120,9 +121,9 @@ def __init__(self, model_path, local_path, stop_words=None, **kwargs): logging.error("VLM user model from hub is not supported yet.") exit(1) - if self.local_path is None: + if self.downloaded_path is None: logging.error( - f"Model ({model_path}) is not appicable. Please refer to our docs for proper usage.", + f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", exc_info=True, ) exit(1) @@ -145,20 +146,20 @@ def __init__(self, model_path, local_path, stop_words=None, **kwargs): @SpinningCursorAnimation() def _load_model(self): - logging.debug(f"Loading model from {self.local_path}") + logging.debug(f"Loading model from {self.downloaded_path}") start_time = time.time() with suppress_stdout_stderr(): self.projector = ( self.projector_handler( - clip_model_path=self.projector_local_path, verbose=False + clip_model_path=self.projector_downloaded_path, verbose=False ) - if self.projector_local_path + if self.projector_downloaded_path else None ) try: from nexa.gguf.llama.llama import Llama self.model = Llama( - model_path=self.local_path, + model_path=self.downloaded_path, chat_handler=self.projector, verbose=False, chat_format=self.chat_format, @@ -171,7 +172,7 @@ def _load_model(self): exc_info=True, ) self.model = Llama( - model_path=self.local_path, + model_path=self.downloaded_path, chat_handler=self.projector, verbose=False, chat_format=self.chat_format, diff --git a/nexa/gguf/nexa_inference_voice.py b/nexa/gguf/nexa_inference_voice.py index c20b37cc..0725f704 100644 --- a/nexa/gguf/nexa_inference_voice.py +++ b/nexa/gguf/nexa_inference_voice.py @@ -12,7 +12,7 @@ ) from nexa.utils import nexa_prompt from nexa.utils import nexa_prompt, SpinningCursorAnimation, suppress_stdout_stderr - +from nexa.general import pull_model logging.basicConfig(level=logging.INFO) @@ -21,8 +21,8 @@ class NexaVoiceInference: A class used for loading voice models and running voice transcription. Methods: - run: Run the voice transcription loop. - run_streamlit: Run the Streamlit UI. + run: Run the voice transcription loop. + run_streamlit: Run the Streamlit UI. Args: model_path (str): Path or identifier for the model in Nexa Model Hub. @@ -36,12 +36,15 @@ class NexaVoiceInference: output_dir (str): Output directory for transcriptions. """ - def __init__(self, model_path, local_path, **kwargs): + def __init__(self, model_path, local_path=None, **kwargs): self.model_path = model_path - self.local_path = local_path + self.downloaded_path = local_path self.params = DEFAULT_VOICE_GEN_PARAMS - if self.local_path is None: + if self.downloaded_path is None: + self.downloaded_path, run_type = pull_model(self.model_path) + + if self.downloaded_path is None: logging.error( f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", exc_info=True, @@ -64,10 +67,10 @@ def __init__(self, model_path, local_path, **kwargs): def _load_model(self): from faster_whisper import WhisperModel - logging.debug(f"Loading model from: {self.local_path}") + logging.debug(f"Loading model from: {self.downloaded_path}") with suppress_stdout_stderr(): self.model = WhisperModel( - self.local_path, + self.downloaded_path, device="cpu", compute_type=self.params["compute_type"], ) diff --git a/nexa/onnx/nexa_inference_image.py b/nexa/onnx/nexa_inference_image.py index d2c197af..9e10f7bb 100644 --- a/nexa/onnx/nexa_inference_image.py +++ b/nexa/onnx/nexa_inference_image.py @@ -12,7 +12,7 @@ ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline, ) - +from nexa.general import pull_model from nexa.constants import EXIT_REMINDER, NEXA_RUN_MODEL_MAP_ONNX from nexa.utils import nexa_prompt @@ -45,7 +45,7 @@ class NexaImageInference: random_seed (int): Random seed for image generation. streamlit (bool): Run the inference in Streamlit UI. """ - def __init__(self, model_path, local_path, **kwargs): + def __init__(self, model_path, local_path=None, **kwargs): self.model_path = NEXA_RUN_MODEL_MAP_ONNX.get(model_path, model_path) self.download_onnx_folder = local_path self.params = { @@ -63,8 +63,14 @@ def __init__(self, model_path, local_path, **kwargs): def run(self): if self.download_onnx_folder is None: - logging.error("Failed to download the model. Exiting.") - return + self.download_onnx_folder, run_type = pull_model(self.model_path) + + if self.download_onnx_folder is None: + logging.error( + f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", + exc_info=True, + ) + exit(1) self._load_model(self.download_onnx_folder) self._dialogue_mode() diff --git a/nexa/onnx/nexa_inference_text.py b/nexa/onnx/nexa_inference_text.py index bb17c6e8..794efca9 100644 --- a/nexa/onnx/nexa_inference_text.py +++ b/nexa/onnx/nexa_inference_text.py @@ -7,7 +7,7 @@ from optimum.onnxruntime import ORTModelForCausalLM from transformers import AutoTokenizer, TextStreamer - +from nexa.general import pull_model from nexa.constants import NEXA_RUN_MODEL_MAP_ONNX from nexa.utils import nexa_prompt @@ -34,7 +34,7 @@ class NexaTextInference: top_p (float): Top-p sampling parameter """ - def __init__(self, model_path, local_path, **kwargs): + def __init__(self, model_path, local_path=None, **kwargs): self.model_path = NEXA_RUN_MODEL_MAP_ONNX.get(model_path, model_path) self.params = { "temperature": 0.5, @@ -144,7 +144,16 @@ def run(self): if self.params.get("streamlit"): self.run_streamlit() else: + if self.downloaded_onnx_folder is None: + self.downloaded_onnx_folder, run_type = pull_model(self.model_path) + if self.downloaded_onnx_folder is None: + logging.error( + f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", + exc_info=True, + ) + exit(1) + self._load_model_and_tokenizer() if self.model is None or self.tokenizer is None or self.streamer is None: diff --git a/nexa/onnx/nexa_inference_tts.py b/nexa/onnx/nexa_inference_tts.py index 75238c06..36318d6b 100644 --- a/nexa/onnx/nexa_inference_tts.py +++ b/nexa/onnx/nexa_inference_tts.py @@ -12,7 +12,7 @@ from nexa.constants import EXIT_REMINDER, NEXA_RUN_MODEL_MAP_ONNX from nexa.utils import nexa_prompt - +from nexa.general import pull_model logging.basicConfig(level=logging.INFO) @@ -33,7 +33,7 @@ class NexaTTSInference: streamlit (bool): Run the inference in Streamlit UI. """ - def __init__(self, model_path, local_path, **kwargs): + def __init__(self, model_path, local_path=None, **kwargs): self.model_path = NEXA_RUN_MODEL_MAP_ONNX.get(model_path, model_path) self.yaml_file_name = None self.params = { @@ -45,6 +45,17 @@ def __init__(self, model_path, local_path, **kwargs): self.processor = None self.config = None self.downloaded_onnx_folder = local_path + + if self.downloaded_onnx_folder is None: + self.downloaded_onnx_folder, run_type = pull_model(self.model_path) + + if self.downloaded_onnx_folder is None: + logging.error( + f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", + exc_info=True, + ) + exit(1) + self.yaml_file_name = os.path.join(self.downloaded_onnx_folder, "config.yaml") with open(self.yaml_file_name, "r", encoding="utf-8") as f: self.config = yaml.safe_load(f) diff --git a/nexa/onnx/nexa_inference_voice.py b/nexa/onnx/nexa_inference_voice.py index fa3d4a8f..40c9094a 100644 --- a/nexa/onnx/nexa_inference_voice.py +++ b/nexa/onnx/nexa_inference_voice.py @@ -11,6 +11,7 @@ from nexa.constants import EXIT_REMINDER, NEXA_RUN_MODEL_MAP_ONNX from nexa.utils import nexa_prompt +from nexa.general import pull_model logging.basicConfig(level=logging.INFO) @@ -31,7 +32,7 @@ class NexaVoiceInference: streamlit (bool): Run the inference in Streamlit UI. """ - def __init__(self, model_path, local_path, **kwargs): + def __init__(self, model_path, local_path=None, **kwargs): self.model_path = NEXA_RUN_MODEL_MAP_ONNX.get(model_path, model_path) self.downloaded_onnx_folder = local_path self.params = {"output_dir": "transcriptions", "sampling_rate": 16000} @@ -40,6 +41,16 @@ def __init__(self, model_path, local_path, **kwargs): self.processor = None def run(self): + if self.downloaded_onnx_folder is None: + self.downloaded_onnx_folder, run_type = pull_model(self.model_path) + + if self.downloaded_onnx_folder is None: + logging.error( + f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", + exc_info=True, + ) + exit(1) + self._load_model(self.downloaded_onnx_folder) self._dialogue_mode() diff --git a/tests/test_image_generation.py b/tests/test_image_generation.py index 7e749dc6..292089f3 100644 --- a/tests/test_image_generation.py +++ b/tests/test_image_generation.py @@ -4,6 +4,7 @@ sd = NexaImageInference( model_path="sd1-4", + local_path=None, wtype="q4_0", ) @@ -31,6 +32,6 @@ def test_img_to_img(): ) # Main execution -# if __name__ == "__main__": - # test_txt_to_img() - # test_img_to_img() +if __name__ == "__main__": + test_txt_to_img() + test_img_to_img() diff --git a/tests/test_text_generation.py b/tests/test_text_generation.py index e3ceed30..e060adc2 100644 --- a/tests/test_text_generation.py +++ b/tests/test_text_generation.py @@ -3,6 +3,7 @@ model = NexaTextInference( model_path="gemma", + local_path=None, verbose=False, n_gpu_layers=-1 if is_gpu_available() else 0, chat_format="llama-2", @@ -63,12 +64,12 @@ def test_create_embedding(): print("Embeddings:\n", embeddings) # Main execution -# if __name__ == "__main__": -# print("=== Testing 1 ===") -# test1() -# print("=== Testing 2 ===") -# test2() -# print("=== Testing 3 ===") -# test3() -# print("=== Testing 4 ===") -# test4() \ No newline at end of file +if __name__ == "__main__": + print("=== Testing 1 ===") + test_text_generation() + print("=== Testing 2 ===") + test_streaming() + print("=== Testing 3 ===") + test_create_chat_completion() + print("=== Testing 4 ===") + test_create_embedding() \ No newline at end of file From ac59357d315102e29c08f0589b18491b78cbeb37 Mon Sep 17 00:00:00 2001 From: Davidqian123 Date: Thu, 22 Aug 2024 20:37:14 +0000 Subject: [PATCH 6/7] add ci test --- nexa/constants.py | 13 +++++-- tests/test_vlm_generation.py | 63 ++++++++++++++++++++++++++++++++++ tests/test_voice_generation.py | 57 ++++++++++++++++++++++++++++++ 3 files changed, 131 insertions(+), 2 deletions(-) create mode 100644 tests/test_vlm_generation.py create mode 100644 tests/test_voice_generation.py diff --git a/nexa/constants.py b/nexa/constants.py index 3835f134..17812f88 100644 --- a/nexa/constants.py +++ b/nexa/constants.py @@ -61,7 +61,12 @@ "phi3v": "Phi-3-vision-128k-instruct:onnx-cpu-int4", "sd1-5": "stable-diffusion-v1-5:onnx-cpu-fp32", "lcm-dreamshaper": "lcm-dreamshaper-v7:onnx-cpu-fp32", - "whisper": "whisper-tiny:onnx-cpu-fp32", + "whisper-tiny.en": "whisper-tiny.en:onnx-cpu-fp32", + "whisper-tiny": "whisper-tiny:onnx-cpu-fp32", + "whisper-small.en": "whisper-small.en:onnx-cpu-fp32", + "whisper-small": "whisper-small:onnx-cpu-fp32", + "whisper-base.en": "whisper-base.en:onnx-cpu-fp32", + "whisper-base": "whisper-base:onnx-cpu-fp32", "ljspeech": "ljspeech-jets:onnx-cpu-fp32", } @@ -86,10 +91,14 @@ "whisper-large": "whisper-large:bin-large-v3", "whisper-tiny": "whisper-tiny:bin-tiny", "faster-whisper-tiny": "faster-whisper-tiny:bin-cpu-fp16", + "faster-whisper-tiny.en": "faster-whisper-tiny.en:bin-cpu-fp16", "faster-whisper-small": "faster-whisper-small:bin-cpu-fp16", + "faster-whisper-small.en": "faster-whisper-small.en:bin-cpu-fp16", "faster-whisper-medium": "faster-whisper-medium:bin-cpu-fp16", + "faster-whisper-medium.en": "faster-whisper-medium.en:bin-cpu-fp16", "faster-whisper-base": "faster-whisper-base:bin-cpu-fp16", - "faster-whisper-large": "faster-whisper-large:bin-cpu-fp16", + "faster-whisper-base.en": "faster-whisper-base.en:bin-cpu-fp16", + "faster-whisper-large": "faster-whisper-large-v3:bin-cpu-fp16", } NEXA_RUN_MODEL_MAP_FUNCTION_CALLING = { diff --git a/tests/test_vlm_generation.py b/tests/test_vlm_generation.py new file mode 100644 index 00000000..1e00fc6e --- /dev/null +++ b/tests/test_vlm_generation.py @@ -0,0 +1,63 @@ +from nexa.gguf import NexaVLMInference +from tempfile import TemporaryDirectory +from .utils import download_model +import os + +vlm = NexaVLMInference( + model_path="llava-phi-3-mini:q4_0", + local_path=None, +) + +# Test create_chat_completion +def test_create_chat_completion(): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"} + ] + completion = vlm.create_chat_completion( + messages=messages, + max_tokens=50, + temperature=0.7, + top_p=0.95, + top_k=40, + stream=False + ) + + assert isinstance(completion, dict) + assert "choices" in completion + assert len(completion["choices"]) > 0 + assert "message" in completion["choices"][0] + assert "content" in completion["choices"][0]["message"] + print("create_chat_completion test passed") + +# Test _chat method +def test_chat(): + with TemporaryDirectory() as temp_dir: + # Download a sample image + img_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + img_path = download_model(img_url, temp_dir) + + # Test _chat with image + chat_output = vlm._chat("Describe this image", image_path=img_path) + + # Check if the output is an iterator + assert hasattr(chat_output, '__iter__') + + # Collect the output + output_text = "" + for chunk in chat_output: + assert "choices" in chunk + assert len(chunk["choices"]) > 0 + assert "delta" in chunk["choices"][0] + delta = chunk["choices"][0]["delta"] + if "content" in delta: + output_text += delta["content"] + + assert len(output_text) > 0 + print("_chat test with image passed") + +if __name__ == "__main__": + print("=== Testing 1 ===") + test_create_chat_completion() + print("=== Testing 2 ===") + test_chat() \ No newline at end of file diff --git a/tests/test_voice_generation.py b/tests/test_voice_generation.py new file mode 100644 index 00000000..5bc1ce3b --- /dev/null +++ b/tests/test_voice_generation.py @@ -0,0 +1,57 @@ +import os +import tempfile +from pathlib import Path +from nexa.gguf import NexaVoiceInference +from unittest.mock import patch, MagicMock + +def test_voice_generation(): + # Initialize the NexaVoiceInference + voice_inference = NexaVoiceInference( + model_path="faster-whisper-tiny", + local_path=None, + output_dir=tempfile.mkdtemp(), # Use a temporary directory for output + beam_size=5, + language="en", + task="transcribe", + temperature=0.0, + compute_type="float32" + ) + + # Create a mock audio file + mock_audio_path = Path(tempfile.mkdtemp()) / "test_audio.wav" + mock_audio_path.touch() # Create an empty file + + # Mock the WhisperModel's transcribe method + mock_segments = [ + MagicMock(text="This is a test transcription."), + MagicMock(text=" It works perfectly.") + ] + mock_transcribe = MagicMock(return_value=(mock_segments, None)) + + # Test _transcribe_audio method + with patch.object(voice_inference.model, 'transcribe', mock_transcribe): + voice_inference._transcribe_audio(str(mock_audio_path)) + + # Assertions + mock_transcribe.assert_called_once_with( + str(mock_audio_path), + beam_size=5, + language="en", + task="transcribe", + temperature=0.0, + vad_filter=True + ) + + # Check if the transcription was saved + transcription_files = list(Path(voice_inference.params["output_dir"]).glob("transcription_*.txt")) + assert len(transcription_files) == 1, "Transcription file was not created" + + # Check the content of the transcription file + with open(transcription_files[0], 'r') as f: + content = f.read() + assert content == "This is a test transcription. It works perfectly.", "Transcription content is incorrect" + + print("Voice generation test passed successfully!") + +if __name__ == "__main__": + test_voice_generation() \ No newline at end of file From bfed39cf9bcf6696e3f6e75562b3b28063564f7b Mon Sep 17 00:00:00 2001 From: Davidqian123 Date: Thu, 22 Aug 2024 23:05:42 +0000 Subject: [PATCH 7/7] update --- nexa/gguf/nexa_inference_text.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nexa/gguf/nexa_inference_text.py b/nexa/gguf/nexa_inference_text.py index 9f0afd60..f12054f1 100644 --- a/nexa/gguf/nexa_inference_text.py +++ b/nexa/gguf/nexa_inference_text.py @@ -25,15 +25,17 @@ class NexaTextInference: """ - A class used for load text models and run text generation. + A class used for loading text models and running text generation. Methods: run: Run the text generation loop. run_streamlit: Run the Streamlit UI. - + create_embedding: Embed a string. + create_chat_completion: Generate completion for a chat conversation. + create_completion: Generate completion for a given prompt. Args: model_path (str): Path or identifier for the model in Nexa Model Hub. - local_path (str): Local path of the model. + local_path (str, optional): Local path of the model. embedding (bool): Enable embedding generation. stop_words (list): List of stop words for early stopping. profiling (bool): Enable timing measurements for the generation process.