From b6f7f8e3fc7e7aacdf5df86df09cb2ed42d791f4 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 19 Nov 2024 16:43:52 +0800 Subject: [PATCH] support modelscope hub --- CLI.md | 32 ++++--- README.md | 15 ++- SERVER.md | 3 +- nexa/cli/entry.py | 57 +++++++---- nexa/constants.py | 1 + nexa/general.py | 156 +++++++++++++++++++++++++------ nexa/gguf/server/nexa_service.py | 28 ++++-- requirements.txt | 1 + 8 files changed, 222 insertions(+), 71 deletions(-) diff --git a/CLI.md b/CLI.md index 06ed3a33..5f219047 100644 --- a/CLI.md +++ b/CLI.md @@ -46,11 +46,12 @@ nexa pull MODEL_PATH usage: nexa pull [-h] model_path positional arguments: - model_path Path or identifier for the model in Nexa Model Hub, or Hugging Face repo ID when using -hf flag + model_path Path or identifier for the model in Nexa Model Hub, Hugging Face repo ID when using -hf flag, or ModelScope model ID when using -ms flag options: -h, --help show this help message and exit -hf, --huggingface Pull model from Hugging Face Hub + -ms, --modelscope Pull model from ModelScope Hub -o, --output_path OUTPUT_PATH Custom output path for the pulled model ``` @@ -102,7 +103,7 @@ You can run any model shown in `nexa list` command. ``` nexa run MODEL_PATH -usage: nexa run [-h] [-t TEMPERATURE] [-m MAX_NEW_TOKENS] [-k TOP_K] [-p TOP_P] [-sw [STOP_WORDS ...]] [-pf] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] model_path +usage: nexa run [-h] [-t TEMPERATURE] [-m MAX_NEW_TOKENS] [-k TOP_K] [-p TOP_P] [-sw [STOP_WORDS ...]] [-pf] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] [-ms] model_path positional arguments: model_path Path or identifier for the model in Nexa Model Hub @@ -112,8 +113,9 @@ options: -pf, --profiling Enable profiling logs for the inference process -st, --streamlit Run the inference in Streamlit UI, can be used with -lp or -hf -lp, --local_path Indicate that the model path provided is the local path - -mt, --model_type Indicate the model running type, must be used with -lp or -hf, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] + -mt, --model_type Indicate the model running type, must be used with -lp or -hf or -ms, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] -hf, --huggingface Load model from Hugging Face Hub + -ms, --modelscope Load model from ModelScope Hub Text generation options: -t, --temperature TEMPERATURE @@ -137,7 +139,7 @@ nexa run llama2 ``` nexa run MODEL_PATH -usage: nexa run [-h] [-i2i] [-ns NUM_INFERENCE_STEPS] [-np NUM_IMAGES_PER_PROMPT] [-H HEIGHT] [-W WIDTH] [-g GUIDANCE_SCALE] [-o OUTPUT] [-s RANDOM_SEED] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] model_path +usage: nexa run [-h] [-i2i] [-ns NUM_INFERENCE_STEPS] [-np NUM_IMAGES_PER_PROMPT] [-H HEIGHT] [-W WIDTH] [-g GUIDANCE_SCALE] [-o OUTPUT] [-s RANDOM_SEED] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] [-ms] model_path positional arguments: model_path Path or identifier for the model in Nexa Model Hub @@ -146,8 +148,9 @@ options: -h, --help show this help message and exit -st, --streamlit Run the inference in Streamlit UI, can be used with -lp or -hf -lp, --local_path Indicate that the model path provided is the local path - -mt, --model_type Indicate the model running type, must be used with -lp or -hf, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] + -mt, --model_type Indicate the model running type, must be used with -lp or -hf or -ms, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] -hf, --huggingface Load model from Hugging Face Hub + -ms, --modelscope Load model from ModelScope Hub Image generation options: -i2i, --img2img Whether to run image-to-image generation @@ -182,7 +185,7 @@ nexa run sd1-4 ``` nexa run MODEL_PATH -usage: nexa run [-h] [-t TEMPERATURE] [-m MAX_NEW_TOKENS] [-k TOP_K] [-p TOP_P] [-sw [STOP_WORDS ...]] [-pf] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] model_path +usage: nexa run [-h] [-t TEMPERATURE] [-m MAX_NEW_TOKENS] [-k TOP_K] [-p TOP_P] [-sw [STOP_WORDS ...]] [-pf] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] [-ms] model_path positional arguments: model_path Path or identifier for the model in Nexa Model Hub @@ -192,8 +195,9 @@ options: -pf, --profiling Enable profiling logs for the inference process -st, --streamlit Run the inference in Streamlit UI, can be used with -lp or -hf -lp, --local_path Indicate that the model path provided is the local path - -mt, --model_type Indicate the model running type, must be used with -lp or -hf, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] + -mt, --model_type Indicate the model running type, must be used with -lp or -hf or -ms, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] -hf, --huggingface Load model from Hugging Face Hub + -ms, --modelscope Load model from ModelScope Hub VLM generation options: -t, --temperature TEMPERATURE @@ -217,7 +221,7 @@ nexa run nanollava ``` nexa run MODEL_PATH -usage: nexa run [-h] [-o OUTPUT_DIR] [-b BEAM_SIZE] [-l LANGUAGE] [--task TASK] [-t TEMPERATURE] [-c COMPUTE_TYPE] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] model_path +usage: nexa run [-h] [-o OUTPUT_DIR] [-b BEAM_SIZE] [-l LANGUAGE] [--task TASK] [-t TEMPERATURE] [-c COMPUTE_TYPE] [-st] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] [-ms] model_path positional arguments: model_path Path or identifier for the model in Nexa Model Hub @@ -226,8 +230,9 @@ options: -h, --help show this help message and exit -st, --streamlit Run the inference in Streamlit UI, can be used with -lp or -hf -lp, --local_path Indicate that the model path provided is the local path - -mt, --model_type Indicate the model running type, must be used with -lp or -hf, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] + -mt, --model_type Indicate the model running type, must be used with -lp or -hf or -ms, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] -hf, --huggingface Load model from Hugging Face Hub + -ms, --modelscope Load model from ModelScope Hub Automatic Speech Recognition options: -b, --beam_size BEAM_SIZE @@ -251,7 +256,7 @@ nexa run faster-whisper-tiny ``` nexa embed MODEL_PATH -usage: nexa embed [-h] [-lp] [-hf] [-n] [-nt] model_path prompt +usage: nexa embed [-h] [-lp] [-hf] [-ms] [-n] [-nt] model_path prompt positional arguments: model_path Path or identifier for the model in Nexa Model Hub @@ -261,6 +266,7 @@ options: -h, --help show this help message and exit -lp, --local_path Indicate that the model path provided is the local path -hf, --huggingface Load model from Hugging Face Hub + -ms, --modelscope Load model from ModelScope Hub -n, --normalize Normalize the embeddings -nt, --no_truncate Not truncate the embeddings ``` @@ -318,6 +324,7 @@ options: --only_copy Only copy tensors (ignores ftype, allow_requantize, and quantize_output_tensor) --pure Quantize all tensors to the default type --keep_split Quantize to the same number of shards + -ms --modelscope Load model from ModelScope Hub ``` #### Example @@ -341,7 +348,7 @@ Start a local server using models on your local computer. ``` nexa server MODEL_PATH -usage: nexa server [-h] [--host HOST] [--port PORT] [--reload] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] model_path +usage: nexa server [-h] [--host HOST] [--port PORT] [--reload] [-lp] [-mt {NLP, COMPUTER_VISION, MULTIMODAL, AUDIO}] [-hf] [-ms] model_path positional arguments: model_path Path or identifier for the model in S3 @@ -349,8 +356,9 @@ positional arguments: options: -h, --help show this help message and exit -lp, --local_path Indicate that the model path provided is the local path - -mt, --model_type Indicate the model running type, must be used with -lp or -hf, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] + -mt, --model_type Indicate the model running type, must be used with -lp or -hf or -ms, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] -hf, --huggingface Load model from Hugging Face Hub + -ms, --modelscope Load model from ModelScope Hub --host HOST Host to bind the server to --port PORT Port to bind the server to --reload Enable automatic reloading on code changes diff --git a/README.md b/README.md index dc161b6e..1acdca1e 100644 --- a/README.md +++ b/README.md @@ -254,19 +254,26 @@ Supported model examples (full list at [Model Hub](https://nexa.ai/models)): | [all-MiniLM-L12-v2](https://nexa.ai/sentence-transformers/all-MiniLM-L12-v2/gguf-fp16/readme) | Embedding | GGUF | `nexa embed all-MiniLM-L12-v2:fp16` | | [bark-small](https://nexa.ai/suno/bark-small/gguf-fp16/readme) | Text-to-Speech | GGUF | `nexa run bark-small:fp16` | -## Run Models from 🤗 HuggingFace -You can pull, convert (to .gguf), quantize and run [llama.cpp supported](https://github.com/ggerganov/llama.cpp#description) text generation models from HF with Nexa SDK. +## Run Models from 🤗 HuggingFace or 🤖 ModelScope +You can pull, convert (to .gguf), quantize and run [llama.cpp supported](https://github.com/ggerganov/llama.cpp#description) text generation models from HF or MS with Nexa SDK. ### Run .gguf File -Use `nexa run -hf ` to run models with provided .gguf files: +Use `nexa run -hf ` or `nexa run -ms ` to run models with provided .gguf files: ```bash nexa run -hf Qwen/Qwen2.5-Coder-7B-Instruct-GGUF ``` +```bash +nexa run -ms Qwen/Qwen2.5-Coder-7B-Instruct-GGUF +``` > **Note:** You will be prompted to select a single .gguf file. If your desired quantization version has multiple split files (like fp16-00001-of-00004), please use Nexa's conversion tool (see below) to convert and quantize the model locally. ### Convert .safetensors Files -Install [Nexa Python package](https://github.com/NexaAI/nexa-sdk?tab=readme-ov-file#install-option-2-python-package), and install Nexa conversion tool with `pip install "nexaai[convert]"`, then convert models with `nexa convert `: +Install [Nexa Python package](https://github.com/NexaAI/nexa-sdk?tab=readme-ov-file#install-option-2-python-package), and install Nexa conversion tool with `pip install "nexaai[convert]"`, then convert models from huggingface with `nexa convert `: ```bash nexa convert HuggingFaceTB/SmolLM2-135M-Instruct ``` +Or you can convert models from ModelScope with `nexa convert -ms `: +```bash +nexa convert -ms Qwen/Qwen2.5-7B-Instruct +``` > **Note:** Check our [leaderboard](https://nexa.ai/leaderboard) for performance benchmarks of different quantized versions of mainstream language models and [HuggingFace docs](https://huggingface.co/docs/optimum/en/concept_guides/quantization) to learn about quantization options. 📋 You can view downloaded and converted models with `nexa list` diff --git a/SERVER.md b/SERVER.md index 77c9b0b5..b75efa45 100644 --- a/SERVER.md +++ b/SERVER.md @@ -9,8 +9,9 @@ usage: nexa server [-h] [--host HOST] [--port PORT] [--reload] model_path ### Options: - `-lp, --local_path`: Indicate that the model path provided is the local path -- `-mt, --model_type`: Indicate the model running type, must be used with -lp or -hf, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] +- `-mt, --model_type`: Indicate the model running type, must be used with -lp or -hf or ms, choose from [NLP, COMPUTER_VISION, MULTIMODAL, AUDIO] - `-hf, --huggingface`: Load model from Hugging Face Hub +- `-ms, --modelscope`: Load model from ModelScope Hub - `--host`: Host to bind the server to - `--port`: Port to bind the server to - `--reload`: Enable automatic reloading on code changes diff --git a/nexa/cli/entry.py b/nexa/cli/entry.py index 84325c53..b4ea222d 100644 --- a/nexa/cli/entry.py +++ b/nexa/cli/entry.py @@ -40,17 +40,18 @@ def run_ggml_inference(args): is_local_path = kwargs.pop("local_path", False) model_type = kwargs.pop("model_type", None) hf = kwargs.pop('huggingface', False) + ms = kwargs.pop('modelscope', False) run_type = None if model_type: run_type = ModelType[model_type].value - elif is_local_path or hf: + elif is_local_path or hf or ms: run_type = ModelType["NLP"].value local_path = None projector_local_path = None - if is_local_path or hf: + if is_local_path or hf or ms: if is_local_path: local_path = os.path.abspath(model_path) model_path = local_path @@ -69,13 +70,13 @@ def run_ggml_inference(args): if not os.path.isdir(local_path): print("Error: For Audio models with --local_path, the provided path must be a directory containing all related files.") return - else: # hf case + else: # hf or ms case # TODO: remove this after adding support for Multimodal model in CLI if run_type == "Multimodal" or run_type == "Audio" or run_type == "TTS": print("Running multimodal model or audio model or TTS model from Hugging Face is currently not supported in CLI mode. Please use SDK to run Multimodal model or Audio model or TTS model.") return from nexa.general import pull_model - local_path, _ = pull_model(model_path, hf=True, run_type=run_type) + local_path, _ = pull_model(model_path, hf=hf, ms=ms, run_type=run_type) else: # Model Hub from nexa.general import pull_model local_path, run_type = pull_model(model_path) @@ -83,7 +84,7 @@ def run_ggml_inference(args): stop_words = kwargs.pop("stop_words", None) try: - if (is_local_path or hf) and not model_type: + if (is_local_path or hf or ms) and not model_type: print("No model type specified. Running with default model type: NLP") print("You can specify a different model type using the -mt flag") @@ -149,11 +150,12 @@ def run_ggml_server(args): is_local_path = kwargs.pop("local_path", False) model_type = kwargs.pop("model_type", None) hf = kwargs.pop('huggingface', False) + ms = kwargs.pop('modelscope', False) run_type = None if model_type: run_type = ModelType[model_type].value - elif is_local_path or hf: + elif is_local_path or hf or ms: run_type = ModelType["NLP"].value projector_local_path = None @@ -173,7 +175,7 @@ def run_ggml_server(args): print("Error: For Audio models with --local_path, the provided path must be a directory containing all related files.") return - if (is_local_path or hf) and not model_type: + if (is_local_path or hf or ms) and not model_type: print("No model type specified. Running with default model type: NLP") print("You can specify a different model type using the -mt flag") @@ -182,6 +184,7 @@ def run_ggml_server(args): is_local_path_arg=is_local_path, model_type_arg=run_type, huggingface=hf, + modelscope=ms, projector_local_path_arg=projector_local_path, **kwargs ) @@ -269,17 +272,18 @@ def run_embedding_generation(args): prompt = kwargs.pop("prompt") is_local_path = kwargs.pop("local_path", False) hf = kwargs.pop('huggingface', False) + ms = kwargs.pop('modelscope', False) normalize = kwargs.pop('normalize', False) no_truncate = kwargs.pop('no_truncate', False) local_path = None - if is_local_path or hf: + if is_local_path or hf or ms: if is_local_path: local_path = os.path.abspath(model_path) model_path = local_path - else: # hf case + else: # hf or ms case from nexa.general import pull_model - local_path, _ = pull_model(model_path, hf=True, run_type="Text Embedding") + local_path, _ = pull_model(model_path, hf=hf, ms=ms, run_type="Text Embedding") else: # Model Hub from nexa.general import pull_model local_path, _ = pull_model(model_path) @@ -368,9 +372,13 @@ def run_convert(args): # Valid GGUF file, proceed as is pass else: - # Try downloading from HF if path isn't a valid local directory/file - from nexa.general import download_repo_from_hf - success, local_path = download_repo_from_hf(input_path) + # Try downloading from HF or MS if path isn't a valid local directory/file + if args.modelscope: + from nexa.general import download_repo_from_ms + success, local_path = download_repo_from_ms(input_path) + else: + from nexa.general import download_repo_from_hf + success, local_path = download_repo_from_hf(input_path) if success: input_path = local_path @@ -427,6 +435,7 @@ def main(): run_parser.add_argument("-lp", "--local_path", action="store_true", help="Indicate that the model path provided is the local path") run_parser.add_argument("-mt", "--model_type", type=str, choices=[e.name for e in ModelType], help="Indicate the model running type (default: NLP)") run_parser.add_argument("-hf", "--huggingface", action="store_true", help="Load model from Hugging Face Hub") + run_parser.add_argument("-ms", "--modelscope", action="store_true", help="Load model from ModelScope Hub") # Text generation/vlm arguments text_group = run_parser.add_argument_group('Text generation/VLM options') @@ -505,6 +514,7 @@ def main(): embed_parser.add_argument("prompt", type=str, help="The prompt to generate an embedding for") embed_parser.add_argument("-lp", "--local_path", action="store_true", help="Indicate that the model path provided is the local path") embed_parser.add_argument("-hf", "--huggingface", action="store_true", help="Load model from Hugging Face Hub") + embed_parser.add_argument("-ms", "--modelscope", action="store_true", help="Load model from ModelScope Hub") embed_parser.add_argument("-n", "--normalize", action="store_true", help="Normalize the embeddings") embed_parser.add_argument("-nt", "--no_truncate", action="store_true", help="Not truncate the embeddings") @@ -525,6 +535,7 @@ def main(): convert_hf_parser.add_argument("--no_tensor_first_split", action="store_true", help="Disable tensor-first splitting") convert_hf_parser.add_argument("--vocab_only", action="store_true", help="Only process vocabulary") convert_hf_parser.add_argument("--dry_run", action="store_true", help="Perform a dry run without actual conversion") + convert_hf_parser.add_argument("-ms", "--modelscope", action="store_true", help="Download model from ModelScope Hub") quantization_parser = convert_parser.add_argument_group('Quantization options') quantization_parser.add_argument("--nthread", type=int, default=4, help="Number of threads to use (default: 4)") @@ -540,8 +551,9 @@ def main(): server_parser = subparsers.add_parser("server", help="Run the Nexa AI Text Generation Service") server_parser.add_argument("model_path", type=str, nargs='?', help="Path or identifier for the model in Nexa Model Hub") server_parser.add_argument("-lp", "--local_path", action="store_true", help="Indicate that the model path provided is the local path") - server_parser.add_argument("-mt", "--model_type", type=str, choices=[e.name for e in ModelType], help="Indicate the model running type, must be used with -lp or -hf") + server_parser.add_argument("-mt", "--model_type", type=str, choices=[e.name for e in ModelType], help="Indicate the model running type, must be used with -lp, -hf or -ms") server_parser.add_argument("-hf", "--huggingface", action="store_true", help="Load model from Hugging Face Hub") + server_parser.add_argument("-ms", "--modelscope", action="store_true", help="Load model from ModelScope Hub") server_parser.add_argument("--host", type=str, default="localhost", help="Host to bind the server to") server_parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to") server_parser.add_argument("--reload", action="store_true", help="Enable automatic reloading on code changes") @@ -551,6 +563,7 @@ def main(): pull_parser = subparsers.add_parser("pull", help="Pull a model from official or hub.") pull_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub") pull_parser.add_argument("-hf", "--huggingface", action="store_true", help="Pull model from Hugging Face Hub") + pull_parser.add_argument("-ms", "--modelscope", action="store_true", help="Pull model from ModelScope Hub") pull_parser.add_argument("-o", "--output_path", type=str, help="Custom output path for the pulled model") remove_parser = subparsers.add_parser("remove", help="Remove a model from local machine.") @@ -584,11 +597,23 @@ def main(): if args.local_path and args.huggingface: print("Error: --local_path and --huggingface flags cannot be used together") return + if args.local_path and args.modelscope: + print("Error: --local_path and --modelscope flags cannot be used together") + return + if args.huggingface and args.modelscope: + print("Error: --huggingface and --modelscope flags cannot be used together") + return run_ggml_inference(args) elif args.command == "server": if args.local_path and args.huggingface: print("Error: --local_path and --huggingface flags cannot be used together") return + if args.local_path and args.modelscope: + print("Error: --local_path and --modelscope flags cannot be used together") + return + if args.huggingface and args.modelscope: + print("Error: --huggingface and --modelscope flags cannot be used together") + return run_ggml_server(args) elif args.command == "onnx": run_onnx_inference(args) @@ -601,6 +626,7 @@ def main(): import os hf = getattr(args, 'huggingface', False) + ms = getattr(args, 'modelscope', False) local_download_path = None if args.output_path: @@ -608,8 +634,7 @@ def main(): os.makedirs(args.output_path, exist_ok=True) print(f"Created output directory: {args.output_path}") local_download_path = os.path.abspath(args.output_path) - - pull_model(args.model_path, hf, local_download_path=local_download_path) + pull_model(args.model_path, hf, ms, local_download_path=local_download_path) elif args.command == "convert": run_convert(args) elif args.command == "remove": diff --git a/nexa/constants.py b/nexa/constants.py index b034f73d..24acd195 100644 --- a/nexa/constants.py +++ b/nexa/constants.py @@ -9,6 +9,7 @@ NEXA_MODEL_EVAL_RESULTS_PATH = NEXA_CACHE_ROOT / "eval" NEXA_MODELS_HUB_OFFICIAL_DIR = NEXA_MODELS_HUB_DIR / "official" NEXA_MODELS_HUB_HF_DIR = NEXA_MODELS_HUB_DIR / "huggingface" +NEXA_MODELS_HUB_MS_DIR = NEXA_MODELS_HUB_DIR / "modelscope" NEXA_MODEL_LIST_PATH = NEXA_MODELS_HUB_DIR / "model_list.json" # URLs and buckets diff --git a/nexa/general.py b/nexa/general.py index 5c92f52a..3190c49d 100644 --- a/nexa/general.py +++ b/nexa/general.py @@ -18,6 +18,7 @@ NEXA_MODELS_HUB_DIR, NEXA_MODELS_HUB_OFFICIAL_DIR, NEXA_MODELS_HUB_HF_DIR, + NEXA_MODELS_HUB_MS_DIR, NEXA_OFFICIAL_BUCKET, NEXA_RUN_MODEL_MAP, NEXA_TOKEN_PATH, @@ -107,12 +108,14 @@ def get_user_info(token): return None -def pull_model(model_path, hf = False, **kwargs): +def pull_model(model_path, hf = False, ms = False, **kwargs): model_path = NEXA_RUN_MODEL_MAP.get(model_path, model_path) try: if hf == True: result = pull_model_from_hf(model_path, **kwargs) + elif ms == True: + result = pull_model_from_ms(model_path, **kwargs) else: if is_model_exists(model_path): location, run_type = get_model_info(model_path) @@ -126,11 +129,11 @@ def pull_model(model_path, hf = False, **kwargs): if result["success"]: # Only add to model list if not using custom download path - model_path = model_path if not hf else f"{model_path}:{result['local_path'].split('/')[-1]}" + model_path = model_path if not (hf or ms) else f"{model_path}:{result['local_path'].split('/')[-1]}" if not kwargs.get('local_download_path'): add_model_to_list(model_path, result["local_path"], result["model_type"], result["run_type"]) - if hf: + if hf or ms: print(f"Successfully pulled model {model_path} to {result['local_path']}") else: print(f"Successfully pulled model {model_path} to {result['local_path']}, run_type: {result['run_type']}") @@ -232,7 +235,7 @@ def pull_model_from_official(model_path, **kwargs): } def pull_model_from_hf(repo_id, run_type = "NLP", **kwargs): - repo_id, filename = select_gguf_in_hf_repo(repo_id) + repo_id, filename = select_gguf_from_repo(repo_id, 'huggingface') success, model_path = download_gguf_from_hf(repo_id, filename, **kwargs) # For beta version, we only support NLP gguf models @@ -244,6 +247,19 @@ def pull_model_from_hf(repo_id, run_type = "NLP", **kwargs): } +def pull_model_from_ms(repo_id, run_type = "NLP", **kwargs): + repo_id, filename = select_gguf_from_repo(repo_id, 'modelscope') + success, model_path = download_gguf_from_ms(repo_id, filename, **kwargs) + + # For beta version, we only support NLP gguf models + return { + "success": success, + "local_path": model_path, + "model_type": "gguf", + "run_type": run_type + } + + def get_run_type_from_model_path(model_path): model_name, _ = model_path.split(":") return NEXA_OFFICIAL_MODELS_TYPE.get(model_name, ModelType.NLP).value @@ -469,6 +485,32 @@ def download_repo_from_hf(repo_id): print(f"Failed to download the repository: {e}") return False, None +def download_repo_from_ms(repo_id): + try: + from modelscope import snapshot_download + from pathlib import Path + except ImportError: + print("The modelscope package is required. Please install it with `pip install modelscope`.") + return False, None + + # Define the local directory to save the model + local_dir = NEXA_MODELS_HUB_MS_DIR / Path(repo_id) + local_dir.mkdir(parents=True, exist_ok=True) + + try: + # Download the entire repository + repo_path = snapshot_download( + model_id=repo_id, + local_dir=local_dir, + revision="master" + ) + + print(f"Successfully downloaded repository '{repo_id}' to {repo_path}") + return True, repo_path + except Exception as e: + print(f"Failed to download the repository: {e}") + return False, None + def download_gguf_from_hf(repo_id, filename, **kwargs): try: from huggingface_hub import hf_hub_download @@ -508,6 +550,44 @@ def download_gguf_from_hf(repo_id, filename, **kwargs): print(f"Failed to download the model: {e}") return False, None +def download_gguf_from_ms(repo_id, filename, **kwargs): + from pathlib import Path + import shutil + try: + from modelscope.hub.file_download import model_file_download + except ImportError: + print("The modelscope package is required. Please install it with `pip install modelscope`.") + return None + + # Get custom download path from kwargs if present + local_download_path = kwargs.get('local_download_path') + base_download_dir = Path(local_download_path) if local_download_path else NEXA_MODELS_HUB_MS_DIR + local_dir = base_download_dir / Path(repo_id) + local_dir.mkdir(parents=True, exist_ok=True) + + # Download the model + try: + model_path = model_file_download( + model_id=repo_id, + file_path=filename, + local_dir=local_dir, + local_files_only=False, + ) + # If using custom download path, move the file and cleanup + if local_download_path: + model_file = Path(model_path) + target_path = base_download_dir / filename + shutil.move(str(model_file), str(target_path)) + # Get the organization directory (first part of repo_id) + org_dir = base_download_dir / repo_id.split('/')[0] + shutil.rmtree(org_dir) + return True, str(target_path) + + return True, model_path + except Exception as e: + print(f"Failed to download the model: {e}") + return False, None + def is_model_exists(model_name): if not NEXA_MODEL_LIST_PATH.exists(): return False @@ -660,42 +740,60 @@ def clean(): except Exception as e: print(f"An error occurred while cleaning the directory: {e}") -def select_gguf_in_hf_repo(repo_id: str) -> Tuple[str, str]: +def select_gguf_from_repo(repo_id: str, model_hub: str) -> Tuple[str, str]: """ - Lists all files ending with .gguf in the given Hugging Face repository, + Lists all files ending with .gguf in the given (HuggingFace or ModelScope) repository, prompts the user to select one, and returns the repo_id and the selected filename. Args: - repo_id (str): The Hugging Face repository ID. + repo_id (str): The repository ID. + model_hub (str): huggingface or modelscope Returns: Tuple[str, str]: A tuple containing the repo_id and the selected filename. """ - try: - from huggingface_hub import HfFileSystem - from huggingface_hub.utils import validate_repo_id - from pathlib import Path - except ImportError: - print("The huggingface-hub package is required. Please install it with `pip install huggingface-hub`.") - exit(1) + if model_hub == 'huggingface': + try: + from huggingface_hub import HfFileSystem + from huggingface_hub.utils import validate_repo_id + from pathlib import Path + except ImportError: + print("The huggingface-hub package is required. Please install it with `pip install huggingface-hub`.") + exit(1) - validate_repo_id(repo_id) - hffs = HfFileSystem() + validate_repo_id(repo_id) + hffs = HfFileSystem() - try: - files = [ - file["name"] if isinstance(file, dict) else file - for file in hffs.ls(repo_id, recursive=True) - ] - except Exception as e: - print(f"Error accessing repository '{repo_id}'. Please make sure you have access to the Hugging Face repository first.") - exit(1) + try: + files = [ + file["name"] if isinstance(file, dict) else file + for file in hffs.ls(repo_id, recursive=True) + ] + except Exception as e: + print(f"Error accessing repository '{repo_id}'. Please make sure you have access to the Hugging Face repository first.") + exit(1) + + # Remove the repo prefix from files + file_list = [] + for file in files: + rel_path = Path(file).relative_to(repo_id) + file_list.append(str(rel_path)) + elif model_hub == 'modelscope': + try: + from modelscope.hub.api import HubApi + except ImportError: + print("The modelscope package is required. Please install it with `pip install modelscope`.") + exit(1) - # Remove the repo prefix from files - file_list = [] - for file in files: - rel_path = Path(file).relative_to(repo_id) - file_list.append(str(rel_path)) + try: + ms_api = HubApi() + infos = ms_api.get_model_files(repo_id, recursive=True) + file_list = [info['Path'] for info in infos] + except Exception as e: + print(f"Error accessing repository '{repo_id}'. Please make sure you have access to the ModelScope repository first.") + exit(1) + else: + raise ValueError("Invalid model hub specified. Supported model hub are 'huggingface' and 'modelscope") # Filter for files ending with .gguf gguf_files = [file for file in file_list if file.endswith('.gguf')] diff --git a/nexa/gguf/server/nexa_service.py b/nexa/gguf/server/nexa_service.py index d0267098..73772e31 100644 --- a/nexa/gguf/server/nexa_service.py +++ b/nexa/gguf/server/nexa_service.py @@ -82,6 +82,7 @@ is_local_path = False model_type = None is_huggingface = False +is_modelscope = False projector_path = None # Request Classes class GenerationRequest(BaseModel): @@ -176,7 +177,7 @@ class EmbeddingRequest(BaseModel): # helper functions async def load_model(): - global model, chat_format, completion_template, model_path, n_ctx, is_local_path, model_type, is_huggingface, projector_path + global model, chat_format, completion_template, model_path, n_ctx, is_local_path, model_type, is_huggingface, is_modelscope, projector_path if is_local_path: if model_type == "Multimodal": if not projector_path: @@ -185,11 +186,11 @@ async def load_model(): projector_downloaded_path = projector_path else: downloaded_path = model_path - elif is_huggingface: + elif is_huggingface or is_modelscope: # TODO: currently Multimodal models and Audio models are not supported for Hugging Face if model_type == "Multimodal" or model_type == "Audio": raise ValueError("Multimodal and Audio models are not supported for Hugging Face") - downloaded_path, _ = pull_model(model_path, hf=True) + downloaded_path, _ = pull_model(model_path, hf=is_huggingface, ms=is_modelscope) else: if model_path in NEXA_RUN_MODEL_MAP_VLM: # for Multimodal models downloaded_path, _ = pull_model(NEXA_RUN_MODEL_MAP_VLM[model_path]) @@ -333,7 +334,7 @@ def nexa_run_text_generation( logprobs_or_none = None if is_chat_completion: - if is_local_path or is_huggingface: # do not add system prompt if local path or huggingface + if is_local_path or is_huggingface or is_modelscope: # do not add system prompt if local path or huggingface or modelscope messages = [{"role": "user", "content": prompt}] else: messages = chat_completion_system_prompt + [{"role": "user", "content": prompt}] @@ -496,14 +497,15 @@ def image_url_to_base64(image_url: str) -> str: return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}" -def run_nexa_ai_service(model_path_arg=None, is_local_path_arg=False, model_type_arg=None, huggingface=False, projector_local_path_arg=None, **kwargs): - global model_path, n_ctx, is_local_path, model_type, is_huggingface, projector_path +def run_nexa_ai_service(model_path_arg=None, is_local_path_arg=False, model_type_arg=None, huggingface=False, modelscope=False, projector_local_path_arg=None, **kwargs): + global model_path, n_ctx, is_local_path, model_type, is_huggingface, is_modelscope, projector_path is_local_path = is_local_path_arg is_huggingface = huggingface + is_modelscope = modelscope projector_path = projector_local_path_arg - if is_local_path_arg or huggingface: + if is_local_path_arg or huggingface or modelscope: if not model_path_arg: - raise ValueError("model_path must be provided when using --local_path or --huggingface") + raise ValueError("model_path must be provided when using --local_path or --huggingface or --modelscope") if is_local_path_arg and not model_type_arg: raise ValueError("--model_type must be provided when using --local_path") model_path = os.path.abspath(model_path_arg) if is_local_path_arg else model_path_arg @@ -515,6 +517,7 @@ def run_nexa_ai_service(model_path_arg=None, is_local_path_arg=False, model_type os.environ["IS_LOCAL_PATH"] = str(is_local_path_arg) os.environ["MODEL_TYPE"] = model_type if model_type else "" os.environ["HUGGINGFACE"] = str(huggingface) + os.environ["MODELSCOPE"] = str(modelscope) os.environ["PROJECTOR_PATH"] = projector_path if projector_path else "" n_ctx = kwargs.get("nctx", 2048) host = kwargs.get("host", "localhost") @@ -525,11 +528,12 @@ def run_nexa_ai_service(model_path_arg=None, is_local_path_arg=False, model_type # Endpoints @app.on_event("startup") async def startup_event(): - global model_path, is_local_path, model_type, is_huggingface, projector_path + global model_path, is_local_path, model_type, is_huggingface, is_modelscope, projector_path model_path = os.getenv("MODEL_PATH", "gemma") is_local_path = os.getenv("IS_LOCAL_PATH", "False").lower() == "true" model_type = os.getenv("MODEL_TYPE", None) is_huggingface = os.getenv("HUGGINGFACE", "False").lower() == "true" + is_modelscope = os.getenv("MODELSCOPE", "False").lower() == "true" projector_path = os.getenv("PROJECTOR_PATH", None) await load_model() @@ -859,12 +863,18 @@ async def create_embedding(request: EmbeddingRequest): action="store_true", help="Use a Hugging Face model", ) + parser.add_argument( + "--modelscope", + action="store_true", + help="Use a ModelScope model", + ) args = parser.parse_args() run_nexa_ai_service( args.model_path, is_local_path_arg=args.local_path, model_type_arg=args.model_type, huggingface=args.huggingface, + modelscope=args.modelscope, nctx=args.nctx, host=args.host, port=args.port, diff --git a/requirements.txt b/requirements.txt index 978b8c1d..6e732a0e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,7 @@ pydantic pillow python-multipart huggingface_hub +modelscope # For onnx optimum[onnxruntime] # for CPU version