diff --git a/nexa/cli/entry.py b/nexa/cli/entry.py index 955ef6d0..e01a16b7 100644 --- a/nexa/cli/entry.py +++ b/nexa/cli/entry.py @@ -1,38 +1,6 @@ import argparse -import uvicorn - from nexa import __version__ -def run_onnx_inference(args): - kwargs = {k: v for k, v in vars(args).items() if v is not None} - model_path = kwargs.pop("model_path") - if args.onnx_command == "gen-text": - from nexa.onnx.nexa_inference_text import \ - NexaTextInference as NexaTextOnnxInference - inference = NexaTextOnnxInference(model_path, **kwargs) - elif args.onnx_command == "gen-image": - from nexa.onnx.nexa_inference_image import \ - NexaImageInference as NexaImageOnnxInference - inference = NexaImageOnnxInference(model_path, **kwargs) - elif args.onnx_command == "asr": - from nexa.onnx.nexa_inference_voice import \ - NexaVoiceInference as NexaVoiceOnnxInference - inference = NexaVoiceOnnxInference(model_path, **kwargs) - elif args.onnx_command == "tts": - from nexa.onnx.nexa_inference_tts import \ - NexaTTSInference as NexaTTSOnnxInference - inference = NexaTTSOnnxInference(model_path, **kwargs) - elif args.onnx_command == "server": - from nexa.onnx.server.nexa_service import app as NexaOnnxServer - uvicorn.run(NexaOnnxServer, host=args.host, port=args.port, reload=args.reload) - else: - raise ValueError(f"Unknown ONNX command: {args.onnx_command}") - - if hasattr(args, 'streamlit') and args.streamlit: - inference.run_streamlit(model_path) - else: - inference.run() - def run_ggml_inference(args): kwargs = {k: v for k, v in vars(args).items() if v is not None} model_path = kwargs.pop("model_path") @@ -41,15 +9,18 @@ def run_ggml_inference(args): from nexa.gguf.server.nexa_service import run_nexa_ai_service as NexaServer NexaServer(model_path, **kwargs) return - + + from nexa.general import pull_model + local_path, run_type = pull_model(model_path) + stop_words = kwargs.pop("stop_words", []) - if args.command == "gen-text": + if run_type == "NLP": from nexa.gguf.nexa_inference_text import NexaTextInference - inference = NexaTextInference(model_path, stop_words=stop_words, **kwargs) - elif args.command == "gen-image": + inference = NexaTextInference(model_path=model_path, local_path=local_path, stop_words=stop_words, **kwargs) + elif run_type == "Computer Vision": from nexa.gguf.nexa_inference_image import NexaImageInference - inference = NexaImageInference(model_path, **kwargs) + inference = NexaImageInference(model_path=model_path, local_path=local_path, **kwargs) if hasattr(args, 'streamlit') and args.streamlit: inference.run_streamlit(model_path) elif args.img2img: @@ -57,14 +28,41 @@ def run_ggml_inference(args): else: inference.run_txt2img() return - elif args.command == "vlm": + elif run_type == "Multimodal": from nexa.gguf.nexa_inference_vlm import NexaVLMInference - inference = NexaVLMInference(model_path, stop_words=stop_words, **kwargs) - elif args.command == "asr": + inference = NexaVLMInference(model_path=model_path, local_path=local_path, stop_words=stop_words, **kwargs) + elif run_type == "Audio": from nexa.gguf.nexa_inference_voice import NexaVoiceInference - inference = NexaVoiceInference(model_path, **kwargs) + inference = NexaVoiceInference(model_path=model_path, local_path=local_path, **kwargs) else: - raise ValueError(f"Unknown command: {args.command}") + raise ValueError(f"Unknown task: {run_type}") + + if hasattr(args, 'streamlit') and args.streamlit: + inference.run_streamlit(model_path) + else: + inference.run() + +def run_onnx_inference(args): + kwargs = {k: v for k, v in vars(args).items() if v is not None} + model_path = kwargs.pop("model_path") + + from nexa.general import pull_model + local_path, run_type = pull_model(model_path) + + if run_type == "NLP": + from nexa.onnx.nexa_inference_text import NexaTextInference as NexaTextOnnxInference + inference = NexaTextOnnxInference(model_path=model_path, local_path=local_path, **kwargs) + elif run_type == "Computer Vision": + from nexa.onnx.nexa_inference_image import NexaImageInference as NexaImageOnnxInference + inference = NexaImageOnnxInference(model_path=model_path, local_path=local_path, **kwargs) + elif run_type == "Audio": + from nexa.onnx.nexa_inference_voice import NexaVoiceInference as NexaVoiceOnnxInference + inference = NexaVoiceOnnxInference(model_path=model_path, local_path=local_path, **kwargs) + elif run_type == "TTS": + from nexa.onnx.nexa_inference_tts import NexaTTSInference as NexaTTSOnnxInference + inference = NexaTTSOnnxInference(model_path=model_path, local_path=local_path, **kwargs) + else: + raise ValueError(f"Unknown task: {run_type}") if hasattr(args, 'streamlit') and args.streamlit: inference.run_streamlit(model_path) @@ -86,104 +84,71 @@ def main(): subparsers = parser.add_subparsers(dest="command", help="sub-command help") - # ONNX subparsers - onnx_parser = subparsers.add_parser("onnx", help="Run ONNX models for inference.") - onnx_subparsers = onnx_parser.add_subparsers(dest="onnx_command", help="ONNX sub-command help") - - # ONNX Text Generation - onnx_text_parser = onnx_subparsers.add_parser("gen-text", help="Run ONNX model for text generation.") - onnx_text_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - onnx_text_parser.add_argument("-t", "--temperature", type=float, default=0.8, help="Temperature for sampling") - onnx_text_parser.add_argument("-m", "--max_new_tokens", type=int, default=512, help="Maximum number of new tokens to generate") - onnx_text_parser.add_argument("-k", "--top_k", type=int, default=50, help="Top-k sampling parameter") - onnx_text_parser.add_argument("-p", "--top_p", type=float, default=1.0, help="Top-p sampling parameter") - onnx_text_parser.add_argument("-sw", "--stop_words", nargs="*", default=[], help="List of stop words for early stopping") - onnx_text_parser.add_argument("-pf", "--profiling", action="store_true", help="Enable profiling logs for the inference process") - onnx_text_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") - - # ONNX Image Generation - onnx_image_parser = onnx_subparsers.add_parser("gen-image", help="Run ONNX model for image generation.") - onnx_image_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - onnx_image_parser.add_argument("-ns", "--num_inference_steps", type=int, default=20, help="Number of inference steps") - onnx_image_parser.add_argument("-np", "--num_images_per_prompt", type=int, default=1, help="Number of images to generate per prompt") - onnx_image_parser.add_argument("-H", "--height", type=int, default=512, help="Height of the output image") - onnx_image_parser.add_argument("-W", "--width", type=int, default=512, help="Width of the output image") - onnx_image_parser.add_argument("-g", "--guidance_scale", type=float, default=7.5, help="Guidance scale for diffusion") - onnx_image_parser.add_argument("-o", "--output", type=str, default="generated_images/image.png", help="Output path for the generated image") - onnx_image_parser.add_argument("-s", "--random_seed", type=int, default=41, help="Random seed for image generation") - onnx_image_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") - - # ONNX ASR parser - onnx_asr_parser = onnx_subparsers.add_parser('asr', help='Run ONNX model for auto-speech-recognition.') - onnx_asr_parser.add_argument("model_path", type=str, help="Path or identifier for the model in S3") - onnx_asr_parser.add_argument("-o", "--output_dir", type=str, default="transcriptions", help="Output directory for transcriptions") - onnx_asr_parser.add_argument("-r", "--sampling_rate", type=int, default=16000, help="Sampling rate for audio processing") - onnx_asr_parser.add_argument("-st", "--streamlit", action='store_true', help="Run the inference in Streamlit UI") - - # ONNX voice-generation parser - onnx_tts_parser = onnx_subparsers.add_parser('tts', help='Run ONNX model for text-to-speech generation.') - onnx_tts_parser.add_argument("model_path", type=str, help="Path or identifier for the model in S3") - onnx_tts_parser.add_argument("-o", "--output_dir", type=str, default="tts", help="Output directory for tts") - onnx_tts_parser.add_argument("-r", "--sampling_rate", type=int, default=16000, help="Sampling rate for audio processing") - onnx_tts_parser.add_argument("-st", "--streamlit", action='store_true', help="Run the inference in Streamlit UI") - - # ONNX server parser - onnx_server_parser = onnx_subparsers.add_parser("server", help="Run the Nexa AI Text Generation Service") - onnx_server_parser.add_argument("model_path", type=str, help="Path or identifier for the model in S3") - onnx_server_parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to") - onnx_server_parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to") - onnx_server_parser.add_argument("--reload", action="store_true", help="Enable automatic reloading on code changes") - - # GGML Text Generation - gen_text_parser = subparsers.add_parser("gen-text", help="Run a GGUF model locally for text generation.") - gen_text_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - gen_text_parser.add_argument("-t", "--temperature", type=float, default=0.8, help="Temperature for sampling") - gen_text_parser.add_argument("-m", "--max_new_tokens", type=int, default=512, help="Maximum number of new tokens to generate") - gen_text_parser.add_argument("-k", "--top_k", type=int, default=50, help="Top-k sampling parameter") - gen_text_parser.add_argument("-p", "--top_p", type=float, default=1.0, help="Top-p sampling parameter") - gen_text_parser.add_argument("-sw", "--stop_words", nargs="*", default=[], help="List of stop words for early stopping") - gen_text_parser.add_argument("-pf", "--profiling", action="store_true", help="Enable profiling logs for the inference process") - gen_text_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") - - # GGML Image Generation - gen_image_parser = subparsers.add_parser("gen-image", help="Run a GGUF model locally for image generation.") - gen_image_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - gen_image_parser.add_argument("-i2i","--img2img",action="store_true",help="Whether to run image-to-image generation") - gen_image_parser.add_argument("-ns", "--num_inference_steps", type=int, help="Number of inference steps") - gen_image_parser.add_argument("-np", "--num_images_per_prompt", type=int, default=1, help="Number of images to generate per prompt") - gen_image_parser.add_argument("-H", "--height", type=int, help="Height of the output image") - gen_image_parser.add_argument("-W", "--width", type=int, help="Width of the output image") - gen_image_parser.add_argument("-g", "--guidance_scale", type=float, help="Guidance scale for diffusion") - gen_image_parser.add_argument("-o", "--output", type=str, default="generated_images/image.png", help="Output path for the generated image") - gen_image_parser.add_argument("-s", "--random_seed", type=int, help="Random seed for image generation") - gen_image_parser.add_argument("--lora_dir", type=str, help="Path to directory containing LoRA files") - gen_image_parser.add_argument("--wtype", type=str, help="weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0). If not specified, the default is the type of the weight file.") - gen_image_parser.add_argument("--control_net_path", type=str, help="Path to control net model") - gen_image_parser.add_argument("--control_image_path", type=str, help="Path to image condition for Control Net") - gen_image_parser.add_argument("--control_strength", type=str, help="Strength to apply Control Net") - gen_image_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") - - # GGML VLM Inference - vlm_parser = subparsers.add_parser("vlm", help="Run a GGUF model locally for VLM inference.") - vlm_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - vlm_parser.add_argument("-t", "--temperature", type=float, default=0.8, help="Temperature for sampling") - vlm_parser.add_argument("-m", "--max_new_tokens", type=int, default=2048, help="Maximum number of new tokens to generate") - vlm_parser.add_argument("-k", "--top_k", type=int, default=50, help="Top-k sampling parameter") - vlm_parser.add_argument("-p", "--top_p", type=float, default=1.0, help="Top-p sampling parameter") - vlm_parser.add_argument("-sw", "--stop_words", nargs="*", default=[], help="List of stop words for early stopping") - vlm_parser.add_argument("-pf", "--profiling", action="store_true", help="Enable profiling logs for the inference process") - vlm_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") - - # Voice Inference - asr_parser = subparsers.add_parser("asr", help="Run a GGUF model locally for voice inference.") - asr_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") - asr_parser.add_argument("-o", "--output_dir", type=str, default="transcriptions", help="Output directory for transcriptions") - asr_parser.add_argument("-b", "--beam_size", type=int, default=5, help="Beam size to use for transcription") - asr_parser.add_argument("-l", "--language", type=str, default=None, help="The language spoken in the audio. It should be a language code such as 'en' or 'fr'.") - asr_parser.add_argument("--task", type=str, default="transcribe", help="Task to execute (transcribe or translate)") - asr_parser.add_argument("-t", "--temperature", type=float, default=0.0, help="Temperature for sampling") - asr_parser.add_argument("-c", "--compute_type", type=str, default="default", help="Type to use for computation (e.g., float16, int8, int8_float16)") - asr_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") + # Run command + run_parser = subparsers.add_parser("run", help="Run inference for various tasks using GGUF models.") + run_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") + run_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") + run_parser.add_argument("-pf", "--profiling", action="store_true", help="Enable profiling logs for the inference process") + + # Text generation/vlm arguments + text_group = run_parser.add_argument_group('Text generation/VLM options') + text_group.add_argument("-t", "--temperature", type=float, help="Temperature for sampling") + text_group.add_argument("-m", "--max_new_tokens", type=int, help="Maximum number of new tokens to generate") + text_group.add_argument("-k", "--top_k", type=int, help="Top-k sampling parameter") + text_group.add_argument("-p", "--top_p", type=float, help="Top-p sampling parameter") + text_group.add_argument("-sw", "--stop_words", nargs="*", help="List of stop words for early stopping") + + # Image generation arguments + image_group = run_parser.add_argument_group('Image generation options') + image_group.add_argument("-i2i", "--img2img", action="store_true", help="Whether to run image-to-image generation") + image_group.add_argument("-ns", "--num_inference_steps", type=int, help="Number of inference steps") + image_group.add_argument("-np", "--num_images_per_prompt", type=int, help="Number of images to generate per prompt") + image_group.add_argument("-H", "--height", type=int, help="Height of the output image") + image_group.add_argument("-W", "--width", type=int, help="Width of the output image") + image_group.add_argument("-g", "--guidance_scale", type=float, help="Guidance scale for diffusion") + image_group.add_argument("-o", "--output", type=str, help="Output path for the generated image") + image_group.add_argument("-s", "--random_seed", type=int, help="Random seed for image generation") + image_group.add_argument("--lora_dir", type=str, help="Path to directory containing LoRA files") + image_group.add_argument("--wtype", type=str, help="Weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)") + image_group.add_argument("--control_net_path", type=str, help="Path to control net model") + image_group.add_argument("--control_image_path", type=str, help="Path to image condition for Control Net") + image_group.add_argument("--control_strength", type=str, help="Strength to apply Control Net") + + # ASR arguments + asr_group = run_parser.add_argument_group('Automatic Speech Recognition options') + asr_group.add_argument("-b", "--beam_size", type=int, help="Beam size to use for transcription") + asr_group.add_argument("-l", "--language", type=str, help="Language code for audio (e.g., 'en' or 'fr')") + asr_group.add_argument("--task", type=str, help="Task to execute (transcribe or translate)") + asr_group.add_argument("-c", "--compute_type", type=str, help="Type to use for computation") + + # ONNX command + onnx_parser = subparsers.add_parser("onnx", help="Run inference for various tasks using ONNX models.") + onnx_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") + onnx_parser.add_argument("-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI") + + # ONNX Text generation arguments + onnx_text_group = onnx_parser.add_argument_group('Text generation options') + onnx_text_group.add_argument("-t", "--temperature", type=float, help="Temperature for sampling") + onnx_text_group.add_argument("-m", "--max_new_tokens", type=int, help="Maximum number of new tokens to generate") + onnx_text_group.add_argument("-k", "--top_k", type=int, help="Top-k sampling parameter") + onnx_text_group.add_argument("-p", "--top_p", type=float, help="Top-p sampling parameter") + onnx_text_group.add_argument("-sw", "--stop_words", nargs="*", help="List of stop words for early stopping") + onnx_text_group.add_argument("-pf", "--profiling", action="store_true", help="Enable profiling logs for the inference process") + + # ONNX Image generation arguments + onnx_image_group = onnx_parser.add_argument_group('Image generation options') + onnx_image_group.add_argument("-ns", "--num_inference_steps", type=int, help="Number of inference steps") + onnx_image_group.add_argument("-np", "--num_images_per_prompt", type=int, help="Number of images to generate per prompt") + onnx_image_group.add_argument("-H", "--height", type=int, help="Height of the output image") + onnx_image_group.add_argument("-W", "--width", type=int, help="Width of the output image") + onnx_image_group.add_argument("-g", "--guidance_scale", type=float, help="Guidance scale for diffusion") + onnx_image_group.add_argument("-O", "--output", type=str, help="Output path for the generated image") + onnx_image_group.add_argument("-s", "--random_seed", type=int, help="Random seed for image generation") + + # ONNX Voice arguments + onnx_voice_group = onnx_parser.add_argument_group('Voice generation options') + onnx_voice_group.add_argument("-o", "--output_dir", type=str, default="voice_output", help="Output directory for audio processing") + onnx_voice_group.add_argument("-r", "--sampling_rate", type=int, default=16000, help="Sampling rate for audio processing") # GGML server parser server_parser = subparsers.add_parser("server", help="Run the Nexa AI Text Generation Service") @@ -192,26 +157,31 @@ def main(): server_parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to") server_parser.add_argument("--reload", action="store_true", help="Enable automatic reloading on code changes") - # GGML general + # Other commands subparsers.add_parser("pull", help="Pull a model from official or hub.").add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") subparsers.add_parser("remove", help="Remove a model from local machine.").add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") + subparsers.add_parser("clean", help="Clean up all model files.") subparsers.add_parser("list", help="List all models in the local machine.") subparsers.add_parser("login", help="Login to Nexa API.") subparsers.add_parser("whoami", help="Show current user information.") subparsers.add_parser("logout", help="Logout from Nexa API.") + args = parser.parse_args() - if args.command == "onnx": - run_onnx_inference(args) - elif args.command in ["gen-text", "gen-image", "vlm", "asr", "server"]: + if args.command in ["run", "server"]: run_ggml_inference(args) + elif args.command == "onnx": + run_onnx_inference(args) elif args.command == "pull": from nexa.general import pull_model pull_model(args.model_path) elif args.command == "remove": from nexa.general import remove_model remove_model(args.model_path) + elif args.command == "clean": + from nexa.general import clean + clean() elif args.command == "list": from nexa.general import list_models list_models() @@ -227,6 +197,5 @@ def main(): else: parser.print_help() - if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/nexa/constants.py b/nexa/constants.py index 693965e1..17812f88 100644 --- a/nexa/constants.py +++ b/nexa/constants.py @@ -61,7 +61,12 @@ "phi3v": "Phi-3-vision-128k-instruct:onnx-cpu-int4", "sd1-5": "stable-diffusion-v1-5:onnx-cpu-fp32", "lcm-dreamshaper": "lcm-dreamshaper-v7:onnx-cpu-fp32", - "whisper": "whisper-tiny:onnx-cpu-fp32", + "whisper-tiny.en": "whisper-tiny.en:onnx-cpu-fp32", + "whisper-tiny": "whisper-tiny:onnx-cpu-fp32", + "whisper-small.en": "whisper-small.en:onnx-cpu-fp32", + "whisper-small": "whisper-small:onnx-cpu-fp32", + "whisper-base.en": "whisper-base.en:onnx-cpu-fp32", + "whisper-base": "whisper-base:onnx-cpu-fp32", "ljspeech": "ljspeech-jets:onnx-cpu-fp32", } @@ -86,10 +91,14 @@ "whisper-large": "whisper-large:bin-large-v3", "whisper-tiny": "whisper-tiny:bin-tiny", "faster-whisper-tiny": "faster-whisper-tiny:bin-cpu-fp16", + "faster-whisper-tiny.en": "faster-whisper-tiny.en:bin-cpu-fp16", "faster-whisper-small": "faster-whisper-small:bin-cpu-fp16", + "faster-whisper-small.en": "faster-whisper-small.en:bin-cpu-fp16", "faster-whisper-medium": "faster-whisper-medium:bin-cpu-fp16", + "faster-whisper-medium.en": "faster-whisper-medium.en:bin-cpu-fp16", "faster-whisper-base": "faster-whisper-base:bin-cpu-fp16", - "faster-whisper-large": "faster-whisper-large:bin-cpu-fp16", + "faster-whisper-base.en": "faster-whisper-base.en:bin-cpu-fp16", + "faster-whisper-large": "faster-whisper-large-v3:bin-cpu-fp16", } NEXA_RUN_MODEL_MAP_FUNCTION_CALLING = { @@ -129,6 +138,8 @@ **NEXA_RUN_MODEL_MAP_TEXT, **NEXA_RUN_MODEL_MAP_VLM, **NEXA_RUN_MODEL_MAP_IMAGE, + **NEXA_RUN_MODEL_MAP_VOICE, + **NEXA_RUN_MODEL_MAP_FUNCTION_CALLING, } NEXA_RUN_CHAT_TEMPLATE_MAP = { @@ -205,4 +216,79 @@ "task": "transcribe", "temperature": 0.0, "compute_type": "default" -} \ No newline at end of file +} + +NEXA_OFFICIAL_MODELS_TYPE = { + 'gemma-2b': 'NLP', + 'Llama-2-7b-chat': 'NLP', + 'Llama-2-7b': 'NLP', + 'Meta-Llama-3-8B-Instruct': 'NLP', + 'Meta-Llama-3.1-8B-Instruct': 'NLP', + 'Mistral-7B-Instruct-v0.3': 'NLP', + 'Mistral-7B-Instruct-v0.2': 'NLP', + 'Phi-3-mini-128k-instruct': 'NLP', + 'Phi-3-mini-4k-instruct': 'NLP', + 'CodeQwen1.5-7B-Instruct': 'NLP', + 'Qwen2-0.5B-Instruct': 'NLP', + 'Qwen2-1.5B-Instruct': 'NLP', + 'Qwen2-7B-Instruct': 'NLP', + 'codegemma-2b': 'NLP', + 'gemma-1.1-2b-instruct': 'NLP', + 'gemma-2b-instruct': 'NLP', + 'gemma-2-9b-instruct': 'NLP', + 'gemma-1.1-7b-instruct': 'NLP', + 'gemma-7b-instruct': 'NLP', + 'gemma-7b': 'NLP', + 'Qwen2-1.5B': 'NLP', + 'codegemma-7b': 'NLP', + 'TinyLlama-1.1B-Chat-v1.0': 'NLP', + 'CodeLlama-7b-Instruct': 'NLP', + 'gpt2': 'NLP', + 'CodeLlama-7b': 'NLP', + 'CodeLlama-7b-Python': 'NLP', + 'Qwen1.5-7B-Instruct': 'NLP', + 'Qwen1.5-7B': 'NLP', + 'Phi-2': 'NLP', + 'deepseek-coder-1.3b-instruct': 'NLP', + 'deepseek-coder-1.3b-base': 'NLP', + 'deepseek-coder-6.7b-instruct': 'NLP', + 'dolphin-2.8-mistral-7b': 'NLP', + 'gemma-2-2b-instruct': 'NLP', + 'Octopus-v2': 'NLP', + 'Octopus-v4': 'NLP', + 'Octo-planner': 'NLP', + 'deepseek-coder-6.7b-base': 'NLP', + 'Llama2-7b-chat-uncensored': 'NLP', + 'Llama3-8B-Lexi-Uncensored': 'NLP', + 'Llama2-7b-function-calling': 'NLP', + 'lcm-dreamshaper-v7': 'Computer Vision', + 'stable-diffusion-v1-5': 'Computer Vision', + 'stable-diffusion-v1-4': 'Computer Vision', + 'stable-diffusion-v2-1': 'Computer Vision', + 'sdxl-turbo': 'Computer Vision', + 'hassaku-hentai-model-v13-LCM': 'Computer Vision', + 'anything-v30-LCM': 'Computer Vision', + 'Phi-3-vision-128k-instruct': 'Multimodal', + 'nanoLLaVA': 'Multimodal', + 'llava-v1.6-mistral-7b': 'Multimodal', + 'llava-v1.6-vicuna-7b': 'Multimodal', + 'llava-phi-3-mini': 'Multimodal', + 'llava-llama-3-8b-v1.1': 'Multimodal', + 'faster-whisper-tiny.en': 'Audio', + 'faster-whisper-tiny': 'Audio', + 'faster-whisper-small.en': 'Audio', + 'faster-whisper-small': 'Audio', + 'faster-whisper-medium.en': 'Audio', + 'faster-whisper-medium': 'Audio', + 'faster-whisper-base.en': 'Audio', + 'faster-whisper-base': 'Audio', + 'faster-whisper-large-v3': 'Audio', + 'whisper-tiny.en': 'Audio', + 'whisper-tiny': 'Audio', + 'whisper-small.en': 'Audio', + 'whisper-small': 'Audio', + 'whisper-base.en': 'Audio', + 'whisper-base': 'Audio', +} + + diff --git a/nexa/general.py b/nexa/general.py index 0143fa2f..b36f0688 100644 --- a/nexa/general.py +++ b/nexa/general.py @@ -1,7 +1,7 @@ import json import logging from pathlib import Path - +import shutil import requests from nexa.constants import ( @@ -13,6 +13,7 @@ NEXA_OFFICIAL_BUCKET, NEXA_RUN_MODEL_MAP, NEXA_TOKEN_PATH, + NEXA_OFFICIAL_MODELS_TYPE, ) @@ -103,25 +104,25 @@ def pull_model(model_path): try: if is_model_exists(model_path): - location = get_model_location(model_path) + location, run_type = get_model_info(model_path) logging.debug(f"Model {model_path} already exists at {location}") - return location + return location, run_type if "/" in model_path: - success, location, model_type = pull_model_from_hub(model_path) + result = pull_model_from_hub(model_path) else: - success, location, model_type = pull_model_from_official(model_path) + result = pull_model_from_official(model_path) - if success: - add_model_to_list(model_path, location, model_type) - logging.debug(f"Successfully pulled model {model_path} to {location}") - return location + if result["success"]: + add_model_to_list(model_path, result["local_path"], result["model_type"], result["run_type"]) + logging.debug(f"Successfully pulled model {model_path} to {result['local_path']}, run_type: {result['run_type']}") + return result["local_path"], result["run_type"] else: logging.debug(f"Failed to pull model {model_path}") - return None + return None, "UNKNOWN" except Exception as e: logging.error(f"An error occurred while pulling the model: {e}") - return None + return None, "UNKNOWN" def pull_model_from_hub(model_path): @@ -133,13 +134,20 @@ def pull_model_from_hub(model_path): token = file.read().strip() try: - presigned_links = get_model_presigned_link(model_path, token) + result = get_model_presigned_link(model_path, token) + run_type = result['type'] + presigned_links = result['presigned_urls'] except Exception as e: print(f"Failed to get download models: {e}") - return False, None + return { + "success": False, + "local_path": None, + "model_type": None, + "run_type": None + } success = True - model_location = None + local_path = None model_type = "undefined" # Determine model_type @@ -159,19 +167,24 @@ def pull_model_from_hub(model_path): download_path = NEXA_MODELS_HUB_DIR / file_path download_file_with_progress(presigned_link, download_path) - if model_location is None: + if local_path is None: if model_type == "onnx" or model_type == "bin": - model_location = str(download_path.parent) + local_path = str(download_path.parent) elif model_type == "gguf": - model_location = str(download_path) + local_path = str(download_path) else: # undefined - model_location = str(download_path.parent) + local_path = str(download_path.parent) except Exception as e: print(f"Failed to download {file_path}: {e}") success = False - return success, model_location, model_type + return { + "success": success, + "local_path": local_path, + "model_type": model_type, + "run_type": run_type + } def pull_model_from_official(model_path): @@ -183,8 +196,21 @@ def pull_model_from_official(model_path): model_type = "bin" else: model_type = "gguf" + + run_type = get_run_type_from_model_path(model_path) success, location = download_model_from_official(model_path, model_type) - return success, location, model_type + + return { + "success": success, + "local_path": location, + "model_type": model_type, + "run_type": run_type + } + + +def get_run_type_from_model_path(model_path): + model_name, model_version = model_path.split(":") + return NEXA_OFFICIAL_MODELS_TYPE.get(model_name, "UNKNOWN") def get_model_presigned_link(full_path, token): @@ -192,11 +218,11 @@ def get_model_presigned_link(full_path, token): Get the presigned links for downloading the contents of a model folder. Args: - full_path (str): The full path of the folder to download (e.g., "gpt2-onnx/"). - token (str): The authentication token. Can be empty. + full_path (str): The full path of the folder to download (e.g., "openai/gpt2:gguf-q2_K"). + token (str, optional): The authentication token. Defaults to None. Returns: - dict: A dictionary containing the contents of the folder, where keys are file paths and values are presigned links. + dict: A dictionary containing the model type and presigned URLs. """ url = f"{NEXA_API_URL}/model/download-tag-folder" @@ -205,13 +231,20 @@ def get_model_presigned_link(full_path, token): if token: headers["Authorization"] = f"Bearer {token}" - body = {"full_path": full_path} + body = {"full_path": full_path, "need_type": True} try: response = requests.post(url, headers=headers, json=body) response.raise_for_status() - folder_contents = response.json() - return folder_contents + result = response.json() + + run_type = result.get("type", [])[0] if result.get("type") else None + presigned_urls = result.get("presigned_urls", {}) + + return { + "run_type": run_type, + "presigned_urls": presigned_urls + } except requests.exceptions.RequestException as e: print(f"API request failed: {e}") @@ -286,7 +319,7 @@ def is_model_exists(model_name): return model_name in model_list -def add_model_to_list(model_name, model_location, model_type): +def add_model_to_list(model_name, model_location, model_type, run_type): NEXA_MODEL_LIST_PATH.parent.mkdir(parents=True, exist_ok=True) if NEXA_MODEL_LIST_PATH.exists(): @@ -295,20 +328,28 @@ def add_model_to_list(model_name, model_location, model_type): else: model_list = {} - model_list[model_name] = {"type": model_type, "location": model_location} + model_list[model_name] = { + "type": model_type, + "location": model_location, + "run_type": run_type + } with open(NEXA_MODEL_LIST_PATH, "w") as f: json.dump(model_list, f, indent=2) -def get_model_location(model_name): +def get_model_info(model_name): if not NEXA_MODEL_LIST_PATH.exists(): - return None + return None, None with open(NEXA_MODEL_LIST_PATH, "r") as f: model_list = json.load(f) - return model_list.get(model_name).get("location") + model_data = model_list.get(model_name, {}) + location = model_data.get("location") + run_type = model_data.get("run_type") + + return location, run_type def list_models(): @@ -320,15 +361,15 @@ def list_models(): model_list = json.load(f) table = [ - (model_name, model_info["type"], model_info["location"]) + (model_name, model_info["type"], model_info["run_type"], model_info["location"]) for model_name, model_info in model_list.items() ] - headers = ["Model Name", "Type", "Location"] + headers = ["Model Name", "Type", "Run Type", "Location"] from tabulate import tabulate print( tabulate( - table, headers, tablefmt="pretty", colalign=("left", "left", "left") + table, headers, tablefmt="pretty", colalign=("left", "left", "left", "left") ) ) except Exception as e: @@ -359,8 +400,6 @@ def remove_model(model_path): model_path.unlink() print(f"Deleted model file: {model_path}") elif model_path.is_dir(): - import shutil - shutil.rmtree(model_path) print(f"Deleted model directory: {model_path}") else: @@ -375,6 +414,32 @@ def remove_model(model_path): except Exception as e: print(f"An error occurred while removing the model: {e}") return None + +def clean(): + if not NEXA_MODELS_HUB_DIR.exists(): + print(f"Nothing to clean.") + return + + # Ask for user confirmation + confirmation = input(f"This will remove all downloaded models and the model list. Are you sure? (y/N): ").lower().strip() + + if confirmation != 'y': + print("Operation cancelled.") + return + + try: + # Remove all contents of the directory + for item in NEXA_MODELS_HUB_DIR.iterdir(): + if item.is_file(): + item.unlink() + elif item.is_dir(): + shutil.rmtree(item) + + print(f"Successfully removed all contents from {NEXA_MODELS_HUB_DIR}") + + except Exception as e: + print(f"An error occurred while cleaning the directory: {e}") + if __name__ == "__main__": diff --git a/nexa/gguf/nexa_inference_image.py b/nexa/gguf/nexa_inference_image.py index 4295797e..b0086d61 100644 --- a/nexa/gguf/nexa_inference_image.py +++ b/nexa/gguf/nexa_inference_image.py @@ -5,8 +5,6 @@ import sys import time from pathlib import Path - -from nexa.general import pull_model from nexa.constants import ( DEFAULT_IMG_GEN_PARAMS, EXIT_REMINDER, @@ -19,6 +17,7 @@ from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr from streamlit.web import cli as stcli +from nexa.general import pull_model logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" @@ -30,36 +29,30 @@ class NexaImageInference: A class used for loading image models and running image generation. Methods: - txt2img: (Used for SDK) Run the text-to-image generation loop. - img2img: (Used for SDK) Run the image-to-image generation loop. - run_streamlit: Run the Streamlit UI. + txt2img: (Used for SDK) Run the text-to-image generation loop. + img2img: (Used for SDK) Run the image-to-image generation loop. + run_streamlit: Run the Streamlit UI. Args: - model_path (str): Path or identifier for the model in Nexa Model Hub. - num_inference_steps (int): Number of inference steps. - width (int): Width of the output image. - height (int): Height of the output image. - guidance_scale (float): Guidance scale for diffusion. - output_path (str): Output path for the generated image. - random_seed (int): Random seed for image generation. - streamlit (bool): Run the inference in Streamlit UI. + model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str): Local path of the model. + num_inference_steps (int): Number of inference steps. + width (int): Width of the output image. + height (int): Height of the output image. + guidance_scale (float): Guidance scale for diffusion. + output_path (str): Output path for the generated image. + random_seed (int): Random seed for image generation. + streamlit (bool): Run the inference in Streamlit UI. """ - def __init__(self, model_path, **kwargs): - self.model_path = None - self.downloaded_path = None - if model_path in NEXA_RUN_MODEL_MAP: - logging.debug(f"Found model {model_path} in public hub") - self.model_path = NEXA_RUN_MODEL_MAP.get(model_path) - self.downloaded_path = pull_model(self.model_path) - elif os.path.exists(model_path): - logging.debug(f"Using local model at {model_path}") - self.downloaded_path = model_path - else: - logging.debug(f"Trying to use model from hub at {model_path}") - self.downloaded_path = pull_model(model_path) + def __init__(self, model_path, local_path=None, **kwargs): + self.model_path = model_path + self.downloaded_path = local_path + + if self.downloaded_path is None: + self.downloaded_path, run_type = pull_model(self.model_path) if self.downloaded_path is None: logging.error( @@ -68,15 +61,14 @@ def __init__(self, model_path, **kwargs): ) exit(1) - if self.model_path == "lcm-dreamshaper-v7:fp16": + if "lcm-dreamshaper" in self.model_path: self.params = DEFAULT_IMG_GEN_PARAMS_LCM - elif self.model_path == "sdxl-turbo:q8_0": + elif "sdxl-turbo" in self.model_path: self.params = DEFAULT_IMG_GEN_PARAMS_TURBO else: self.params = DEFAULT_IMG_GEN_PARAMS self.params.update(kwargs) - if not kwargs.get("streamlit", False): self._load_model(model_path) if self.model is None: diff --git a/nexa/gguf/nexa_inference_text.py b/nexa/gguf/nexa_inference_text.py index 043d58e2..f12054f1 100644 --- a/nexa/gguf/nexa_inference_text.py +++ b/nexa/gguf/nexa_inference_text.py @@ -12,8 +12,8 @@ NEXA_RUN_MODEL_MAP, NEXA_STOP_WORDS_MAP, ) -from nexa.general import pull_model from nexa.gguf.lib_utils import is_gpu_available +from nexa.general import pull_model from nexa.utils import SpinningCursorAnimation, nexa_prompt from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr @@ -25,41 +25,36 @@ class NexaTextInference: """ - A class used for load text models and run text generation. + A class used for loading text models and running text generation. Methods: run: Run the text generation loop. run_streamlit: Run the Streamlit UI. - + create_embedding: Embed a string. + create_chat_completion: Generate completion for a chat conversation. + create_completion: Generate completion for a given prompt. Args: - model_path (str): Path or identifier for the model in Nexa Model Hub. - embedding (bool): Enable embedding generation. - stop_words (list): List of stop words for early stopping. - profiling (bool): Enable timing measurements for the generation process. - streamlit (bool): Run the inference in Streamlit UI. - temperature (float): Temperature for sampling. - max_new_tokens (int): Maximum number of new tokens to generate. - top_k (int): Top-k sampling parameter. - top_p (float): Top-p sampling parameter + model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str, optional): Local path of the model. + embedding (bool): Enable embedding generation. + stop_words (list): List of stop words for early stopping. + profiling (bool): Enable timing measurements for the generation process. + streamlit (bool): Run the inference in Streamlit UI. + temperature (float): Temperature for sampling. + max_new_tokens (int): Maximum number of new tokens to generate. + top_k (int): Top-k sampling parameter. + top_p (float): Top-p sampling parameter """ - - def __init__(self, model_path, stop_words=None, **kwargs): + def __init__(self, model_path, local_path=None, stop_words=None, **kwargs): self.params = DEFAULT_TEXT_GEN_PARAMS self.params.update(kwargs) self.model = None - - self.model_path = None - self.downloaded_path = None - if model_path in NEXA_RUN_MODEL_MAP: - logging.debug(f"Found model {model_path} in public hub") - self.model_path = NEXA_RUN_MODEL_MAP.get(model_path) - self.downloaded_path = pull_model(self.model_path) - elif os.path.exists(model_path): - logging.debug(f"Using local model at {model_path}") - self.downloaded_path = model_path - else: - logging.debug(f"Trying to use model from hub at {model_path}") - self.downloaded_path = pull_model(model_path) + + self.model_path = model_path + self.downloaded_path = local_path + + if self.downloaded_path is None: + self.downloaded_path, run_type = pull_model(self.model_path) if self.downloaded_path is None: logging.error( @@ -71,7 +66,6 @@ def __init__(self, model_path, stop_words=None, **kwargs): self.stop_words = ( stop_words if stop_words else NEXA_STOP_WORDS_MAP.get(model_path, []) ) - self.profiling = kwargs.get("profiling", False) self.chat_format = NEXA_RUN_CHAT_TEMPLATE_MAP.get(model_path, None) diff --git a/nexa/gguf/nexa_inference_vlm.py b/nexa/gguf/nexa_inference_vlm.py index 6b611808..da8103f5 100644 --- a/nexa/gguf/nexa_inference_vlm.py +++ b/nexa/gguf/nexa_inference_vlm.py @@ -79,6 +79,7 @@ class NexaVLMInference: Args: model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str): Local path of the model. stop_words (list): List of stop words for early stopping. profiling (bool): Enable timing measurements for the generation process. streamlit (bool): Run the inference in Streamlit UI. @@ -87,34 +88,31 @@ class NexaVLMInference: top_k (int): Top-k sampling parameter. top_p (float): Top-p sampling parameter """ - - - def __init__(self, model_path, stop_words=None, **kwargs): + def __init__(self, model_path, local_path=None, stop_words=None, **kwargs): self.params = DEFAULT_TEXT_GEN_PARAMS self.params.update(kwargs) self.model = None self.projector = None - self.projector_path = NEXA_RUN_PROJECTOR_MAP.get(model_path, None) - self.downloaded_path = None + self.downloaded_path = local_path self.projector_downloaded_path = None - if model_path in NEXA_RUN_MODEL_MAP_VLM: - logging.debug(f"Found model {model_path} in public hub") - self.model_path = NEXA_RUN_MODEL_MAP_VLM.get(model_path) - self.projector_path = NEXA_RUN_PROJECTOR_MAP.get(model_path) - self.downloaded_path = pull_model(self.model_path) - self.projector_downloaded_path = pull_model(self.projector_path) - elif (local_dir := Path(model_path).parent).exists(): - logging.debug(f"Using local model at {local_dir}") + if self.downloaded_path is not None: + if model_path in NEXA_RUN_MODEL_MAP_VLM: + self.projector_path = NEXA_RUN_PROJECTOR_MAP[model_path] + self.projector_downloaded_path, _ = pull_model(self.projector_path) + elif model_path in NEXA_RUN_MODEL_MAP_VLM: + self.model_path = NEXA_RUN_MODEL_MAP_VLM[model_path] + self.projector_path = NEXA_RUN_PROJECTOR_MAP[model_path] + self.downloaded_path, _ = pull_model(self.model_path) + self.projector_downloaded_path, _ = pull_model(self.projector_path) + elif Path(model_path).parent.exists(): + local_dir = Path(model_path).parent model_name = Path(model_path).name tag_and_ext = model_name.split(":")[-1] self.downloaded_path = local_dir / f"model-{tag_and_ext}" self.projector_downloaded_path = local_dir / f"projector-{tag_and_ext}" - if not ( - self.downloaded_path.exists() - and self.projector_downloaded_path.exists() - ): + if not (self.downloaded_path.exists() and self.projector_downloaded_path.exists()): logging.error( f"Model or projector not found in {local_dir}. " "Make sure to name them as 'model-.gguf' and 'projector-.gguf'." @@ -126,7 +124,7 @@ def __init__(self, model_path, stop_words=None, **kwargs): if self.downloaded_path is None: logging.error( - f"Model ({model_path}) is not appicable. Please refer to our docs for proper usage.", + f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", exc_info=True, ) exit(1) diff --git a/nexa/gguf/nexa_inference_voice.py b/nexa/gguf/nexa_inference_voice.py index a469f151..ef0a653b 100644 --- a/nexa/gguf/nexa_inference_voice.py +++ b/nexa/gguf/nexa_inference_voice.py @@ -23,34 +23,28 @@ class NexaVoiceInference: A class used for loading voice models and running voice transcription. Methods: - run: Run the voice transcription loop. - run_streamlit: Run the Streamlit UI. + run: Run the voice transcription loop. + run_streamlit: Run the Streamlit UI. Args: - model_path (str): Path or identifier for the model in Nexa Model Hub. - output_dir (str): Output directory for transcriptions. - beam_size (int): Beam size to use for transcription. - language (str): The language spoken in the audio. - task (str): Task to execute (transcribe or translate). - temperature (float): Temperature for sampling. - compute_type (str): Type to use for computation (e.g., float16, int8, int8_float16). - output_dir (str): Output directory for transcriptions. + model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str): Local path of the model. + output_dir (str): Output directory for transcriptions. + beam_size (int): Beam size to use for transcription. + language (str): The language spoken in the audio. + task (str): Task to execute (transcribe or translate). + temperature (float): Temperature for sampling. + compute_type (str): Type to use for computation (e.g., float16, int8, int8_float16). + output_dir (str): Output directory for transcriptions. """ - def __init__(self, model_path, **kwargs): - self.model_path = None - self.downloaded_path = None + def __init__(self, model_path, local_path=None, **kwargs): + self.model_path = model_path + self.downloaded_path = local_path self.params = DEFAULT_VOICE_GEN_PARAMS - if model_path in NEXA_RUN_MODEL_MAP_VOICE: - logging.debug(f"Found model {model_path} in public hub") - self.model_path = NEXA_RUN_MODEL_MAP_VOICE.get(model_path) - self.downloaded_path = pull_model(self.model_path) - elif os.path.exists(model_path): - logging.debug(f"Using local model at {model_path}") - self.downloaded_path = model_path - else: - logging.error("Using voice model from hub is not supported yet.") - exit(1) + + if self.downloaded_path is None: + self.downloaded_path, run_type = pull_model(self.model_path) if self.downloaded_path is None: logging.error( @@ -189,7 +183,7 @@ def _transcribe_audio(self, audio_path): ) transcription = "".join(segment.text for segment in segments) self._save_transcription(transcription) - logging.info(f"Transcription: {transcription}") + print(f"Transcription: {transcription}") except Exception as e: logging.error(f"Error during transcription: {e}", exc_info=True) diff --git a/nexa/gguf/server/nexa_service.py b/nexa/gguf/server/nexa_service.py index 56c19636..bc7424c5 100644 --- a/nexa/gguf/server/nexa_service.py +++ b/nexa/gguf/server/nexa_service.py @@ -66,7 +66,7 @@ async def load_model(): chat_format = NEXA_RUN_CHAT_TEMPLATE_MAP.get(model_path, None) completion_template = NEXA_RUN_COMPLETION_TEMPLATE_MAP.get(model_path, None) model_path = NEXA_RUN_MODEL_MAP_TEXT.get(model_path) - downloaded_path = pull_model(model_path) + downloaded_path, run_type = pull_model(model_path) with suppress_stdout_stderr(): try: model = Llama( @@ -89,7 +89,7 @@ async def load_model(): elif model_path in NEXA_RUN_MODEL_MAP_FUNCTION_CALLING: chat_format = "chatml-function-calling" model_path = NEXA_RUN_MODEL_MAP_FUNCTION_CALLING.get(model_path) - downloaded_path = pull_model(model_path) + downloaded_path, run_type = pull_model(model_path) with suppress_stdout_stderr(): try: model = Llama( @@ -111,7 +111,7 @@ async def load_model(): logging.info(f"model loaded as {model}") elif model_path in NEXA_RUN_MODEL_MAP_IMAGE: - downloaded_path = pull_model(model_path) + downloaded_path, run_type = pull_model(model_path) with suppress_stdout_stderr(): model = StableDiffusion( model_path=downloaded_path, @@ -123,7 +123,7 @@ async def load_model(): logging.info(f"model loaded as {model}") elif model_path in NEXA_RUN_MODEL_MAP_VOICE: model_path = NEXA_RUN_MODEL_MAP_VOICE.get(model_path) - downloaded_path = pull_model(model_path) + downloaded_path, run_type = pull_model(model_path) with suppress_stdout_stderr(): model = WhisperModel( downloaded_path, diff --git a/nexa/gguf/streamlit/streamlit_image_chat.py b/nexa/gguf/streamlit/streamlit_image_chat.py index 3117fea1..b9d8fb97 100644 --- a/nexa/gguf/streamlit/streamlit_image_chat.py +++ b/nexa/gguf/streamlit/streamlit_image_chat.py @@ -2,7 +2,7 @@ import sys from PIL import Image - +from nexa.general import pull_model import streamlit as st from nexa.gguf.nexa_inference_image import NexaImageInference @@ -11,7 +11,8 @@ @st.cache_resource def load_model(model_path): - nexa_model = NexaImageInference(model_path) + local_path, run_type = pull_model(model_path) + nexa_model = NexaImageInference(model_path=model_path, local_path=local_path) return nexa_model diff --git a/nexa/gguf/streamlit/streamlit_text_chat.py b/nexa/gguf/streamlit/streamlit_text_chat.py index 2a9d7fed..61848a0e 100644 --- a/nexa/gguf/streamlit/streamlit_text_chat.py +++ b/nexa/gguf/streamlit/streamlit_text_chat.py @@ -2,7 +2,7 @@ from typing import Iterator import streamlit as st - +from nexa.general import pull_model from nexa.gguf.nexa_inference_text import NexaTextInference default_model = sys.argv[1] @@ -11,7 +11,8 @@ @st.cache_resource def load_model(model_path): st.session_state.messages = [] - nexa_model = NexaTextInference(model_path) + local_path, run_type = pull_model(model_path) + nexa_model = NexaTextInference(model_path=model_path, local_path=local_path) return nexa_model diff --git a/nexa/gguf/streamlit/streamlit_vlm.py b/nexa/gguf/streamlit/streamlit_vlm.py index 7c8fee91..9f334d13 100644 --- a/nexa/gguf/streamlit/streamlit_vlm.py +++ b/nexa/gguf/streamlit/streamlit_vlm.py @@ -4,7 +4,7 @@ import streamlit as st from PIL import Image - +from nexa.general import pull_model from nexa.gguf.nexa_inference_vlm import NexaVLMInference default_model = sys.argv[1] @@ -12,7 +12,8 @@ @st.cache_resource def load_model(model_path): - nexa_model = NexaVLMInference(model_path) + local_path, run_type = pull_model(model_path) + nexa_model = NexaVLMInference(model_pat=model_path, local_path=local_path) return nexa_model diff --git a/nexa/gguf/streamlit/streamlit_voice_chat.py b/nexa/gguf/streamlit/streamlit_voice_chat.py index 461dcc95..949be707 100644 --- a/nexa/gguf/streamlit/streamlit_voice_chat.py +++ b/nexa/gguf/streamlit/streamlit_voice_chat.py @@ -15,7 +15,8 @@ @st.cache_resource def load_model(model_path): - nexa_model = NexaVoiceInference(model_path) + local_path, run_type = pull_model(model_path) + nexa_model = NexaVoiceInference(model_path=model_path, local_path=local_path) return nexa_model diff --git a/nexa/onnx/nexa_inference_image.py b/nexa/onnx/nexa_inference_image.py index 0a4a6db1..6566cbb2 100644 --- a/nexa/onnx/nexa_inference_image.py +++ b/nexa/onnx/nexa_inference_image.py @@ -12,9 +12,8 @@ ORTStableDiffusionPipeline, ORTStableDiffusionXLPipeline, ) - -from nexa.constants import EXIT_REMINDER, NEXA_RUN_MODEL_MAP_ONNX from nexa.general import pull_model +from nexa.constants import EXIT_REMINDER, NEXA_RUN_MODEL_MAP_ONNX from nexa.utils import nexa_prompt, SpinningCursorAnimation logging.basicConfig(level=logging.INFO) @@ -36,6 +35,7 @@ class NexaImageInference: Args: model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str): Local path of the model. num_inference_steps (int): Number of inference steps. num_images_per_prompt (int): Number of images to generate per prompt. width (int): Width of the output image. @@ -45,8 +45,9 @@ class NexaImageInference: random_seed (int): Random seed for image generation. streamlit (bool): Run the inference in Streamlit UI. """ - def __init__(self, model_path, **kwargs): + def __init__(self, model_path, local_path=None, **kwargs): self.model_path = NEXA_RUN_MODEL_MAP_ONNX.get(model_path, model_path) + self.download_onnx_folder = local_path self.params = { "num_inference_steps": 20, "num_images_per_prompt": 1, @@ -60,17 +61,18 @@ def __init__(self, model_path, **kwargs): self.pipeline = None def run(self): - # Step 1: Download the ONNX folder from S3 - downloaded_onnx_folder = pull_model(self.model_path) - if downloaded_onnx_folder is None: - logging.error("Failed to download the model. Exiting.") - return + if self.download_onnx_folder is None: + self.download_onnx_folder, run_type = pull_model(self.model_path) - # Step 2: Load the model - self._load_model(downloaded_onnx_folder) + if self.download_onnx_folder is None: + logging.error( + f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", + exc_info=True, + ) + exit(1) - # Step 3: Enter dialogue mode + self._load_model(self.download_onnx_folder) self._dialogue_mode() @SpinningCursorAnimation() diff --git a/nexa/onnx/nexa_inference_text.py b/nexa/onnx/nexa_inference_text.py index d688aff5..765d26b7 100644 --- a/nexa/onnx/nexa_inference_text.py +++ b/nexa/onnx/nexa_inference_text.py @@ -7,9 +7,8 @@ from optimum.onnxruntime import ORTModelForCausalLM from transformers import AutoTokenizer, TextStreamer - -from nexa.constants import NEXA_RUN_MODEL_MAP_ONNX from nexa.general import pull_model +from nexa.constants import NEXA_RUN_MODEL_MAP_ONNX from nexa.utils import nexa_prompt, SpinningCursorAnimation logging.basicConfig(level=logging.INFO) @@ -24,17 +23,18 @@ class NexaTextInference: run_streamlit: Run the Streamlit UI. Args: - model_path (str): Path or identifier for the model in Nexa Model Hub. - profiling (bool): Enable timing measurements for the generation process. - streamlit (bool): Run the inference in Streamlit UI. - temperature (float): Temperature for sampling. - min_new_tokens (int): Minimum number of new tokens to generate. - max_new_tokens (int): Maximum number of new tokens to generate. - top_k (int): Top-k sampling parameter. - top_p (float): Top-p sampling parameter + model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str): Local path of the model. + profiling (bool): Enable timing measurements for the generation process. + streamlit (bool): Run the inference in Streamlit UI. + temperature (float): Temperature for sampling. + min_new_tokens (int): Minimum number of new tokens to generate. + max_new_tokens (int): Maximum number of new tokens to generate. + top_k (int): Top-k sampling parameter. + top_p (float): Top-p sampling parameter """ - def __init__(self, model_path, **kwargs): + def __init__(self, model_path, local_path=None, **kwargs): self.model_path = NEXA_RUN_MODEL_MAP_ONNX.get(model_path, model_path) self.params = { "temperature": 0.5, @@ -47,7 +47,7 @@ def __init__(self, model_path, **kwargs): self.model = None self.tokenizer = None self.streamer = None - self.downloaded_onnx_folder = None + self.downloaded_onnx_folder = local_path self.timings = kwargs.get("timings", False) self.device = "cpu" @@ -145,13 +145,16 @@ def run(self): if self.params.get("streamlit"): self.run_streamlit() else: - self.downloaded_onnx_folder = pull_model( - self.model_path - ) # David TODO : move download logic to __init__(), otherwise for streamlit you will duplicate such logic again if self.downloaded_onnx_folder is None: - logging.error("Failed to download the model. Exiting.", exc_info=True) - return + self.downloaded_onnx_folder, run_type = pull_model(self.model_path) + if self.downloaded_onnx_folder is None: + logging.error( + f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", + exc_info=True, + ) + exit(1) + self._load_model_and_tokenizer() if self.model is None or self.tokenizer is None or self.streamer is None: diff --git a/nexa/onnx/nexa_inference_tts.py b/nexa/onnx/nexa_inference_tts.py index ff7093d6..36318d6b 100644 --- a/nexa/onnx/nexa_inference_tts.py +++ b/nexa/onnx/nexa_inference_tts.py @@ -11,9 +11,8 @@ from ttstokenizer import TTSTokenizer from nexa.constants import EXIT_REMINDER, NEXA_RUN_MODEL_MAP_ONNX -from nexa.general import pull_model from nexa.utils import nexa_prompt - +from nexa.general import pull_model logging.basicConfig(level=logging.INFO) @@ -27,13 +26,14 @@ class NexaTTSInference: run_streamlit: Run the Streamlit UI. Args: - model_path (str): Path or identifier for the model in Nexa Model Hub. - output_dir (str): Output directory for tts. - sampling_rate (int): Sampling rate for audio processing. - streamlit (bool): Run the inference in Streamlit UI. + model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str): Local path of the model. + output_dir (str): Output directory for tts. + sampling_rate (int): Sampling rate for audio processing. + streamlit (bool): Run the inference in Streamlit UI. """ - def __init__(self, model_path, **kwargs): + def __init__(self, model_path, local_path=None, **kwargs): self.model_path = NEXA_RUN_MODEL_MAP_ONNX.get(model_path, model_path) self.yaml_file_name = None self.params = { @@ -44,9 +44,18 @@ def __init__(self, model_path, **kwargs): self.model = None self.processor = None self.config = None - self.downloaded_onnx_folder = None + self.downloaded_onnx_folder = local_path + + if self.downloaded_onnx_folder is None: + self.downloaded_onnx_folder, run_type = pull_model(self.model_path) + + if self.downloaded_onnx_folder is None: + logging.error( + f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", + exc_info=True, + ) + exit(1) - self.downloaded_onnx_folder = pull_model(self.model_path) self.yaml_file_name = os.path.join(self.downloaded_onnx_folder, "config.yaml") with open(self.yaml_file_name, "r", encoding="utf-8") as f: self.config = yaml.safe_load(f) diff --git a/nexa/onnx/nexa_inference_voice.py b/nexa/onnx/nexa_inference_voice.py index 638a723a..40c9094a 100644 --- a/nexa/onnx/nexa_inference_voice.py +++ b/nexa/onnx/nexa_inference_voice.py @@ -10,8 +10,8 @@ from transformers import AutoProcessor from nexa.constants import EXIT_REMINDER, NEXA_RUN_MODEL_MAP_ONNX -from nexa.general import pull_model from nexa.utils import nexa_prompt +from nexa.general import pull_model logging.basicConfig(level=logging.INFO) @@ -26,26 +26,32 @@ class NexaVoiceInference: Args: model_path (str): Path or identifier for the model in Nexa Model Hub. + local_path (str): Local path of the model. output_dir (str): Output directory for transcriptions. sampling_rate (int): Sampling rate for audio processing. streamlit (bool): Run the inference in Streamlit UI. """ - def __init__(self, model_path, **kwargs): + def __init__(self, model_path, local_path=None, **kwargs): self.model_path = NEXA_RUN_MODEL_MAP_ONNX.get(model_path, model_path) + self.downloaded_onnx_folder = local_path self.params = {"output_dir": "transcriptions", "sampling_rate": 16000} self.params.update(kwargs) self.model = None self.processor = None def run(self): - downloaded_onnx_folder = pull_model(self.model_path) + if self.downloaded_onnx_folder is None: + self.downloaded_onnx_folder, run_type = pull_model(self.model_path) - if downloaded_onnx_folder is None: - logging.error("Failed to download the model. Exiting.") - return + if self.downloaded_onnx_folder is None: + logging.error( + f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", + exc_info=True, + ) + exit(1) - self._load_model(downloaded_onnx_folder) + self._load_model(self.downloaded_onnx_folder) self._dialogue_mode() def _load_model(self, model_path): @@ -99,7 +105,7 @@ def _transcribe_audio(self, audio_path): )[0] self._save_transcription(transcription) - logging.info(f"Transcription: {transcription}") + print(f"Transcription: {transcription}") except Exception as e: logging.error(f"Error during audio transcription: {e}") diff --git a/nexa/onnx/server/nexa_service.py b/nexa/onnx/server/nexa_service.py index 648d0368..387589d5 100644 --- a/nexa/onnx/server/nexa_service.py +++ b/nexa/onnx/server/nexa_service.py @@ -60,7 +60,7 @@ async def nexa_run_text_generation_preperation(model_path): global model, tokenizer, streamer, chat_template, is_chat_mode # Step 1: Check if the model_path is a key in NEXA_RUN_MODEL_MAP_ONNX, if so, get the full path full_model_path = NEXA_RUN_MODEL_MAP_ONNX.get(model_path, model_path) - downloaded_onnx_folder = pull_model(full_model_path) + downloaded_onnx_folder, run_type = pull_model(full_model_path) logging.info(f"Downloaded ONNX folder: {downloaded_onnx_folder}") # Step 2: Load the model and tokenizer diff --git a/nexa/onnx/streamlit/streamlit_image_chat.py b/nexa/onnx/streamlit/streamlit_image_chat.py index 4e632c69..724c8647 100644 --- a/nexa/onnx/streamlit/streamlit_image_chat.py +++ b/nexa/onnx/streamlit/streamlit_image_chat.py @@ -12,8 +12,8 @@ @st.cache_resource def load_model(model_path): - nexa_model = NexaImageInference(model_path) - nexa_model.downloaded_onnx_folder = pull_model(nexa_model.model_path) + local_path, run_type = pull_model(model_path) + nexa_model = NexaImageInference(model_path=model_path, local_path=local_path) if nexa_model.downloaded_onnx_folder is None: st.error("Failed to download the model. Please check the model path.") diff --git a/nexa/onnx/streamlit/streamlit_text_chat.py b/nexa/onnx/streamlit/streamlit_text_chat.py index 35d78913..872f2774 100644 --- a/nexa/onnx/streamlit/streamlit_text_chat.py +++ b/nexa/onnx/streamlit/streamlit_text_chat.py @@ -12,8 +12,8 @@ @st.cache_resource def load_model(model_path): - nexa_model = NexaTextInference(model_path) - nexa_model.downloaded_onnx_folder = pull_model(nexa_model.model_path) + local_path, run_type = pull_model(model_path) + nexa_model = NexaTextInference(model_path=model_path, local_path=local_path) if nexa_model.downloaded_onnx_folder is None: st.error("Failed to download the model. Please check the model path.") diff --git a/nexa/onnx/streamlit/streamlit_tts.py b/nexa/onnx/streamlit/streamlit_tts.py index 5f30d5a7..1e60cc7a 100644 --- a/nexa/onnx/streamlit/streamlit_tts.py +++ b/nexa/onnx/streamlit/streamlit_tts.py @@ -4,7 +4,7 @@ import soundfile as sf import streamlit as st - +from nexa.general import pull_model from nexa.onnx.nexa_inference_tts import NexaTTSInference default_model = sys.argv[1] @@ -12,7 +12,8 @@ @st.cache_resource def load_model(model_path: str): - return NexaTTSInference(model_path) + local_path, run_type = pull_model(model_path) + return NexaTTSInference(model_path=model_path, local_path=local_path) def generate_audio(nexa_model: NexaTTSInference, text): diff --git a/nexa/onnx/streamlit/streamlit_voice_chat.py b/nexa/onnx/streamlit/streamlit_voice_chat.py index 991ec61b..b3316edb 100644 --- a/nexa/onnx/streamlit/streamlit_voice_chat.py +++ b/nexa/onnx/streamlit/streamlit_voice_chat.py @@ -15,8 +15,8 @@ @st.cache_resource def load_model(model_path): - nexa_model = NexaVoiceInference(model_path) - nexa_model.downloaded_onnx_folder = pull_model(nexa_model.model_path) + local_path, run_type = pull_model(model_path) + nexa_model = NexaVoiceInference(model_path=model_path, local_path=local_path) if nexa_model.downloaded_onnx_folder is None: st.error("Failed to download the model. Please check the model path.") diff --git a/tests/test_image_generation.py b/tests/test_image_generation.py index 7e749dc6..292089f3 100644 --- a/tests/test_image_generation.py +++ b/tests/test_image_generation.py @@ -4,6 +4,7 @@ sd = NexaImageInference( model_path="sd1-4", + local_path=None, wtype="q4_0", ) @@ -31,6 +32,6 @@ def test_img_to_img(): ) # Main execution -# if __name__ == "__main__": - # test_txt_to_img() - # test_img_to_img() +if __name__ == "__main__": + test_txt_to_img() + test_img_to_img() diff --git a/tests/test_text_generation.py b/tests/test_text_generation.py index e3ceed30..e060adc2 100644 --- a/tests/test_text_generation.py +++ b/tests/test_text_generation.py @@ -3,6 +3,7 @@ model = NexaTextInference( model_path="gemma", + local_path=None, verbose=False, n_gpu_layers=-1 if is_gpu_available() else 0, chat_format="llama-2", @@ -63,12 +64,12 @@ def test_create_embedding(): print("Embeddings:\n", embeddings) # Main execution -# if __name__ == "__main__": -# print("=== Testing 1 ===") -# test1() -# print("=== Testing 2 ===") -# test2() -# print("=== Testing 3 ===") -# test3() -# print("=== Testing 4 ===") -# test4() \ No newline at end of file +if __name__ == "__main__": + print("=== Testing 1 ===") + test_text_generation() + print("=== Testing 2 ===") + test_streaming() + print("=== Testing 3 ===") + test_create_chat_completion() + print("=== Testing 4 ===") + test_create_embedding() \ No newline at end of file diff --git a/tests/test_vlm_generation.py b/tests/test_vlm_generation.py new file mode 100644 index 00000000..1e00fc6e --- /dev/null +++ b/tests/test_vlm_generation.py @@ -0,0 +1,63 @@ +from nexa.gguf import NexaVLMInference +from tempfile import TemporaryDirectory +from .utils import download_model +import os + +vlm = NexaVLMInference( + model_path="llava-phi-3-mini:q4_0", + local_path=None, +) + +# Test create_chat_completion +def test_create_chat_completion(): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"} + ] + completion = vlm.create_chat_completion( + messages=messages, + max_tokens=50, + temperature=0.7, + top_p=0.95, + top_k=40, + stream=False + ) + + assert isinstance(completion, dict) + assert "choices" in completion + assert len(completion["choices"]) > 0 + assert "message" in completion["choices"][0] + assert "content" in completion["choices"][0]["message"] + print("create_chat_completion test passed") + +# Test _chat method +def test_chat(): + with TemporaryDirectory() as temp_dir: + # Download a sample image + img_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + img_path = download_model(img_url, temp_dir) + + # Test _chat with image + chat_output = vlm._chat("Describe this image", image_path=img_path) + + # Check if the output is an iterator + assert hasattr(chat_output, '__iter__') + + # Collect the output + output_text = "" + for chunk in chat_output: + assert "choices" in chunk + assert len(chunk["choices"]) > 0 + assert "delta" in chunk["choices"][0] + delta = chunk["choices"][0]["delta"] + if "content" in delta: + output_text += delta["content"] + + assert len(output_text) > 0 + print("_chat test with image passed") + +if __name__ == "__main__": + print("=== Testing 1 ===") + test_create_chat_completion() + print("=== Testing 2 ===") + test_chat() \ No newline at end of file diff --git a/tests/test_voice_generation.py b/tests/test_voice_generation.py new file mode 100644 index 00000000..5bc1ce3b --- /dev/null +++ b/tests/test_voice_generation.py @@ -0,0 +1,57 @@ +import os +import tempfile +from pathlib import Path +from nexa.gguf import NexaVoiceInference +from unittest.mock import patch, MagicMock + +def test_voice_generation(): + # Initialize the NexaVoiceInference + voice_inference = NexaVoiceInference( + model_path="faster-whisper-tiny", + local_path=None, + output_dir=tempfile.mkdtemp(), # Use a temporary directory for output + beam_size=5, + language="en", + task="transcribe", + temperature=0.0, + compute_type="float32" + ) + + # Create a mock audio file + mock_audio_path = Path(tempfile.mkdtemp()) / "test_audio.wav" + mock_audio_path.touch() # Create an empty file + + # Mock the WhisperModel's transcribe method + mock_segments = [ + MagicMock(text="This is a test transcription."), + MagicMock(text=" It works perfectly.") + ] + mock_transcribe = MagicMock(return_value=(mock_segments, None)) + + # Test _transcribe_audio method + with patch.object(voice_inference.model, 'transcribe', mock_transcribe): + voice_inference._transcribe_audio(str(mock_audio_path)) + + # Assertions + mock_transcribe.assert_called_once_with( + str(mock_audio_path), + beam_size=5, + language="en", + task="transcribe", + temperature=0.0, + vad_filter=True + ) + + # Check if the transcription was saved + transcription_files = list(Path(voice_inference.params["output_dir"]).glob("transcription_*.txt")) + assert len(transcription_files) == 1, "Transcription file was not created" + + # Check the content of the transcription file + with open(transcription_files[0], 'r') as f: + content = f.read() + assert content == "This is a test transcription. It works perfectly.", "Transcription content is incorrect" + + print("Voice generation test passed successfully!") + +if __name__ == "__main__": + test_voice_generation() \ No newline at end of file