diff --git a/README.md b/README.md index 8e611922e..9687fbbde 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,10 @@ We are focused to support Llama2 at scale now. If you want any other models, ple ## Dev Log +### 2024-02 + +Sync upstream changes + ### 2023-09 Sync upstream changes diff --git a/docs/arena.md b/docs/arena.md index 979f41db5..2d79b2acf 100644 --- a/docs/arena.md +++ b/docs/arena.md @@ -5,10 +5,11 @@ We invite the entire community to join this benchmarking effort by contributing ## How to add a new model If you want to see a specific model in the arena, you can follow the methods below. -- Method 1: Hosted by LMSYS. - 1. Contribute the code to support this model in FastChat by submitting a pull request. See [instructions](model_support.md#how-to-support-a-new-model). - 2. After the model is supported, we will try to schedule some compute resources to host the model in the arena. However, due to the limited resources we have, we may not be able to serve every model. We will select the models based on popularity, quality, diversity, and other factors. +### Method 1: Hosted by 3rd party API providers or yourself +If you have a model hosted by a 3rd party API provider or yourself, please give us the access to an API endpoint. + - We prefer OpenAI-compatible APIs, so we can reuse our [code](https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/api_provider.py) for calling OpenAI models. + - If you have your own API protocol, please follow the [instructions](model_support.md) to add them. Contribute your code by sending a pull request. -- Method 2: Hosted by 3rd party API providers or yourself. - 1. If you have a model hosted by a 3rd party API provider or yourself, please give us an API endpoint. We prefer OpenAI-compatible APIs, so we can reuse our [code](https://github.com/lm-sys/FastChat/blob/33dca5cf12ee602455bfa9b5f4790a07829a2db7/fastchat/serve/gradio_web_server.py#L333-L358) for calling OpenAI models. - 2. You can use FastChat's OpenAI API [server](openai_api.md) to serve your model with OpenAI-compatible APIs and provide us with the endpoint. +### Method 2: Hosted by LMSYS +1. Contribute the code to support this model in FastChat by submitting a pull request. See [instructions](model_support.md). +2. After the model is supported, we will try to schedule some compute resources to host the model in the arena. However, due to the limited resources we have, we may not be able to serve every model. We will select the models based on popularity, quality, diversity, and other factors. diff --git a/docs/commands/webserver.md b/docs/commands/webserver.md index 179d3dfe7..df96cf8d2 100644 --- a/docs/commands/webserver.md +++ b/docs/commands/webserver.md @@ -24,10 +24,13 @@ python3 -m fastchat.serve.test_message --model vicuna-13b --controller http://lo cd fastchat_logs/server0 +python3 -m fastchat.serve.huggingface_api_worker --model-info-file ~/elo_results/register_hf_api_models.json + export OPENAI_API_KEY= export ANTHROPIC_API_KEY= +export GCP_PROJECT_ID= -python3 -m fastchat.serve.gradio_web_server_multi --controller http://localhost:21001 --concurrency 10 --add-chatgpt --add-claude --add-palm --anony-only --elo ~/elo_results/elo_results.pkl --leaderboard-table-file ~/elo_results/leaderboard_table.csv --register ~/elo_results/register_oai_models.json --show-terms +python3 -m fastchat.serve.gradio_web_server_multi --controller http://localhost:21001 --concurrency 50 --add-chatgpt --add-claude --add-palm --elo ~/elo_results/elo_results.pkl --leaderboard-table-file ~/elo_results/leaderboard_table.csv --register ~/elo_results/register_oai_models.json --show-terms python3 backup_logs.py ``` diff --git a/docs/lightllm_integration.md b/docs/lightllm_integration.md new file mode 100644 index 000000000..b271a826a --- /dev/null +++ b/docs/lightllm_integration.md @@ -0,0 +1,18 @@ +# LightLLM Integration +You can use [LightLLM](https://github.com/ModelTC/lightllm) as an optimized worker implementation in FastChat. +It offers advanced continuous batching and a much higher (~10x) throughput. +See the supported models [here](https://github.com/ModelTC/lightllm?tab=readme-ov-file#supported-model-list). + +## Instructions +1. Please refer to the [Get started](https://github.com/ModelTC/lightllm?tab=readme-ov-file#get-started) to install LightLLM. Or use [Pre-built image](https://github.com/ModelTC/lightllm?tab=readme-ov-file#container) + +2. When you launch a model worker, replace the normal worker (`fastchat.serve.model_worker`) with the LightLLM worker (`fastchat.serve.lightllm_worker`). All other commands such as controller, gradio web server, and OpenAI API server are kept the same. Refer to [--max_total_token_num](https://github.com/ModelTC/lightllm/blob/4a9824b6b248f4561584b8a48ae126a0c8f5b000/docs/ApiServerArgs.md?plain=1#L23) to understand how to calculate the `--max_total_token_num` argument. + ``` + python3 -m fastchat.serve.lightllm_worker --model-path lmsys/vicuna-7b-v1.5 --tokenizer_mode "auto" --max_total_token_num 154000 + ``` + + If you what to use quantized weight and kv cache for inference, try + + ``` + python3 -m fastchat.serve.lightllm_worker --model-path lmsys/vicuna-7b-v1.5 --tokenizer_mode "auto" --max_total_token_num 154000 --mode triton_int8weight triton_int8kv + ``` diff --git a/docs/mlx_integration.md b/docs/mlx_integration.md new file mode 100644 index 000000000..21642d948 --- /dev/null +++ b/docs/mlx_integration.md @@ -0,0 +1,23 @@ +# Apple MLX Integration + +You can use [Apple MLX](https://github.com/ml-explore/mlx) as an optimized worker implementation in FastChat. + +It runs models efficiently on Apple Silicon + +See the supported models [here](https://github.com/ml-explore/mlx-examples/tree/main/llms#supported-models). + +Note that for Apple Silicon Macs with less memory, smaller models (or quantized models) are recommended. + +## Instructions + +1. Install MLX. + + ``` + pip install "mlx-lm>=0.0.6" + ``` + +2. When you launch a model worker, replace the normal worker (`fastchat.serve.model_worker`) with the MLX worker (`fastchat.serve.mlx_worker`). Remember to launch a model worker after you have launched the controller ([instructions](../README.md)) + + ``` + python3 -m fastchat.serve.mlx_worker --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 + ``` diff --git a/docs/model_support.md b/docs/model_support.md index fa0739128..ba5f5b79b 100644 --- a/docs/model_support.md +++ b/docs/model_support.md @@ -1,15 +1,48 @@ # Model Support +This document describes how to support a new model in FastChat. -## Supported models +## Content +- [Local Models](#local-models) +- [API-Based Models](#api-based-models) + +## Local Models +To support a new local model in FastChat, you need to correctly handle its prompt template and model loading. +The goal is to make the following command run with the correct prompts. + +``` +python3 -m fastchat.serve.cli --model [YOUR_MODEL_PATH] +``` + +You can run this example command to learn the code logic. + +``` +python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5 +``` + +You can add `--debug` to see the actual prompt sent to the model. + +### Steps + +FastChat uses the `Conversation` class to handle prompt templates and `BaseModelAdapter` class to handle model loading. + +1. Implement a conversation template for the new model at [fastchat/conversation.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py). You can follow existing examples and use `register_conv_template` to add a new one. Please also add a link to the official reference code if possible. +2. Implement a model adapter for the new model at [fastchat/model/model_adapter.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/model/model_adapter.py). You can follow existing examples and use `register_model_adapter` to add a new one. +3. (Optional) add the model name to the "Supported models" [section](#supported-models) above and add more information in [fastchat/model/model_registry.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/model/model_registry.py). + +After these steps, the new model should be compatible with most FastChat features, such as CLI, web UI, model worker, and OpenAI-compatible API server. Please do some testing with these features as well. + +### Supported models - [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) - example: `python3 -m fastchat.serve.cli --model-path meta-llama/Llama-2-7b-chat-hf` - Vicuna, Alpaca, LLaMA, Koala - example: `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5` +- [allenai/tulu-2-dpo-7b](https://huggingface.co/allenai/tulu-2-dpo-7b) - [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) - [BAAI/AquilaChat2-7B](https://huggingface.co/BAAI/AquilaChat2-7B) - [BAAI/AquilaChat2-34B](https://huggingface.co/BAAI/AquilaChat2-34B) - [BAAI/bge-large-en](https://huggingface.co/BAAI/bge-large-en#using-huggingface-transformers) +- [argilla/notus-7b-v1](https://huggingface.co/argilla/notus-7b-v1) - [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B) - [BlinkDL/RWKV-4-Raven](https://huggingface.co/BlinkDL/rwkv-4-raven) - example: `python3 -m fastchat.serve.cli --model-path ~/model_weights/RWKV-4-Raven-7B-v11x-Eng99%-Other1%-20230429-ctx8192.pth` @@ -18,13 +51,20 @@ - [camel-ai/CAMEL-13B-Combined-Data](https://huggingface.co/camel-ai/CAMEL-13B-Combined-Data) - [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf) - [databricks/dolly-v2-12b](https://huggingface.co/databricks/dolly-v2-12b) +- [deepseek-ai/deepseek-llm-67b-chat](https://huggingface.co/deepseek-ai/deepseek-llm-67b-chat) +- [deepseek-ai/deepseek-coder-33b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct) - [FlagAlpha/Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat) - [FreedomIntelligence/phoenix-inst-chat-7b](https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b) - [FreedomIntelligence/ReaLM-7b-v1](https://huggingface.co/FreedomIntelligence/Realm-7b) - [h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b](https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b) +- [HuggingFaceH4/starchat-beta](https://huggingface.co/HuggingFaceH4/starchat-beta) +- [HuggingFaceH4/zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha) - [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) +- [IEITYuan/Yuan2-2B/51B/102B-hf](https://huggingface.co/IEITYuan) - [lcw99/polyglot-ko-12.8b-chang-instruct-chat](https://huggingface.co/lcw99/polyglot-ko-12.8b-chang-instruct-chat) - [lmsys/fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5) +- [meta-math/MetaMath-7B-V1.0](https://huggingface.co/meta-math/MetaMath-7B-V1.0) +- [Microsoft/Orca-2-7b](https://huggingface.co/microsoft/Orca-2-7b) - [mosaicml/mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat) - example: `python3 -m fastchat.serve.cli --model-path mosaicml/mpt-7b-chat` - [Neutralzz/BiLLa-7B-SFT](https://huggingface.co/Neutralzz/BiLLa-7B-SFT) @@ -34,26 +74,25 @@ - [OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5](https://huggingface.co/OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5) - [openchat/openchat_3.5](https://huggingface.co/openchat/openchat_3.5) - [Open-Orca/Mistral-7B-OpenOrca](https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca) -- [VMware/open-llama-7b-v2-open-instruct](https://huggingface.co/VMware/open-llama-7b-v2-open-instruct) +- [OpenLemur/lemur-70b-chat-v1](https://huggingface.co/OpenLemur/lemur-70b-chat-v1) - [Phind/Phind-CodeLlama-34B-v2](https://huggingface.co/Phind/Phind-CodeLlama-34B-v2) - [project-baize/baize-v2-7b](https://huggingface.co/project-baize/baize-v2-7b) - [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat) +- [rishiraj/CatPPT](https://huggingface.co/rishiraj/CatPPT) - [Salesforce/codet5p-6b](https://huggingface.co/Salesforce/codet5p-6b) - [StabilityAI/stablelm-tuned-alpha-7b](https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b) +- [tenyx/TenyxChat-7B-v1](https://huggingface.co/tenyx/TenyxChat-7B-v1) +- [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0) - [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b) - [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) - [tiiuae/falcon-40b](https://huggingface.co/tiiuae/falcon-40b) - [tiiuae/falcon-180B-chat](https://huggingface.co/tiiuae/falcon-180B-chat) - [timdettmers/guanaco-33b-merged](https://huggingface.co/timdettmers/guanaco-33b-merged) - [togethercomputer/RedPajama-INCITE-7B-Chat](https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Chat) +- [VMware/open-llama-7b-v2-open-instruct](https://huggingface.co/VMware/open-llama-7b-v2-open-instruct) - [WizardLM/WizardLM-13B-V1.0](https://huggingface.co/WizardLM/WizardLM-13B-V1.0) - [WizardLM/WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0) -- [HuggingFaceH4/starchat-beta](https://huggingface.co/HuggingFaceH4/starchat-beta) -- [HuggingFaceH4/zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha) - [Xwin-LM/Xwin-LM-7B-V0.1](https://huggingface.co/Xwin-LM/Xwin-LM-70B-V0.1) -- [OpenLemur/lemur-70b-chat-v1](https://huggingface.co/OpenLemur/lemur-70b-chat-v1) -- [allenai/tulu-2-dpo-7b](https://huggingface.co/allenai/tulu-2-dpo-7b) -- [Microsoft/Orca-2-7b](https://huggingface.co/microsoft/Orca-2-7b) - Any [EleutherAI](https://huggingface.co/EleutherAI) pythia model such as [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b) - Any [Peft](https://github.com/huggingface/peft) adapter trained on top of a model above. To activate, must have `peft` in the model path. Note: If @@ -61,29 +100,31 @@ setting the environment variable `PEFT_SHARE_BASE_WEIGHTS=true` in any model worker. -## How to support a new model -To support a new model in FastChat, you need to correctly handle its prompt template and model loading. -The goal is to make the following command run with the correct prompts. +## API-Based Models +To support an API-based model, consider learning from the existing OpenAI example. +If the model is compatible with OpenAI APIs, then a configuration file is all that's needed without any additional code. +For custom protocols, implementation of a streaming generator in [fastchat/serve/api_provider.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/api_provider.py) is required, following the provided examples. Currently, FastChat is compatible with OpenAI, Anthropic, Google Vertex AI, Mistral, and Nvidia NGC. +### Steps to Launch a WebUI with an API Model +1. Specify the endpoint information in a JSON configuration file. For instance, create a file named `api_endpoints.json`: +```json +{ + "gpt-3.5-turbo": { + "model_name": "gpt-3.5-turbo", + "api_type": "openai", + "api_base": "https://api.openai.com/v1", + "api_key": "sk-******", + "anony_only": false + } +} ``` -python3 -m fastchat.serve.cli --model [YOUR_MODEL_PATH] -``` - -You can run this example command to learn the code logic. + - "api_type" can be one of the following: openai, anthropic, gemini, or mistral. For custom APIs, add a new type and implement it accordingly. + - "anony_only" indicates whether to display this model in anonymous mode only. +2. Launch the Gradio web server with the argument `--register api_endpoints.json`: ``` -python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5 +python3 -m fastchat.serve.gradio_web_server --controller "" --share --register api_endpoints.json ``` -You can add `--debug` to see the actual prompt sent to the model. - -### Steps - -FastChat uses the `Conversation` class to handle prompt templates and `BaseModelAdapter` class to handle model loading. - -1. Implement a conversation template for the new model at [fastchat/conversation.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py). You can follow existing examples and use `register_conv_template` to add a new one. Please also add a link to the official reference code if possible. -2. Implement a model adapter for the new model at [fastchat/model/model_adapter.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/model/model_adapter.py). You can follow existing examples and use `register_model_adapter` to add a new one. -3. (Optional) add the model name to the "Supported models" [section](#supported-models) above and add more information in [fastchat/model/model_registry.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/model/model_registry.py). - -After these steps, the new model should be compatible with most FastChat features, such as CLI, web UI, model worker, and OpenAI-compatible API server. Please do some testing with these features as well. +Now, you can open a browser and interact with the model. diff --git a/docs/openai_api.md b/docs/openai_api.md index f3c0fba93..089b500ff 100644 --- a/docs/openai_api.md +++ b/docs/openai_api.md @@ -8,6 +8,8 @@ The following OpenAI APIs are supported: - Completions. (Reference: https://platform.openai.com/docs/api-reference/completions) - Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings) +The REST API can be seamlessly operated from Google Colab, as demonstrated in the [FastChat_API_GoogleColab.ipynb](https://github.com/lm-sys/FastChat/blob/main/playground/FastChat_API_GoogleColab.ipynb) notebook, available in our repository. This notebook provides a practical example of how to utilize the API effectively within the Google Colab environment. + ## RESTful API Server First, launch the controller @@ -32,29 +34,28 @@ Now, let us test the API server. ### OpenAI Official SDK The goal of `openai_api_server.py` is to implement a fully OpenAI-compatible API server, so the models can be used directly with [openai-python](https://github.com/openai/openai-python) library. -First, install openai-python: +First, install OpenAI python package >= 1.0: ```bash pip install --upgrade openai ``` -Then, interact with model vicuna: +Then, interact with the Vicuna model: ```python import openai -# to get proper authentication, make sure to use a valid key that's listed in -# the --api-keys flag. if no flag value is provided, the `api_key` will be ignored. + openai.api_key = "EMPTY" -openai.api_base = "http://localhost:8000/v1" +openai.base_url = "http://localhost:8000/v1/" model = "vicuna-7b-v1.5" prompt = "Once upon a time" # create a completion -completion = openai.Completion.create(model=model, prompt=prompt, max_tokens=64) +completion = openai.completions.create(model=model, prompt=prompt, max_tokens=64) # print the completion print(prompt + completion.choices[0].text) # create a chat completion -completion = openai.ChatCompletion.create( +completion = openai.chat.completions.create( model=model, messages=[{"role": "user", "content": "Hello! What is your name?"}] ) diff --git a/docs/third_party_ui.md b/docs/third_party_ui.md new file mode 100644 index 000000000..c0b230150 --- /dev/null +++ b/docs/third_party_ui.md @@ -0,0 +1,24 @@ +# Third Party UI +If you want to host it on your own UI or third party UI, you can launch the [OpenAI compatible server](openai_api.md) and host with a tunnelling service such as Tunnelmole or ngrok, and then enter the credentials appropriately. + +You can find suitable UIs from third party repos: +- [WongSaang's ChatGPT UI](https://github.com/WongSaang/chatgpt-ui) +- [McKayWrigley's Chatbot UI](https://github.com/mckaywrigley/chatbot-ui) + +- Please note that some third-party providers only offer the standard `gpt-3.5-turbo`, `gpt-4`, etc., so you will have to add your own custom model inside the code. [Here is an example of how to create a UI with any custom model name](https://github.com/ztjhz/BetterChatGPT/pull/461). + +##### Using Tunnelmole +Tunnelmole is an open source tunnelling tool. You can find its source code on [Github](https://github.com/robbie-cahill/tunnelmole-client). Here's how you can use Tunnelmole: +1. Install Tunnelmole with `curl -O https://install.tunnelmole.com/9Wtxu/install && sudo bash install`. (On Windows, download [tmole.exe](https://tunnelmole.com/downloads/tmole.exe)). Head over to the [README](https://github.com/robbie-cahill/tunnelmole-client) for other methods such as `npm` or building from source. +2. Run `tmole 7860` (replace `7860` with your listening port if it is different from 7860). The output will display two URLs: one HTTP and one HTTPS. It's best to use the HTTPS URL for better privacy and security. +``` +➜ ~ tmole 7860 +http://bvdo5f-ip-49-183-170-144.tunnelmole.net is forwarding to localhost:7860 +https://bvdo5f-ip-49-183-170-144.tunnelmole.net is forwarding to localhost:7860 +``` + +##### Using ngrok +ngrok is a popular closed source tunnelling tool. First download and install it from [ngrok.com](https://ngrok.com/downloads). Here's how to use it to expose port 7860. +``` +ngrok http 7860 +``` diff --git a/docs/training.md b/docs/training.md index 077221824..87b87312f 100644 --- a/docs/training.md +++ b/docs/training.md @@ -90,7 +90,7 @@ deepspeed fastchat/train/train_lora_t5.py \ ### Fine-tuning Vicuna-7B with Local NPUs -You can use the following command to train Vicuna-7B with 8 x 910B (60GB). Use `--nproc_per_node` to specify the number of NPUs. +You can use the following command to train Vicuna-7B with 8 x NPUs. Use `--nproc_per_node` to specify the number of NPUs. ```bash torchrun --nproc_per_node=8 --master_port=20001 fastchat/train/train.py \ --model_name_or_path ~/vicuna-7b-v1.5-16k \ diff --git a/fastchat/__init__.py b/fastchat/__init__.py index c4feccf55..c971add65 100644 --- a/fastchat/__init__.py +++ b/fastchat/__init__.py @@ -1 +1 @@ -__version__ = "0.2.33" +__version__ = "0.2.36" diff --git a/fastchat/constants.py b/fastchat/constants.py index 53ed55c1c..24e1783af 100644 --- a/fastchat/constants.py +++ b/fastchat/constants.py @@ -15,6 +15,7 @@ CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION." INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE." SLOW_MODEL_MSG = "⚠️ Both models will show the responses all at once. Please stay patient as it may take over 30 seconds." +RATE_LIMIT_MSG = "**RATE LIMIT OF THIS MODEL IS REACHED. PLEASE COME BACK LATER OR TRY OTHER MODELS.**" # Maximum input length INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 12000)) # Maximum conversation turns diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 9c8b57e13..95576536c 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -5,8 +5,10 @@ If you have any changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates. """ +import base64 import dataclasses from enum import auto, IntEnum +from io import BytesIO from typing import List, Any, Dict, Union, Tuple @@ -29,6 +31,12 @@ class SeparatorStyle(IntEnum): ROBIN = auto() FALCON_CHAT = auto() CHATGLM3 = auto() + DEEPSEEK_CHAT = auto() + METAMATH = auto() + YUAN2 = auto() + + +IMAGE_PLACEHOLDER_STR = "$$$$" @dataclasses.dataclass @@ -44,6 +52,7 @@ class Conversation: # The names of two roles roles: Tuple[str] = ("USER", "ASSISTANT") # All messages. Each item is (role, message). + # Each message is either a string or a tuple of (string, List[image_url]). messages: List[List[str]] = () # The number of few shot examples offset: int = 0 @@ -72,6 +81,9 @@ def get_prompt(self) -> str: ret = system_prompt + seps[0] for i, (role, message) in enumerate(self.messages): if message: + if type(message) is tuple: + message, images = message + message = IMAGE_PLACEHOLDER_STR * len(images) + message ret += role + ": " + message + seps[i % 2] else: ret += role + ":" @@ -160,6 +172,9 @@ def get_prompt(self) -> str: ret = "" if system_prompt == "" else system_prompt + self.sep + "\n" for role, message in self.messages: if message: + if type(message) is tuple: + message, images = message + message = IMAGE_PLACEHOLDER_STR * len(images) + message ret += role + "\n" + message + self.sep + "\n" else: ret += role + "\n" @@ -170,7 +185,7 @@ def get_prompt(self) -> str: ret += system_prompt for role, message in self.messages: if message: - ret += role + "\n" + " " + message + ret += role + "\n" + message else: ret += role return ret @@ -222,11 +237,52 @@ def get_prompt(self) -> str: ret += role + ": " + message + self.sep else: ret += role + ":" - + return ret + elif self.sep_style == SeparatorStyle.METAMATH: + ret = "" if system_prompt == "" else system_prompt + self.sep + for i, (role, message) in enumerate(self.messages): + # For MetaMath, sep2 is used to prefix the message. + starting_sep = ":\n" if i % 2 == 0 else ": " + self.sep2 + ending_sep = self.sep if i % 2 == 0 else "" + if message: + ret += role + starting_sep + message + ending_sep + else: + ret += role + starting_sep + return ret + elif self.sep_style == SeparatorStyle.DEEPSEEK_CHAT: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.YUAN2: + seps = [self.sep, self.sep2] + ret = "" + if self.system_message: + ret += system_prompt + seps[1] + for _, message in self.messages: + if message: + ret += message + "" + else: + ret += "" + ret = ret.rstrip("") + seps[0] return ret else: raise ValueError(f"Invalid style: {self.sep_style}") + def get_images(self): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + for image in msg[1]: + images.append(image) + + return images + def set_system_message(self, system_message: str): """Set the system message.""" self.system_message = system_message @@ -243,11 +299,52 @@ def update_last_message(self, message: str): """ self.messages[-1][1] = message + def convert_image_to_base64(self, image): + """Given an image, return the base64 encoded image string.""" + from PIL import Image + import requests + + # Load image if it has not been loaded in yet + if type(image) == str: + if image.startswith("http://") or image.startswith("https://"): + response = requests.get(image) + image = Image.open(BytesIO(response.content)).convert("RGB") + elif "base64" in image: + # OpenAI format is: data:image/jpeg;base64,{base64_encoded_image_str} + return image.split(",")[1] + else: + image = Image.open(image).convert("RGB") + + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 2048, 2048 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if longest_edge != max(image.size): + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + + buffered = BytesIO() + image.save(buffered, format="PNG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + + return img_b64_str + def to_gradio_chatbot(self): """Convert the conversation to gradio chatbot format.""" ret = [] for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: + if type(msg) is tuple: + msg, image = msg + img_b64_str = image[0] # Only one image on gradio at one time + img_str = f'user upload image' + msg = img_str + msg.replace("\n", "").strip() + ret.append([msg, None]) else: ret[-1][-1] = msg @@ -255,7 +352,10 @@ def to_gradio_chatbot(self): def to_openai_api_messages(self): """Convert the conversation to OpenAI chat completion format.""" - ret = [{"role": "system", "content": self.system_message}] + if self.system_message == "": + ret = [] + else: + ret = [{"role": "system", "content": self.system_message}] for i, (_, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: @@ -265,6 +365,12 @@ def to_openai_api_messages(self): ret.append({"role": "assistant", "content": msg}) return ret + def extract_text_from_messages(self): + return [ + (role, message[0]) if type(message) is tuple else (role, message) + for role, message in self.messages + ] + def copy(self): return Conversation( name=self.name, @@ -285,7 +391,7 @@ def dict(self): "template_name": self.name, "system_message": self.system_message, "roles": self.roles, - "messages": self.messages, + "messages": self.extract_text_from_messages(), "offset": self.offset, } @@ -463,7 +569,7 @@ def get_conv_template(name: str) -> Conversation: register_conv_template( Conversation( name="chatglm3", - system_template="<|system|>\n {system_message}", + system_template="<|system|>\n{system_message}", roles=("<|user|>", "<|assistant|>"), sep_style=SeparatorStyle.CHATGLM3, stop_token_ids=[ @@ -527,10 +633,20 @@ def get_conv_template(name: str) -> Conversation: ) ) +# TenyxChat default template +register_conv_template( + Conversation( + name="tenyxchat", + roles=("User", "Assistant"), + sep_style=SeparatorStyle.FALCON_CHAT, + sep="<|end_of_turn|>", + ) +) + # Deepseek code default template register_conv_template( Conversation( - name="deepseek", + name="deepseek-coder", system_template="You are an AI programming assistant, utilizing the DeepSeek Coder model, developed by DeepSeek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.", roles=("### Instruction:", "### Response:"), sep="\n", @@ -658,6 +774,17 @@ def get_conv_template(name: str) -> Conversation: ) ) +# Perplexity AI template +register_conv_template( + Conversation( + name="pplxai", + system_message="Be precise and concise.", + roles=("user", "assistant"), + sep_style=None, + sep=None, + ) +) + # Claude default template register_conv_template( Conversation( @@ -668,6 +795,20 @@ def get_conv_template(name: str) -> Conversation: ) ) +# MetaMath default template +# reference: https://github.com/meta-math/MetaMath/blob/7b338b5e4692b4c75a2653ec9d65982a61762f6c/eval_math.py#L58 +register_conv_template( + Conversation( + name="metamath", + system_template="{system_message}", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.METAMATH, + sep="\n\n", + sep2="Let's think step by step.", + ) +) + # MPT default template register_conv_template( Conversation( @@ -740,6 +881,15 @@ def get_conv_template(name: str) -> Conversation: ) ) +register_conv_template( + Conversation( + name="gemini", + roles=("user", "model"), + sep_style=None, + sep=None, + ) +) + # BiLLa default template register_conv_template( Conversation( @@ -933,7 +1083,7 @@ def get_conv_template(name: str) -> Conversation: register_conv_template( Conversation( name="mistral", - system_template="[INST]{system_message}\n", + system_template="[INST] {system_message}\n", roles=("[INST]", "[/INST]"), sep_style=SeparatorStyle.LLAMA2, sep=" ", @@ -955,6 +1105,18 @@ def get_conv_template(name: str) -> Conversation: ) ) +register_conv_template( + Conversation( + name="chinese-alpaca2", + system_template="[INST] <>\n{system_message}\n<>\n\n", + system_message="You are a helpful assistant. 你是一个乐于助人的助手。请你提供专业、有逻辑、内容真实、有价值的详细回复。", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + register_conv_template( Conversation( name="cutegpt", @@ -1003,6 +1165,21 @@ def get_conv_template(name: str) -> Conversation: ) +# ehartford/dolphin-2.2.1-mistral-7b template +# reference: https://huggingface.co/ehartford/dolphin-2.2.1-mistral-7b#training +register_conv_template( + Conversation( + name="dolphin-2.2.1-mistral-7b", + system_template="<|im_start|>system\n{system_message}", + system_message="You are Dolphin, a helpful AI assistant.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + ) +) + + # teknium/OpenHermes-2.5-Mistral-7B template # source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B # reference: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B#prompt-template @@ -1019,6 +1196,21 @@ def get_conv_template(name: str) -> Conversation: ) +# NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO template +# source: https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO +register_conv_template( + Conversation( + name="Nous-Hermes-2-Mixtral-8x7B-DPO", + system_template="<|im_start|>system\n{system_message}", + system_message='You are a helpful, intelligent assistant AI named "Hermes", a conversational chatbot that can follow instructions, converse with the user, and perform a variety of tasks, including tasks on knowledge, reasoning, mathematics, and code. Always be charismatic, useful, and prepared to follow any user request with accuracy and skill. You should respond with high quality, fluent, and detailed responses. Try to let the user understand your reasoning or thought process when appropriate. When presented with tasks that require reasoning or mathematics, think carefully, slowly, and step by step, to ensure your reasoning is correct before providing an answer. Utilize the "Examples" section to assist you in performing the task. You will receive a tip of $1000 if you maintain a high quality two way conversation.', + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + ) +) + + # Qwen-chat default template # source: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py#L130 register_conv_template( @@ -1236,6 +1428,18 @@ def get_conv_template(name: str) -> Conversation: stop_str="<|user|>", ) ) +# xDAN default template +# source: https://huggingface.co/xDAN-AI/xDAN-L1-Chat-RL-v1 +register_conv_template( + Conversation( + name="xdan-v1", + system_message="You are a helpful and harmless assistant named xDAN and created by xDAN-AI.Please response and work on questions thinking step by step.", + roles=("### Human", "### Assistant"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="\n", + stop_str="", + ) +) # Zephyr template # reference: https://huggingface.co/spaces/HuggingFaceH4/zephyr-playground/blob/main/dialogues.py @@ -1251,6 +1455,34 @@ def get_conv_template(name: str) -> Conversation: ) ) +# CatPPT template +# reference: https://huggingface.co/rishiraj/CatPPT +register_conv_template( + Conversation( + name="catppt", + system_template="<|system|>\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="", + stop_token_ids=[2], + stop_str="", + ) +) + +# TinyLlama template +# reference: https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0 +register_conv_template( + Conversation( + name="TinyLlama", + system_template="<|system|>\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="", + stop_token_ids=[2], + stop_str="", + ) +) + # Orca-2 template # reference: https://huggingface.co/microsoft/Orca-2-7b register_conv_template( @@ -1265,6 +1497,89 @@ def get_conv_template(name: str) -> Conversation: ) ) +# Deepseek-chat template +# reference: https://huggingface.co/deepseek-ai/deepseek-llm-67b-chat/blob/main/tokenizer_config.json +register_conv_template( + Conversation( + name="deepseek-chat", + system_message="<|begin▁of▁sentence|>", # must add a bos token before first message + roles=("User", "Assistant"), + sep_style=SeparatorStyle.DEEPSEEK_CHAT, + sep="\n\n", + sep2="<|end▁of▁sentence|>", + stop_str="<|end▁of▁sentence|>", + ) +) + +# Yuan2.0 chat template +# source: https://huggingface.co/IEITYuan/Yuan2-2B-Janus-hf/blob/main/tokenizer_config.json#L6 +register_conv_template( + Conversation( + name="yuan2", + roles=("user", "assistant"), + sep_style=SeparatorStyle.YUAN2, + sep="", + sep2="\n", + stop_token_ids=[ + 77185, + ], # "" + stop_str="", + ) +) + +# Solar-10.7B Chat Template +# Reference: https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0/blob/main/tokenizer_config.json +register_conv_template( + Conversation( + name="solar", + system_message="", + roles=("### User", "### Assistant"), + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + sep="\n\n", + stop_str="", + ) +) + +# nvidia/Llama2-70B-SteerLM-Chat +register_conv_template( + Conversation( + name="steerlm", + system_message="", + roles=("user", "assistant"), + sep_style=None, + sep=None, + ) +) + +# yuan 2.0 template +# reference:https://github.com/IEIT-Yuan/Yuan-2.0 +# reference:https://huggingface.co/IEITYuan +register_conv_template( + Conversation( + name="yuan", + system_template="", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_str="", + ) +) + +# Llava-chatml +# reference: https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/llava/conversation.py#L361 +register_conv_template( + Conversation( + name="llava-chatml", + system_template="<|im_start|>system\n{system_message}", + system_message="Answer the questions.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_str="<|im_end|>", + ) +) + + if __name__ == "__main__": from fastchat.conversation import get_conv_template diff --git a/fastchat/llm_judge/README.md b/fastchat/llm_judge/README.md index 1d2646b13..6737cf8ba 100644 --- a/fastchat/llm_judge/README.md +++ b/fastchat/llm_judge/README.md @@ -59,7 +59,7 @@ You can also specify `--num-gpus-per-model` for model parallelism (needed for la #### Step 2. Generate GPT-4 judgments There are several options to use GPT-4 as a judge, such as pairwise winrate and single-answer grading. -In MT-bench, we recommond single-answer grading as the default mode. +In MT-bench, we recommend single-answer grading as the default mode. This mode asks GPT-4 to grade and give a score to model's answer directly without pairwise comparison. For each turn, GPT-4 will give a score on a scale of 10. We then compute the average score on all turns. @@ -129,6 +129,27 @@ You can use this [colab notebook](https://colab.research.google.com/drive/15O3Y8 +### Other backends +We can also use vLLM for answer generation, which can be faster for the models supported by vLLM. + +1. Launch a vLLM worker +``` +python3 -m fastchat.serve.controller +python3 -m fastchat.serve.vllm_worker --model-path [MODEL-PATH] +python3 -m fastchat.serve.openai_api_server --host localhost --port 8000 +``` + - Arguments: + - `[MODEL-PATH]` is the path to the weights, which can be a local folder or a Hugging Face repo ID. + +2. Generate the answers +``` +python gen_api_answer.py --model [MODEL-NAME] --openai-api-base http://localhost:8000/v1 --parallel 50 +``` + - Arguments: + - `[MODEL-NAME]` is the name of the model from Step 1. + - `--parallel` is the number of concurrent API calls to the vLLM worker. + + ## Agreement Computation We released 3.3K human annotations for model responses generated by 6 models in response to 80 MT-bench questions. The dataset is available at [lmsys/mt_bench_human_judgments](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments). @@ -138,6 +159,7 @@ This Colab [notebook](https://colab.research.google.com/drive/1ctgygDRJhVGUJTQy8 - [Chatbot Arena Conversation Dataset](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations) - [MT-bench Human Annotation Dataset](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments) + ## Citation Please cite the following paper if you find the code or datasets helpful. ``` diff --git a/fastchat/llm_judge/common.py b/fastchat/llm_judge/common.py index 4b598cefb..d2640d601 100644 --- a/fastchat/llm_judge/common.py +++ b/fastchat/llm_judge/common.py @@ -14,7 +14,11 @@ import openai import anthropic -from fastchat.model.model_adapter import get_conversation_template, ANTHROPIC_MODEL_LIST +from fastchat.model.model_adapter import ( + get_conversation_template, + ANTHROPIC_MODEL_LIST, + OPENAI_MODEL_LIST, +) # API setting constants API_MAX_RETRY = 16 @@ -159,10 +163,10 @@ def run_judge_single(question, answer, judge, ref_answer, multi_turn=False): conv.append_message(conv.roles[0], user_prompt) conv.append_message(conv.roles[1], None) - if model in ["gpt-3.5-turbo", "gpt-4"]: - judgment = chat_compeletion_openai(model, conv, temperature=0, max_tokens=2048) + if model in OPENAI_MODEL_LIST: + judgment = chat_completion_openai(model, conv, temperature=0, max_tokens=2048) elif model in ANTHROPIC_MODEL_LIST: - judgment = chat_compeletion_anthropic( + judgment = chat_completion_anthropic( model, conv, temperature=0, max_tokens=1024 ) else: @@ -185,7 +189,7 @@ def run_judge_single(question, answer, judge, ref_answer, multi_turn=False): return rating, user_prompt, judgment -def play_a_match_single(match: MatchPair, output_file: str): +def play_a_match_single(match: MatchSingle, output_file: str): question, model, answer, judge, ref_answer, multi_turn = ( match.question, match.model, @@ -262,14 +266,14 @@ def run_judge_pair(question, answer_a, answer_b, judge, ref_answer, multi_turn=F conv.append_message(conv.roles[0], user_prompt) conv.append_message(conv.roles[1], None) - if model in ["gpt-3.5-turbo", "gpt-4"]: + if model in OPENAI_MODEL_LIST: conv.set_system_message(system_prompt) - judgment = chat_compeletion_openai(model, conv, temperature=0, max_tokens=2048) + judgment = chat_completion_openai(model, conv, temperature=0, max_tokens=2048) elif model in ANTHROPIC_MODEL_LIST: if system_prompt != "You are a helpful assistant.": user_prompt = "[Instruction]\n" + system_prompt + "\n\n" + user_prompt conv.messages[0][1] = user_prompt - judgment = chat_compeletion_anthropic( + judgment = chat_completion_anthropic( model, conv, temperature=0, max_tokens=1024 ) else: @@ -400,7 +404,7 @@ def play_a_match_pair(match: MatchPair, output_file: str): return result -def chat_compeletion_openai(model, conv, temperature, max_tokens, api_dict=None): +def chat_completion_openai(model, conv, temperature, max_tokens, api_dict=None): if api_dict is not None: openai.api_base = api_dict["api_base"] openai.api_key = api_dict["api_key"] @@ -424,7 +428,7 @@ def chat_compeletion_openai(model, conv, temperature, max_tokens, api_dict=None) return output -def chat_compeletion_openai_azure(model, conv, temperature, max_tokens, api_dict=None): +def chat_completion_openai_azure(model, conv, temperature, max_tokens, api_dict=None): openai.api_type = "azure" openai.api_version = "2023-07-01-preview" if api_dict is not None: @@ -463,11 +467,16 @@ def chat_compeletion_openai_azure(model, conv, temperature, max_tokens, api_dict return output -def chat_compeletion_anthropic(model, conv, temperature, max_tokens): +def chat_completion_anthropic(model, conv, temperature, max_tokens, api_dict=None): + if api_dict is not None and "api_key" in api_dict: + api_key = api_dict["api_key"] + else: + api_key = os.environ["ANTHROPIC_API_KEY"] + output = API_ERROR_OUTPUT for _ in range(API_MAX_RETRY): try: - c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) + c = anthropic.Anthropic(api_key=api_key) prompt = conv.get_prompt() response = c.completions.create( model=model, @@ -484,7 +493,7 @@ def chat_compeletion_anthropic(model, conv, temperature, max_tokens): return output.strip() -def chat_compeletion_palm(chat_state, model, conv, temperature, max_tokens): +def chat_completion_palm(chat_state, model, conv, temperature, max_tokens): from fastchat.serve.api_provider import init_palm_chat assert model == "palm-2-chat-bison-001" diff --git a/fastchat/llm_judge/gen_api_answer.py b/fastchat/llm_judge/gen_api_answer.py index b39618546..8f9c62624 100644 --- a/fastchat/llm_judge/gen_api_answer.py +++ b/fastchat/llm_judge/gen_api_answer.py @@ -1,7 +1,7 @@ """Generate answers with GPT-4 Usage: -python3 get_api_answer.py --model gpt-3.5-turbo +python3 gen_api_answer.py --model gpt-3.5-turbo """ import argparse import json @@ -16,9 +16,9 @@ from fastchat.llm_judge.common import ( load_questions, temperature_config, - chat_compeletion_openai, - chat_compeletion_anthropic, - chat_compeletion_palm, + chat_completion_openai, + chat_completion_anthropic, + chat_completion_palm, ) from fastchat.llm_judge.gen_model_answer import reorg_answer_file from fastchat.model.model_adapter import get_conversation_template, ANTHROPIC_MODEL_LIST @@ -50,15 +50,13 @@ def get_answer( conv.append_message(conv.roles[1], None) if model in ANTHROPIC_MODEL_LIST: - output = chat_compeletion_anthropic( - model, conv, temperature, max_tokens - ) + output = chat_completion_anthropic(model, conv, temperature, max_tokens) elif model == "palm-2-chat-bison-001": - chat_state, output = chat_compeletion_palm( + chat_state, output = chat_completion_palm( chat_state, model, conv, temperature, max_tokens ) else: - output = chat_compeletion_openai(model, conv, temperature, max_tokens) + output = chat_completion_openai(model, conv, temperature, max_tokens) conv.update_last_message(output) turns.append(output) diff --git a/fastchat/llm_judge/qa_browser.py b/fastchat/llm_judge/qa_browser.py index e449dee3a..1107756db 100644 --- a/fastchat/llm_judge/qa_browser.py +++ b/fastchat/llm_judge/qa_browser.py @@ -36,7 +36,7 @@ def display_question(category_selector, request: gr.Request): choices = category_selector_map[category_selector] - return gr.Dropdown.update( + return gr.Dropdown( value=choices[0], choices=choices, ) @@ -413,6 +413,8 @@ def build_demo(): ) = load_pairwise_model_judgments(pairwise_model_judgment_file) demo = build_demo() - demo.queue(concurrency_count=10, status_update_rate=10, api_open=False).launch( + demo.queue( + default_concurrency_limit=10, status_update_rate=10, api_open=False + ).launch( server_name=args.host, server_port=args.port, share=args.share, max_threads=200 ) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index e130de1cb..3b3b3c48b 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -12,7 +12,6 @@ else: from functools import lru_cache as cache -import accelerate import psutil import torch from transformers import ( @@ -33,6 +32,7 @@ from fastchat.model.model_chatglm import generate_stream_chatglm from fastchat.model.model_codet5p import generate_stream_codet5p from fastchat.model.model_falcon import generate_stream_falcon +from fastchat.model.model_yuan2 import generate_stream_yuan2 from fastchat.model.model_exllama import generate_stream_exllama from fastchat.model.model_xfastertransformer import generate_stream_xft from fastchat.model.monkey_patch_non_inplace import ( @@ -53,7 +53,24 @@ ANTHROPIC_MODEL_LIST = ( "claude-1", "claude-2", + "claude-2.0", + "claude-2.1", "claude-instant-1", + "claude-instant-1.2", +) + +OPENAI_MODEL_LIST = ( + "gpt-3.5-turbo", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-turbo", + "gpt-4-1106-preview", + "gpt-4-0125-preview", ) @@ -177,6 +194,8 @@ def load_model( debug: bool = False, ): """Load a model from Hugging Face.""" + import accelerate + # get model adapter adapter = get_model_adapter(model_path) @@ -317,6 +336,20 @@ def load_model( if dtype is not None: # Overwrite dtype if it is provided in the arguments. kwargs["torch_dtype"] = dtype + if os.environ.get("FASTCHAT_USE_MODELSCOPE", "False").lower() == "true": + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + try: + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model_path): + model_path = snapshot_download(model_id=model_path, revision=revision) + except ImportError as e: + warnings.warn( + "Use model from www.modelscope.cn need pip install modelscope" + ) + raise e + # Load model model, tokenizer = adapter.load_model(model_path, kwargs) @@ -354,12 +387,13 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str): from fastchat.serve.inference import generate_stream model_type = str(type(model)).lower() + is_peft = "peft" in model_type is_chatglm = "chatglm" in model_type is_falcon = "rwforcausallm" in model_type is_codet5p = "codet5p" in model_type - is_peft = "peft" in model_type is_exllama = "exllama" in model_type is_xft = "xft" in model_type + is_yuan = "yuan" in model_type if is_chatglm: return generate_stream_chatglm @@ -371,6 +405,8 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str): return generate_stream_exllama elif is_xft: return generate_stream_xft + elif is_yuan: + return generate_stream_yuan2 elif peft_share_base_weights and is_peft: # Return a curried stream function that loads the right adapter @@ -387,7 +423,28 @@ def generate_stream_peft( judge_sent_end: bool = False, ): model.set_adapter(model_path) - for x in generate_stream( + base_model_type = str(type(model.base_model.model)) + is_chatglm = "chatglm" in base_model_type + is_falcon = "rwforcausallm" in base_model_type + is_codet5p = "codet5p" in base_model_type + is_exllama = "exllama" in base_model_type + is_xft = "xft" in base_model_type + is_yuan = "yuan" in base_model_type + + generate_stream_function = generate_stream + if is_chatglm: + generate_stream_function = generate_stream_chatglm + elif is_falcon: + generate_stream_function = generate_stream_falcon + elif is_codet5p: + generate_stream_function = generate_stream_codet5p + elif is_exllama: + generate_stream_function = generate_stream_exllama + elif is_xft: + generate_stream_function = generate_stream_xft + elif is_yuan: + generate_stream_function = generate_stream_yuan2 + for x in generate_stream_function( model, tokenizer, params, @@ -903,6 +960,16 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("openchat_3.5") +class TenyxChatAdapter(BaseModelAdapter): + """The model adapter for TenyxChat (e.g. tenyx/TenyxChat-7B-v1)""" + + def match(self, model_path: str): + return "tenyxchat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("tenyxchat") + + class PythiaAdapter(BaseModelAdapter): """The model adapter for any EleutherAI/pythia model""" @@ -1040,12 +1107,7 @@ class ChatGPTAdapter(BaseModelAdapter): """The model adapter for ChatGPT""" def match(self, model_path: str): - return model_path in ( - "gpt-3.5-turbo", - "gpt-3.5-turbo-1106", - "gpt-4", - "gpt-4-turbo", - ) + return model_path in OPENAI_MODEL_LIST def load_model(self, model_path: str, from_pretrained_kwargs: dict): raise NotImplementedError() @@ -1067,6 +1129,22 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("chatgpt") +class PplxAIAdapter(BaseModelAdapter): + """The model adapter for Perplexity AI""" + + def match(self, model_path: str): + return model_path in ( + "pplx-7b-online", + "pplx-70b-online", + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("pplxai") + + class ClaudeAdapter(BaseModelAdapter): """The model adapter for Claude""" @@ -1106,6 +1184,19 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("bard") +class GeminiAdapter(BaseModelAdapter): + """The model adapter for Gemini""" + + def match(self, model_path: str): + return "gemini" in model_path.lower() or "bard" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("gemini") + + class BiLLaAdapter(BaseModelAdapter): """The model adapter for Neutralzz/BiLLa-7B-SFT""" @@ -1424,7 +1515,7 @@ class MistralAdapter(BaseModelAdapter): """The model adapter for Mistral AI models""" def match(self, model_path: str): - return "mistral" in model_path.lower() + return "mistral" in model_path.lower() or "mixtral" in model_path.lower() def load_model(self, model_path: str, from_pretrained_kwargs: dict): model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) @@ -1508,6 +1599,16 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("open-orca") +class DolphinAdapter(OpenOrcaAdapter): + """Model adapter for ehartford/dolphin-2.2.1-mistral-7b""" + + def match(self, model_path: str): + return "dolphin" in model_path.lower() and "mistral" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("dolphin-2.2.1-mistral-7b") + + class Hermes2Adapter(BaseModelAdapter): """Model adapter for teknium/OpenHermes-2.5-Mistral-7B and teknium/OpenHermes-2-Mistral-7B models""" @@ -1535,6 +1636,22 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("OpenHermes-2.5-Mistral-7B") +class NousHermes2MixtralAdapter(BaseModelAdapter): + """Model adapter for NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO model""" + + def match(self, model_path: str): + return any( + model_str in model_path.lower() + for model_str in [ + "nous-hermes-2-mixtral-8x7b-dpo", + "nous-hermes-2-mixtral-8x7b-sft", + ] + ) + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("Nous-Hermes-2-Mixtral-8x7B-DPO") + + class WizardCoderAdapter(BaseModelAdapter): """The model adapter for WizardCoder (e.g., WizardLM/WizardCoder-Python-34B-V1.0)""" @@ -1646,6 +1763,8 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): model.config.max_sequence_length = min( model.config.max_position_embeddings, tokenizer.model_max_length ) + model.use_cls_pooling = True + model.eval() return model, tokenizer def get_default_conv_template(self, model_path: str) -> Conversation: @@ -1768,7 +1887,7 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): return model, tokenizer def get_default_conv_template(self, model_path: str) -> Conversation: - return get_conv_template("llama2-chinese") + return get_conv_template("chinese-alpaca2") class VigogneAdapter(BaseModelAdapter): @@ -1895,6 +2014,36 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("zephyr") +class NotusAdapter(BaseModelAdapter): + """The model adapter for Notus (e.g. argilla/notus-7b-v1)""" + + def match(self, model_path: str): + return "notus" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("zephyr") + + +class CatPPTAdapter(BaseModelAdapter): + """The model adapter for CatPPT (e.g. rishiraj/CatPPT)""" + + def match(self, model_path: str): + return "catppt" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("catppt") + + +class TinyLlamaAdapter(BaseModelAdapter): + """The model adapter for TinyLlama (e.g. TinyLlama/TinyLlama-1.1B-Chat-v1.0)""" + + def match(self, model_path: str): + return "tinyllama" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("TinyLlama") + + class XwinLMAdapter(BaseModelAdapter): """The model adapter for Xwin-LM V0.1 and V0.2 series of models(e.g., Xwin-LM/Xwin-LM-70B-V0.1)""" @@ -1933,6 +2082,16 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("metharme") +class XdanAdapter(BaseModelAdapter): + """The model adapter for xDAN-AI (e.g. xDAN-AI/xDAN-L1-Chat-RL-v1)""" + + def match(self, model_path: str): + return "xdan" in model_path.lower() and "v1" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("xdan-v1") + + class MicrosoftOrcaAdapter(BaseModelAdapter): """The model adapter for Microsoft/Orca-2 series of models (e.g. Microsoft/Orca-2-7b, Microsoft/Orca-2-13b)""" @@ -1955,6 +2114,171 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("Yi-34b-chat") +class DeepseekCoderAdapter(BaseModelAdapter): + """The model adapter for deepseek-ai's coder models""" + + def match(self, model_path: str): + return "deepseek-coder" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("deepseek-coder") + + +class DeepseekChatAdapter(BaseModelAdapter): + """The model adapter for deepseek-ai's chat models""" + + # Note: that this model will require tokenizer version >= 0.13.3 because the tokenizer class is LlamaTokenizerFast + + def match(self, model_path: str): + return "deepseek-llm" in model_path.lower() and "chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("deepseek-chat") + + +class Yuan2Adapter(BaseModelAdapter): + """The model adapter for Yuan2.0""" + + def match(self, model_path: str): + return "yuan2" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + # from_pretrained_kwargs["torch_dtype"] = torch.bfloat16 + tokenizer = LlamaTokenizer.from_pretrained( + model_path, + add_eos_token=False, + add_bos_token=False, + eos_token="", + eod_token="", + sep_token="", + revision=revision, + ) + tokenizer.add_tokens( + [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + ], + special_tokens=True, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_path, + # device_map='auto', + trust_remote_code=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("yuan2") + + +class MetaMathAdapter(BaseModelAdapter): + """The model adapter for MetaMath models""" + + def match(self, model_path: str): + return "metamath" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("metamath") + + +class BagelAdapter(BaseModelAdapter): + """Model adapter for jondurbin/bagel-* models""" + + def match(self, model_path: str): + return "bagel" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("airoboros_v3") + + +class SolarAdapter(BaseModelAdapter): + """The model adapter for upstage/SOLAR-10.7B-Instruct-v1.0""" + + def match(self, model_path: str): + return "solar-" in model_path.lower() and "instruct" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("solar") + + +class SteerLMAdapter(BaseModelAdapter): + """The model adapter for nvidia/Llama2-70B-SteerLM-Chat""" + + def match(self, model_path: str): + return "steerlm-chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("steerlm") + + +class LlavaAdapter(BaseModelAdapter): + """The model adapter for liuhaotian/llava-v1.5 series of models""" + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + # TODO(chris): Implement huggingface-compatible load_model + pass + + def match(self, model_path: str): + return "llava" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + if "34b" in model_path: + return get_conv_template("llava-chatml") + + return get_conv_template("vicuna_v1.1") + + +class YuanAdapter(BaseModelAdapter): + """The model adapter for Yuan""" + + def match(self, model_path: str): + return "yuan" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + tokenizer.add_tokens( + [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + ], + special_tokens=True, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("yuan") + + # Note: the registration order matters. # The one registered earlier has a higher matching priority. register_model_adapter(PeftModelAdapter) @@ -1971,6 +2295,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(OasstPythiaAdapter) register_model_adapter(OasstLLaMAAdapter) register_model_adapter(OpenChat35Adapter) +register_model_adapter(TenyxChatAdapter) register_model_adapter(StableLMAdapter) register_model_adapter(BaizeAdapter) register_model_adapter(RwkvAdapter) @@ -1978,6 +2303,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(PhoenixAdapter) register_model_adapter(BardAdapter) register_model_adapter(PaLM2Adapter) +register_model_adapter(GeminiAdapter) register_model_adapter(ChatGPTAdapter) register_model_adapter(AzureOpenAIAdapter) register_model_adapter(ClaudeAdapter) @@ -1998,14 +2324,16 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(TigerBotAdapter) register_model_adapter(BaichuanAdapter) register_model_adapter(XGenAdapter) -register_model_adapter(NousHermesAdapter) register_model_adapter(PythiaAdapter) register_model_adapter(InternLMChatAdapter) register_model_adapter(StarChatAdapter) register_model_adapter(Llama2Adapter) register_model_adapter(CuteGPTAdapter) register_model_adapter(OpenOrcaAdapter) +register_model_adapter(DolphinAdapter) register_model_adapter(Hermes2Adapter) +register_model_adapter(NousHermes2MixtralAdapter) +register_model_adapter(NousHermesAdapter) register_model_adapter(MistralAdapter) register_model_adapter(WizardCoderAdapter) register_model_adapter(QwenChatAdapter) @@ -2021,11 +2349,25 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(CodeLlamaAdapter) register_model_adapter(Llama2ChangAdapter) register_model_adapter(ZephyrAdapter) +register_model_adapter(NotusAdapter) +register_model_adapter(CatPPTAdapter) +register_model_adapter(TinyLlamaAdapter) register_model_adapter(XwinLMAdapter) register_model_adapter(LemurAdapter) register_model_adapter(PygmalionAdapter) register_model_adapter(MicrosoftOrcaAdapter) +register_model_adapter(XdanAdapter) register_model_adapter(YiAdapter) +register_model_adapter(PplxAIAdapter) +register_model_adapter(DeepseekCoderAdapter) +register_model_adapter(DeepseekChatAdapter) +register_model_adapter(Yuan2Adapter) +register_model_adapter(MetaMathAdapter) +register_model_adapter(BagelAdapter) +register_model_adapter(SolarAdapter) +register_model_adapter(SteerLMAdapter) +register_model_adapter(LlavaAdapter) +register_model_adapter(YuanAdapter) # After all adapters, try the default base adapter. register_model_adapter(BaseModelAdapter) diff --git a/fastchat/model/model_chatglm.py b/fastchat/model/model_chatglm.py index 5d4db62bc..2cbac8bc5 100644 --- a/fastchat/model/model_chatglm.py +++ b/fastchat/model/model_chatglm.py @@ -37,6 +37,31 @@ def process_response(response): return response +def recover_message_list(prompt): + role_token_pattern = "|".join( + [re.escape(r) for r in ["<|system|>", "<|user|>", "<|assistant|>"]] + ) + role = None + last_end_idx = -1 + message_list = [] + for match in re.finditer(role_token_pattern, prompt): + if role: + messge = {} + if role == "<|system|>": + messge["role"] = "system" + elif role == "<|user|>": + messge["role"] = "user" + else: + messge["role"] = "assistant" + messge["content"] = prompt[last_end_idx + 1 : match.start()] + message_list.append(messge) + + role = prompt[match.start() : match.end()] + last_end_idx = match.end() + + return message_list + + @torch.inference_mode() def generate_stream_chatglm( model, @@ -54,7 +79,17 @@ def generate_stream_chatglm( max_new_tokens = int(params.get("max_new_tokens", 256)) echo = params.get("echo", True) - inputs = tokenizer([prompt], return_tensors="pt").to(model.device) + model_type = str(type(model)).lower() + if "peft" in model_type: + model_type = str(type(model.base_model.model)).lower() + + if "chatglm3" in model_type: + message_list = recover_message_list(prompt) + inputs = tokenizer.build_chat_input( + query=message_list[-1]["content"], history=message_list[:-1], role="user" + ).to(model.device) + else: + inputs = tokenizer([prompt], return_tensors="pt").to(model.device) input_echo_len = len(inputs["input_ids"][0]) gen_kwargs = { diff --git a/fastchat/model/model_registry.py b/fastchat/model/model_registry.py index 40aee1b4c..433449cdb 100644 --- a/fastchat/model/model_registry.py +++ b/fastchat/model/model_registry.py @@ -1,12 +1,12 @@ """Additional information of the models.""" -from collections import namedtuple +from collections import namedtuple, OrderedDict from typing import List ModelInfo = namedtuple("ModelInfo", ["simple_name", "link", "description"]) -model_info = {} +model_info = OrderedDict() def register_model_info( @@ -29,159 +29,356 @@ def get_model_info(name: str) -> ModelInfo: register_model_info( - ["gpt-3.5-turbo"], - "GPT-3.5", - "https://openai.com/blog/chatgpt", - "GPT-3.5 by OpenAI", + [ + "IEITYuan/Yuan2-2B-Janus-hf", + "IEITYuan/Yuan2-2B-hf", + "IEITYuan/Yuan2-51B-hf", + "IEITYuan/Yuan2-102B-hf", + ], + "IEIT-Yuan2", + "https://github.com/IEIT-Yuan/Yuan-2.0", + "Yuan2.0 is a new generation Fundamental Large Language Model developed by IEIT System.", ) + register_model_info( - ["gpt-3.5-turbo-1106"], - "GPT-3.5-Turbo-1106", - "https://platform.openai.com/docs/models/gpt-3-5", - "GPT-3.5-Turbo-1106 by OpenAI", + [ + "mixtral-8x7b-instruct-v0.1", + "mistral-medium", + "mistral-7b-instruct-v0.2", + "mistral-7b-instruct", + ], + "Mixtral of experts", + "https://mistral.ai/news/mixtral-of-experts/", + "A Mixture-of-Experts model by Mistral AI", +) + +register_model_info( + [ + "qwen1.5-72b-chat", + "qwen1.5-14b-chat", + "qwen1.5-7b-chat", + "qwen1.5-4b-chat", + "qwen1.5-1.8b-chat", + "qwen1.5-0.5b-chat", + "qwen-14b-chat", + ], + "Qwen 1.5", + "https://qwenlm.github.io/blog/qwen1.5/", + "A large language model by Alibaba Cloud", +) + +register_model_info( + ["qwen-14b-chat"], + "Qwen", + "https://huggingface.co/Qwen", + "A large language model by Alibaba Cloud", +) + +register_model_info( + ["bard-feb-2024", "bard-jan-24-gemini-pro"], + "Bard", + "https://bard.google.com/", + "Bard by Google", ) + register_model_info( - ["gpt-4"], "GPT-4", "https://openai.com/research/gpt-4", "ChatGPT-4 by OpenAI" + ["gemini-pro", "gemini-pro-dev-api"], + "Gemini", + "https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/", + "Gemini by Google", ) + register_model_info( - ["gpt-4-turbo"], + ["deepseek-llm-67b-chat"], + "DeepSeek LLM", + "https://huggingface.co/deepseek-ai/deepseek-llm-67b-chat", + "An advanced language model by DeepSeek", +) + +register_model_info( + ["stripedhyena-nous-7b"], + "StripedHyena-Nous", + "https://huggingface.co/togethercomputer/StripedHyena-Nous-7B", + "A chat model developed by Together Research and Nous Research.", +) + +register_model_info( + ["solar-10.7b-instruct-v1.0"], + "SOLAR-10.7B-Instruct", + "https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0", + "A model trained using depth up-scaling by Upstage AI", +) + +register_model_info( + ["gpt-4-turbo", "gpt-4-1106-preview", "gpt-4-0125-preview"], "GPT-4-Turbo", "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo", "GPT-4-Turbo by OpenAI", ) + register_model_info( - ["claude-2"], + [ + "gpt-3.5-turbo", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0314", + "gpt-3.5-turbo-0613", + ], + "GPT-3.5", + "https://platform.openai.com/docs/models/gpt-3-5", + "GPT-3.5-Turbo by OpenAI", +) + +register_model_info( + ["gpt-4", "gpt-4-0314", "gpt-4-0613"], + "GPT-4", + "https://openai.com/research/gpt-4", + "GPT-4 by OpenAI", +) + +register_model_info( + ["claude-2.1", "claude-2.0"], "Claude", "https://www.anthropic.com/index/claude-2", "Claude 2 by Anthropic", ) + register_model_info( ["claude-1"], "Claude", "https://www.anthropic.com/index/introducing-claude", - "Claude by Anthropic", + "Claude 1 by Anthropic", ) + register_model_info( - ["claude-instant-1"], + ["claude-instant-1", "claude-instant-1.2"], "Claude Instant", "https://www.anthropic.com/index/introducing-claude", "Claude Instant by Anthropic", ) + register_model_info( - ["palm-2"], - "PaLM 2 Chat", - "https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023", - "PaLM 2 for Chat (chat-bison@001) by Google", + ["nous-hermes-2-mixtral-8x7b-dpo"], + "Nous-Hermes-2-Mixtral-8x7B-DPO", + "https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", + "Nous Hermes finetuned from Mixtral 8x7B", +) + +register_model_info( + ["openchat-3.5-0106", "openchat-3.5"], + "OpenChat 3.5", + "https://github.com/imoneoi/openchat", + "An open model fine-tuned on Mistral-7B using C-RLFT", +) + +register_model_info( + ["deepseek-llm-67b-chat"], + "DeepSeek LLM", + "https://huggingface.co/deepseek-ai/deepseek-llm-67b-chat", + "An advanced language model by DeepSeek", +) + +register_model_info( + ["stripedhyena-nous-7b"], + "StripedHyena-Nous", + "https://huggingface.co/togethercomputer/StripedHyena-Nous-7B", + "A chat model developed by Together Research and Nous Research.", +) + +register_model_info( + ["llama2-70b-steerlm-chat"], + "Llama2-70B-SteerLM-Chat", + "https://huggingface.co/nvidia/Llama2-70B-SteerLM-Chat", + "A Llama fine-tuned with SteerLM method by NVIDIA", +) + +register_model_info( + ["pplx-70b-online", "pplx-7b-online"], + "pplx-online-llms", + "https://blog.perplexity.ai/blog/introducing-pplx-online-llms", + "Online LLM API by Perplexity AI", +) + +register_model_info( + ["openhermes-2.5-mistral-7b"], + "OpenHermes-2.5-Mistral-7B", + "https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B", + "A mistral-based model fine-tuned on 1M GPT-4 outputs", +) + +register_model_info( + ["starling-lm-7b-alpha"], + "Starling-LM-7B-alpha", + "https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha", + "An open model trained using RLAIF by Berkeley", +) + +register_model_info( + ["tulu-2-dpo-70b"], + "Tulu 2", + "https://huggingface.co/allenai/tulu-2-dpo-70b", + "An instruction and RLHF model by UW/AllenAI", +) + +register_model_info( + ["yi-34b-chat", "yi-6b-chat"], + "Yi-Chat", + "https://huggingface.co/01-ai/Yi-34B-Chat", + "A large language model by 01 AI", ) + +register_model_info( + ["llama-2-70b-chat", "llama-2-34b-chat", "llama-2-13b-chat", "llama-2-7b-chat"], + "Llama 2", + "https://ai.meta.com/llama/", + "Open foundation and fine-tuned chat models by Meta", +) + register_model_info( [ "vicuna-33b", "vicuna-33b-v1.3", "vicuna-13b", - "vicuna-13b-v1.3", + "vicuna-13b-v1.5", "vicuna-7b", - "vicuna-7b-v1.3", + "vicuna-7b-v1.5", ], "Vicuna", "https://lmsys.org/blog/2023-03-30-vicuna/", - "a chat assistant fine-tuned on user-shared conversations by LMSYS", + "A chat assistant fine-tuned on user-shared conversations by LMSYS", ) + register_model_info( - ["llama-2-70b-chat", "llama-2-34b-chat", "llama-2-13b-chat", "llama-2-7b-chat"], - "Llama 2", - "https://ai.meta.com/llama/", - "open foundation and fine-tuned chat models by Meta", + ["chatglm3-6b", "chatglm2-6b", "chatglm-6b"], + "ChatGLM", + "https://chatglm.cn/blog", + "An open bilingual dialogue language model by Tsinghua University", ) + register_model_info( - ["mistral-7b-instruct"], - "Mistral", - "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1", - "a large language model by Mistral AI team", + ["tenyxchat-7b-v1"], + "TenyxChat-7B", + "https://huggingface.co/tenyx/TenyxChat-7B-v1", + "An open model DPO trained on top of OpenChat-3.5 using Tenyx fine-tuning", ) + register_model_info( ["zephyr-7b-beta", "zephyr-7b-alpha"], "Zephyr", "https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha", - "a chatbot fine-tuned from Mistral by Hugging Face", + "A chatbot fine-tuned from Mistral by Hugging Face", ) + register_model_info( - ["qwen-14b-chat"], - "Qwen", - "https://huggingface.co/Qwen/Qwen-14B-Chat", - "a large language model by Alibaba Cloud", + ["notus-7b-v1"], + "Notus", + "https://huggingface.co/argilla/notus-7b-v1", + "A chatbot fine-tuned from Zephyr SFT by Argilla", ) + +register_model_info( + ["catppt"], + "CatPPT", + "https://huggingface.co/rishiraj/CatPPT", + "A chatbot fine-tuned from a SLERP merged model by Rishiraj Acharya", +) + register_model_info( - ["codellama-34b-instruct", "codellama-13b-instruct", "codellama-7b-instruct"], + ["TinyLlama"], + "TinyLlama", + "https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "The TinyLlama project is an open endeavor to pretrain a 1.1B Llama model on 3 trillion tokens.", +) + +register_model_info( + [ + "codellama-70b-instruct", + "codellama-34b-instruct", + "codellama-13b-instruct", + "codellama-7b-instruct", + ], "Code Llama", "https://ai.meta.com/blog/code-llama-large-language-model-coding/", - "open foundation models for code by Meta", + "Open foundation models for code by Meta", ) + register_model_info( ["wizardlm-70b", "wizardlm-30b", "wizardlm-13b"], "WizardLM", "https://github.com/nlpxucan/WizardLM", - "an instruction-following LLM using evol-instruct by Microsoft", + "An instruction-following LLM using evol-instruct by Microsoft", ) + register_model_info( ["wizardcoder-15b-v1.0"], "WizardLM", "https://github.com/nlpxucan/WizardLM/tree/main/WizardCoder", "Empowering Code Large Language Models with Evol-Instruct", ) + register_model_info( ["mpt-7b-chat", "mpt-30b-chat"], "MPT-Chat", "https://www.mosaicml.com/blog/mpt-30b", - "a chatbot fine-tuned from MPT by MosaicML", + "A chatbot fine-tuned from MPT by MosaicML", ) + register_model_info( ["guanaco-33b", "guanaco-65b"], "Guanaco", "https://github.com/artidoro/qlora", - "a model fine-tuned with QLoRA by UW", + "A model fine-tuned with QLoRA by UW", ) + register_model_info( ["gpt4all-13b-snoozy"], "GPT4All-Snoozy", "https://github.com/nomic-ai/gpt4all", - "a finetuned LLaMA model on assistant style data by Nomic AI", + "A finetuned LLaMA model on assistant style data by Nomic AI", ) + register_model_info( ["koala-13b"], "Koala", "https://bair.berkeley.edu/blog/2023/04/03/koala", - "a dialogue model for academic research by BAIR", + "A dialogue model for academic research by BAIR", ) + register_model_info( ["RWKV-4-Raven-14B"], "RWKV-4-Raven", "https://huggingface.co/BlinkDL/rwkv-4-raven", - "an RNN with transformer-level LLM performance", -) -register_model_info( - ["chatglm-6b", "chatglm2-6b"], - "ChatGLM", - "https://chatglm.cn/blog", - "an open bilingual dialogue language model by Tsinghua University", + "An RNN with transformer-level LLM performance", ) + register_model_info( ["alpaca-13b"], "Alpaca", "https://crfm.stanford.edu/2023/03/13/alpaca.html", - "a model fine-tuned from LLaMA on instruction-following demonstrations by Stanford", + "A model fine-tuned from LLaMA on instruction-following demonstrations by Stanford", ) + register_model_info( ["oasst-pythia-12b"], "OpenAssistant (oasst)", "https://open-assistant.io", - "an Open Assistant for everyone by LAION", + "An Open Assistant for everyone by LAION", ) + register_model_info( ["oasst-sft-7-llama-30b"], "OpenAssistant (oasst)", "https://open-assistant.io", - "an Open Assistant for everyone by LAION", + "An Open Assistant for everyone by LAION", ) + +register_model_info( + ["palm-2"], + "PaLM 2 Chat", + "https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023", + "PaLM 2 for Chat (chat-bison@001) by Google", +) + register_model_info( ["openchat-3.5"], "OpenChat 3.5", @@ -198,68 +395,79 @@ def get_model_info(name: str) -> ModelInfo: ["llama-7b", "llama-13b"], "LLaMA", "https://arxiv.org/abs/2302.13971", - "open and efficient foundation language models by Meta", + "Open and efficient foundation language models by Meta", ) + register_model_info( ["open-llama-7b-v2-open-instruct", "open-llama-7b-open-instruct"], "Open LLaMa (Open Instruct)", "https://medium.com/vmware-data-ml-blog/starter-llm-for-the-enterprise-instruction-tuning-openllama-7b-d05fc3bbaccc", "Open LLaMa fine-tuned on instruction-following data by VMware", ) + register_model_info( ["dolly-v2-12b"], "Dolly", "https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm", - "an instruction-tuned open large language model by Databricks", + "An instruction-tuned open large language model by Databricks", ) + register_model_info( ["stablelm-tuned-alpha-7b"], "StableLM", "https://github.com/stability-AI/stableLM", "Stability AI language models", ) + register_model_info( ["codet5p-6b"], "CodeT5p-6b", "https://huggingface.co/Salesforce/codet5p-6b", "Code completion model released by Salesforce", ) + register_model_info( ["fastchat-t5-3b", "fastchat-t5-3b-v1.0"], "FastChat-T5", "https://huggingface.co/lmsys/fastchat-t5-3b-v1.0", - "a chat assistant fine-tuned from FLAN-T5 by LMSYS", + "A chat assistant fine-tuned from FLAN-T5 by LMSYS", ) + register_model_info( ["phoenix-inst-chat-7b"], "Phoenix-7B", "https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b", - "a multilingual chat assistant fine-tuned from Bloomz to democratize ChatGPT across languages by CUHK(SZ)", + "A multilingual chat assistant fine-tuned from Bloomz to democratize ChatGPT across languages by CUHK(SZ)", ) + register_model_info( ["realm-7b-v1"], "ReaLM", "https://github.com/FreedomIntelligence/ReaLM", "A chatbot fine-tuned from LLaMA2 with data generated via iterative calls to UserGPT and ChatGPT by CUHK(SZ) and SRIBD.", ) + register_model_info( ["billa-7b-sft"], "BiLLa-7B-SFT", "https://huggingface.co/Neutralzz/BiLLa-7B-SFT", - "an instruction-tuned bilingual LLaMA with enhanced reasoning ability by an independent researcher", + "An instruction-tuned bilingual LLaMA with enhanced reasoning ability by an independent researcher", ) + register_model_info( ["h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2"], "h2oGPT-GM-7b", "https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2", - "an instruction-tuned OpenLLaMA with enhanced conversational ability by H2O.ai", + "An instruction-tuned OpenLLaMA with enhanced conversational ability by H2O.ai", ) + register_model_info( ["baize-v2-7b", "baize-v2-13b"], "Baize v2", "https://github.com/project-baize/baize-chatbot#v2", "A chatbot fine-tuned from LLaMA with ChatGPT self-chat data and Self-Disillation with Feedback (SDF) by UCSD and SYSU.", ) + register_model_info( [ "airoboros-l2-7b-2.1", @@ -269,8 +477,20 @@ def get_model_info(name: str) -> ModelInfo: ], "airoboros", "https://huggingface.co/jondurbin/airoboros-l2-70b-2.1", - "an instruction-tuned LlaMa model tuned with 100% synthetic instruction-response pairs from GPT4", + "An instruction-tuned LlaMa model tuned with 100% synthetic instruction-response pairs from GPT4", ) + +register_model_info( + [ + "spicyboros-7b-2.2", + "spicyboros-13b-2.2", + "spicyboros-70b-2.2", + ], + "spicyboros", + "https://huggingface.co/jondurbin/spicyboros-70b-2.2", + "De-aligned versions of the airoboros models", +) + register_model_info( [ "spicyboros-7b-2.2", @@ -287,18 +507,21 @@ def get_model_info(name: str) -> ModelInfo: "https://huggingface.co/OptimalScale/robin-7b-v2-delta", "A chatbot fine-tuned from LLaMA-7b, achieving competitive performance on chitchat, commonsense reasoning and instruction-following tasks, by OptimalScale, HKUST.", ) + register_model_info( ["manticore-13b-chat"], "Manticore 13B Chat", "https://huggingface.co/openaccess-ai-collective/manticore-13b-chat-pyg", "A chatbot fine-tuned from LlaMa across several CoT and chat datasets.", ) + register_model_info( ["redpajama-incite-7b-chat"], "RedPajama-INCITE-7B-Chat", "https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Chat", "A chatbot fine-tuned from RedPajama-INCITE-7B-Base by Together", ) + register_model_info( [ "falcon-7b", @@ -312,30 +535,42 @@ def get_model_info(name: str) -> ModelInfo: "https://huggingface.co/tiiuae/falcon-180B", "TII's flagship series of large language models", ) + register_model_info( ["tigerbot-7b-sft"], "Tigerbot", "https://huggingface.co/TigerResearch/tigerbot-7b-sft", - "TigerBot is a large-scale language model (LLM) with multiple languages and tasks.", + "A large-scale language model (LLM) with multiple languages and tasks.", ) + register_model_info( ["internlm-chat-7b", "internlm-chat-7b-8k"], "InternLM", "https://huggingface.co/internlm/internlm-chat-7b", - "InternLM is a multi-language large-scale language model (LLM), developed by SHLAB.", + "A multi-language large-scale language model (LLM), developed by SHLAB.", ) + register_model_info( ["Qwen-7B-Chat"], "Qwen", "https://huggingface.co/Qwen/Qwen-7B-Chat", - "Qwen is a multi-language large-scale language model (LLM), developed by Damo Academy.", + "A multi-language large-scale language model (LLM), developed by Damo Academy.", ) + register_model_info( ["Llama2-Chinese-13b-Chat", "LLama2-Chinese-13B"], "Llama2-Chinese", "https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat", - "Llama2-Chinese is a multi-language large-scale language model (LLM), developed by FlagAlpha.", + "A multi-language large-scale language model (LLM), developed by FlagAlpha.", ) + +register_model_info( + ["Chinese-Alpaca-2-7B", "Chinese-Alpaca-2-13B"], + "Chinese-Alpaca", + "https://huggingface.co/hfl/chinese-alpaca-2-13b", + "New extended Chinese vocabulary beyond Llama-2, open-sourcing the Chinese LLaMA-2 and Alpaca-2 LLMs.", +) + register_model_info( ["Chinese-Alpaca-2-7B", "Chinese-Alpaca-2-13B"], "Chinese-Alpaca", @@ -346,13 +581,108 @@ def get_model_info(name: str) -> ModelInfo: ["Vigogne-2-7B-Instruct", "Vigogne-2-13B-Instruct"], "Vigogne-Instruct", "https://huggingface.co/bofenghuang/vigogne-2-7b-instruct", - "Vigogne-Instruct is a French large language model (LLM) optimized for instruction-following, developed by Bofeng Huang", + "A French large language model (LLM) optimized for instruction-following, developed by Bofeng Huang", ) + register_model_info( ["Vigogne-2-7B-Chat", "Vigogne-2-13B-Chat"], "Vigogne-Chat", "https://huggingface.co/bofenghuang/vigogne-2-7b-chat", - "Vigogne-Chat is a French large language model (LLM) optimized for instruction-following and multi-turn dialogues, developed by Bofeng Huang", + "A French large language model (LLM) optimized for instruction-following and multi-turn dialogues, developed by Bofeng Huang", +) + +register_model_info( + ["stable-vicuna-13B-HF"], + "stable-vicuna", + "https://huggingface.co/TheBloke/stable-vicuna-13B-HF", + "A Vicuna model fine-tuned using RLHF via PPO on various conversational and instructional datasets.", +) + +register_model_info( + ["deluxe-chat-v1", "deluxe-chat-v1.1", "deluxe-chat-v1.2"], + "DeluxeChat", + "", + "Deluxe Chat", +) + +register_model_info( + [ + "Xwin-LM-7B-V0.1", + "Xwin-LM-13B-V0.1", + "Xwin-LM-70B-V0.1", + "Xwin-LM-7B-V0.2", + "Xwin-LM-13B-V0.2", + ], + "Xwin-LM", + "https://github.com/Xwin-LM/Xwin-LM", + "Chat models developed by Xwin-LM team", +) + +register_model_info( + ["lemur-70b-chat"], + "Lemur-Chat", + "https://huggingface.co/OpenLemur/lemur-70b-chat-v1", + "An openly accessible language model optimized for both natural language and coding capabilities ", +) + +register_model_info( + ["Mistral-7B-OpenOrca"], + "Open-Orca", + "https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca", + "A fine-tune of [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) using [OpenOrca dataset](https://huggingface.co/datasets/Open-Orca/OpenOrca)", +) + +register_model_info( + ["dolphin-2.2.1-mistral-7b"], + "dolphin-mistral", + "https://huggingface.co/ehartford/dolphin-2.2.1-mistral-7b", + "An uncensored fine-tuned Mistral 7B", +) + +register_model_info( + [ + "AquilaChat-7B", + "AquilaChat2-7B", + "AquilaChat2-34B", + ], + "Aquila-Chat", + "https://huggingface.co/BAAI/AquilaChat2-34B", + "Chat models developed by BAAI team", +) + +register_model_info( + ["xDAN-L1-Chat-RL-v1"], + "xDAN-L1-Chat", + "https://huggingface.co/xDAN-AI/xDAN-L1-Chat-RL-v1", + "A large language chat model created by xDAN-AI.", +) + +register_model_info( + ["MetaMath-70B-V1.0", "MetaMath-7B-V1.0"], + "MetaMath", + "https://huggingface.co/meta-math", + "A finetune of Llama2 on [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA) that specializes in mathematical reasoning.", +) + +register_model_info( + ["Yuan2-2B-hf", "Yuan2-51B-hf", "Yuan2-102B-hf"], + "IEIYuan", + "https://huggingface.co/IEITYuan", + "A Basemodel developed by IEI.", +) + +register_model_info( + [ + "llava-v1.6-34b", + "llava-v1.6-vicuna-13b", + "llava-v1.6-vicuna-7b", + "llava-v1.6-mistral-7b", + "llava-v1.5-13b", + "llava-v1.5-7b", + ], + "LLaVA", + "https://github.com/haotian-liu/LLaVA", + "an open large language and vision assistant", ) register_model_info( ["stable-vicuna-13B-HF"], diff --git a/fastchat/model/model_yuan2.py b/fastchat/model/model_yuan2.py new file mode 100644 index 000000000..25b3e13f8 --- /dev/null +++ b/fastchat/model/model_yuan2.py @@ -0,0 +1,139 @@ +import gc +from threading import Thread +from typing import Iterable + +import torch +import transformers +from transformers import TextIteratorStreamer, GenerationConfig + +from fastchat.utils import is_partial_stop + + +@torch.inference_mode() +def generate_stream_yuan2( + model, + tokenizer, + params, + device, + context_len=2048, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + len_prompt = len(prompt) + temperature = float(params.get("temperature", 1)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 0)) + top_k = int(params.get("top_k", 1)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 512)) + stop_str = params.get("stop", "") + echo = bool(params.get("echo", True)) + stop_token_ids = params.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer("")["input_ids"][0]) + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + max_src_len = context_len - max_new_tokens - 8 + + input_ids = input_ids[-max_src_len:] # truncate from the left + attention_mask = attention_mask[-max_src_len:] # truncate from the left + input_echo_len = len(input_ids) + + decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) + + generation_config = GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=temperature >= 1.2, + temperature=temperature, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=10, + top_p=top_p, + top_k=top_k, + ) + + generation_kwargs = dict( + inputs=input_ids, + attention_mask=attention_mask, + streamer=streamer, + generation_config=generation_config, + ) + + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + if echo: + # means keep the prompt + output = prompt + else: + output = "" + + for i, new_text in enumerate(streamer): + output += new_text + if i % stream_interval == 0: + if echo: + rfind_start = len_prompt + else: + rfind_start = 0 + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # prevent yielding partial stop sequence + if not partially_stopped: + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + output = output.strip() + + # finish stream event, which contains finish reason + if i == max_new_tokens - 1: + finish_reason = "length" + elif partially_stopped: + finish_reason = None + else: + finish_reason = "stop" + + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + + # clean + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() diff --git a/fastchat/protocol/openai_api_protocol.py b/fastchat/protocol/openai_api_protocol.py index b2a4d25d4..99e93a40a 100644 --- a/fastchat/protocol/openai_api_protocol.py +++ b/fastchat/protocol/openai_api_protocol.py @@ -57,7 +57,11 @@ class LogProbs(BaseModel): class ChatCompletionRequest(BaseModel): model: str - messages: Union[str, List[Dict[str, str]]] + messages: Union[ + str, + List[Dict[str, str]], + List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]], + ] temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 top_k: Optional[int] = -1 diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py index 3dbb8a690..1e319f0a2 100644 --- a/fastchat/serve/api_provider.py +++ b/fastchat/serve/api_provider.py @@ -1,16 +1,93 @@ """Call API providers.""" +import json import os import random import time +import requests + from fastchat.utils import build_logger -from fastchat.constants import WORKER_API_TIMEOUT logger = build_logger("gradio_web_server", "gradio_web_server.log") +def get_api_provider_stream_iter( + conv, + model_name, + model_api_dict, + temperature, + top_p, + max_new_tokens, +): + if model_api_dict["api_type"] == "openai": + prompt = conv.to_openai_api_messages() + stream_iter = openai_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_base=model_api_dict["api_base"], + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "anthropic": + prompt = conv.get_prompt() + stream_iter = anthropic_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + elif model_api_dict["api_type"] == "gemini": + stream_iter = gemini_api_stream_iter( + model_api_dict["model_name"], + conv, + temperature, + top_p, + max_new_tokens, + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "bard": + prompt = conv.to_openai_api_messages() + stream_iter = bard_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "mistral": + prompt = conv.to_openai_api_messages() + stream_iter = mistral_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + elif model_api_dict["api_type"] == "nvidia": + prompt = conv.to_openai_api_messages() + stream_iter = nvidia_api_stream_iter( + model_name, + prompt, + temperature, + top_p, + max_new_tokens, + model_api_dict["api_base"], + ) + elif model_api_dict["api_type"] == "ai2": + prompt = conv.to_openai_api_messages() + stream_iter = ai2_api_stream_iter( + model_name, + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_base=model_api_dict["api_base"], + api_key=model_api_dict["api_key"], + ) + else: + raise NotImplementedError() + + return stream_iter + + def openai_api_stream_iter( model_name, messages, @@ -22,8 +99,19 @@ def openai_api_stream_iter( ): import openai - openai.api_base = api_base or "https://api.openai.com/v1" - openai.api_key = api_key or os.environ["OPENAI_API_KEY"] + api_key = api_key or os.environ["OPENAI_API_KEY"] + + if "azure" in model_name: + client = openai.AzureOpenAI( + api_version="2023-07-01-preview", + azure_endpoint=api_base or "https://api.openai.com/v1", + api_key=api_key, + ) + else: + client = openai.OpenAI( + base_url=api_base or "https://api.openai.com/v1", api_key=api_key + ) + if model_name == "gpt-4-turbo": model_name = "gpt-4-1106-preview" @@ -37,7 +125,7 @@ def openai_api_stream_iter( } logger.info(f"==== request ====\n{gen_params}") - res = openai.ChatCompletion.create( + res = client.chat.completions.create( model=model_name, messages=messages, temperature=temperature, @@ -46,12 +134,13 @@ def openai_api_stream_iter( ) text = "" for chunk in res: - text += chunk["choices"][0]["delta"].get("content", "") - data = { - "text": text, - "error_code": 0, - } - yield data + if len(chunk.choices) > 0: + text += chunk.choices[0].delta.content or "" + data = { + "text": text, + "error_code": 0, + } + yield data def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens): @@ -88,43 +177,278 @@ def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_to yield data -def init_palm_chat(model_name): - import vertexai # pip3 install google-cloud-aiplatform - from vertexai.preview.language_models import ChatModel - - project_id = os.environ["GCP_PROJECT_ID"] - location = "us-central1" - vertexai.init(project=project_id, location=location) - - chat_model = ChatModel.from_pretrained(model_name) - chat = chat_model.start_chat(examples=[]) - return chat +def gemini_api_stream_iter( + model_name, conv, temperature, top_p, max_new_tokens, api_key=None +): + import google.generativeai as genai # pip install google-generativeai + if api_key is None: + api_key = os.environ["GEMINI_API_KEY"] + genai.configure(api_key=api_key) -def palm_api_stream_iter(chat, message, temperature, top_p, max_new_tokens): - parameters = { + generation_config = { "temperature": temperature, - "top_p": top_p, "max_output_tokens": max_new_tokens, + "top_p": top_p, } - gen_params = { - "model": "palm-2", - "prompt": message, + params = { + "model": model_name, + "prompt": conv, } - gen_params.update(parameters) - logger.info(f"==== request ====\n{gen_params}") + params.update(generation_config) + logger.info(f"==== request ====\n{params}") - response = chat.send_message(message, **parameters) - content = response.text + safety_settings = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + ] + model = genai.GenerativeModel( + model_name=model_name, + generation_config=generation_config, + safety_settings=safety_settings, + ) + history = [] + for role, message in conv.messages[:-2]: + history.append({"role": role, "parts": message}) + convo = model.start_chat(history=history) + response = convo.send_message(conv.messages[-2][1], stream=True) + try: + text = "" + for chunk in response: + text += chunk.text + data = { + "text": text, + "error_code": 0, + } + yield data + except Exception as e: + logger.error(f"==== error ====\n{e}") + reason = chunk.candidates + yield { + "text": f"**API REQUEST ERROR** Reason: {reason}.", + "error_code": 1, + } + + +def bard_api_stream_iter(model_name, conv, temperature, top_p, api_key=None): + del top_p # not supported + del temperature # not supported + + if api_key is None: + api_key = os.environ["BARD_API_KEY"] + + # convert conv to conv_bard + conv_bard = [] + for turn in conv: + if turn["role"] == "user": + conv_bard.append({"author": "0", "content": turn["content"]}) + elif turn["role"] == "assistant": + conv_bard.append({"author": "1", "content": turn["content"]}) + else: + raise ValueError(f"Unsupported role: {turn['role']}") + + params = { + "model": model_name, + "prompt": conv_bard, + } + logger.info(f"==== request ====\n{params}") + + try: + res = requests.post( + f"https://generativelanguage.googleapis.com/v1beta2/models/{model_name}:generateMessage?key={api_key}", + json={ + "prompt": { + "messages": conv_bard, + }, + }, + timeout=30, + ) + except Exception as e: + logger.error(f"==== error ====\n{e}") + yield { + "text": f"**API REQUEST ERROR** Reason: {e}.", + "error_code": 1, + } + + if res.status_code != 200: + logger.error(f"==== error ==== ({res.status_code}): {res.text}") + yield { + "text": f"**API REQUEST ERROR** Reason: status code {res.status_code}.", + "error_code": 1, + } + + response_json = res.json() + if "candidates" not in response_json: + logger.error(f"==== error ==== response blocked: {response_json}") + reason = response_json["filters"][0]["reason"] + yield { + "text": f"**API REQUEST ERROR** Reason: {reason}.", + "error_code": 1, + } + + response = response_json["candidates"][0]["content"] pos = 0 - while pos < len(content): - # This is a fancy way to simulate token generation latency combined - # with a Poisson process. - pos += random.randint(10, 20) - time.sleep(random.expovariate(50)) + while pos < len(response): + # simulate token streaming + pos += random.randint(3, 6) + time.sleep(0.002) data = { - "text": content[:pos], + "text": response[:pos], "error_code": 0, } yield data + + +def ai2_api_stream_iter( + model_name, + model_id, + messages, + temperature, + top_p, + max_new_tokens, + api_key=None, + api_base=None, +): + # get keys and needed values + ai2_key = api_key or os.environ.get("AI2_API_KEY") + api_base = api_base or "https://inferd.allen.ai/api/v1/infer" + + # Make requests + gen_params = { + "model": model_name, + "prompt": messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + # AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling: + # https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157 + if temperature == 0.0 and top_p < 1.0: + raise ValueError("top_p must be 1 when temperature is 0.0") + + res = requests.post( + api_base, + stream=True, + headers={"Authorization": f"Bearer {ai2_key}"}, + json={ + "model_id": model_id, + # This input format is specific to the Tulu2 model. Other models + # may require different input formats. See the model's schema + # documentation on InferD for more information. + "input": { + "messages": messages, + "opts": { + "max_tokens": max_new_tokens, + "temperature": temperature, + "top_p": top_p, + "logprobs": 1, # increase for more choices + }, + }, + }, + timeout=5, + ) + + if res.status_code != 200: + logger.error(f"unexpected response ({res.status_code}): {res.text}") + raise ValueError("unexpected response from InferD", res) + + text = "" + for line in res.iter_lines(): + if line: + part = json.loads(line) + if "result" in part and "output" in part["result"]: + for t in part["result"]["output"]["text"]: + text += t + else: + logger.error(f"unexpected part: {part}") + raise ValueError("empty result in InferD response") + + data = { + "text": text, + "error_code": 0, + } + yield data + + +def mistral_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens): + from mistralai.client import MistralClient + from mistralai.models.chat_completion import ChatMessage + + api_key = os.environ["MISTRAL_API_KEY"] + + client = MistralClient(api_key=api_key) + + # Make requests + gen_params = { + "model": model_name, + "prompt": messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + new_messages = [ + ChatMessage(role=message["role"], content=message["content"]) + for message in messages + ] + + res = client.chat_stream( + model=model_name, + temperature=temperature, + messages=new_messages, + max_tokens=max_new_tokens, + top_p=top_p, + ) + + text = "" + for chunk in res: + if chunk.choices[0].delta.content is not None: + text += chunk.choices[0].delta.content + data = { + "text": text, + "error_code": 0, + } + yield data + + +def nvidia_api_stream_iter(model_name, messages, temp, top_p, max_tokens, api_base): + assert model_name in ["llama2-70b-steerlm-chat", "yi-34b-chat"] + + api_key = os.environ["NVIDIA_API_KEY"] + headers = { + "Authorization": f"Bearer {api_key}", + "accept": "text/event-stream", + "content-type": "application/json", + } + # nvidia api does not accept 0 temperature + if temp == 0.0: + temp = 0.0001 + + payload = { + "messages": messages, + "temperature": temp, + "top_p": top_p, + "max_tokens": max_tokens, + "seed": 42, + "stream": True, + } + logger.info(f"==== request ====\n{payload}") + + response = requests.post( + api_base, headers=headers, json=payload, stream=True, timeout=1 + ) + text = "" + for line in response.iter_lines(): + if line: + data = line.decode("utf-8") + if data.endswith("[DONE]"): + break + data = json.loads(data[6:])["choices"][0]["delta"]["content"] + text += data + yield {"text": text, "error_code": 0} diff --git a/fastchat/serve/base_model_worker.py b/fastchat/serve/base_model_worker.py index 514cc8221..2fe322990 100644 --- a/fastchat/serve/base_model_worker.py +++ b/fastchat/serve/base_model_worker.py @@ -34,6 +34,7 @@ def __init__( model_names: List[str], limit_worker_concurrency: int, conv_template: str = None, + multimodal: bool = False, ): global logger, worker @@ -46,6 +47,7 @@ def __init__( self.limit_worker_concurrency = limit_worker_concurrency self.conv = self.make_conv_template(conv_template, model_path) self.conv.sep_style = int(self.conv.sep_style) + self.multimodal = multimodal self.tokenizer = None self.context_len = None self.call_ct = 0 @@ -92,6 +94,7 @@ def register_to_controller(self): "worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status(), + "multimodal": self.multimodal, } r = requests.post(url, json=data) assert r.status_code == 200 @@ -126,18 +129,18 @@ def send_heart_beat(self): self.register_to_controller() def get_queue_length(self): - if ( - self.semaphore is None - or self.semaphore._value is None - or self.semaphore._waiters is None - ): + if self.semaphore is None: return 0 else: - return ( - self.limit_worker_concurrency - - self.semaphore._value - + len(self.semaphore._waiters) + sempahore_value = ( + self.semaphore._value + if self.semaphore._value is not None + else self.limit_worker_concurrency ) + waiter_count = ( + 0 if self.semaphore._waiters is None else len(self.semaphore._waiters) + ) + return self.limit_worker_concurrency - sempahore_value + waiter_count def get_status(self): return { diff --git a/fastchat/serve/call_monitor.py b/fastchat/serve/call_monitor.py new file mode 100644 index 000000000..eb8bf2aea --- /dev/null +++ b/fastchat/serve/call_monitor.py @@ -0,0 +1,219 @@ +import json +import os +import glob +import time + +from fastapi import FastAPI +import hashlib +import asyncio + +REFRESH_INTERVAL_SEC = 60 +LOG_DIR = "/home/vicuna/fastchat_logs/server0" +# LOG_DIR = "/home/vicuna/tmp/test_env" + + +class Monitor: + """Monitor the number of calls to each model.""" + + def __init__(self, log_dir: str): + self.log_dir = log_dir + self.model_call = {} + self.user_call = {} + self.model_call_limit_global = { + "gpt-4-1106-preview": 300, + "gpt-4-0125-preview": 300, + } + self.model_call_day_limit_per_user = {"gpt-4-1106-preview": 10} + + async def update_stats(self, num_file=1) -> None: + while True: + # find the latest num_file log under log_dir + json_files = glob.glob(os.path.join(self.log_dir, "*.json")) + json_files.sort(key=os.path.getctime, reverse=True) + json_files = json_files[:num_file] + + model_call = {} + user_call = {} + for json_file in json_files: + for line in open(json_file, "r", encoding="utf-8"): + obj = json.loads(line) + if obj["type"] != "chat": + continue + if obj["model"] not in model_call: + model_call[obj["model"]] = [] + model_call[obj["model"]].append( + {"tstamp": obj["tstamp"], "user_id": obj["ip"]} + ) + if obj["ip"] not in user_call: + user_call[obj["ip"]] = [] + user_call[obj["ip"]].append( + {"tstamp": obj["tstamp"], "model": obj["model"]} + ) + + self.model_call = model_call + self.model_call_stats_hour = self.get_model_call_stats(top_k=None) + self.model_call_stats_day = self.get_model_call_stats( + top_k=None, most_recent_min=24 * 60 + ) + + self.user_call = user_call + self.user_call_stats_hour = self.get_user_call_stats(top_k=None) + self.user_call_stats_day = self.get_user_call_stats( + top_k=None, most_recent_min=24 * 60 + ) + await asyncio.sleep(REFRESH_INTERVAL_SEC) + + def get_model_call_limit(self, model: str) -> int: + if model not in self.model_call_limit_global: + return -1 + return self.model_call_limit_global[model] + + def update_model_call_limit(self, model: str, limit: int) -> bool: + if model not in self.model_call_limit_global: + return False + self.model_call_limit_global[model] = limit + return True + + def is_model_limit_reached(self, model: str) -> bool: + if model not in self.model_call_limit_global: + return False + if model not in self.model_call_stats_hour: + return False + # check if the model call limit is reached + if self.model_call_stats_hour[model] >= self.model_call_limit_global[model]: + return True + return False + + def is_user_limit_reached(self, model: str, user_id: str) -> bool: + if model not in self.model_call_day_limit_per_user: + return False + if user_id not in self.user_call_stats_day: + return False + if model not in self.user_call_stats_day[user_id]["call_dict"]: + return False + # check if the user call limit is reached + if ( + self.user_call_stats_day[user_id]["call_dict"][model] + >= self.model_call_day_limit_per_user[model] + ): + return True + return False + + def get_model_call_stats( + self, target_model=None, most_recent_min: int = 60, top_k: int = 20 + ) -> dict: + model_call_stats = {} + for model, reqs in self.model_call.items(): + if target_model is not None and model != target_model: + continue + model_call = [] + for req in reqs: + if req["tstamp"] < time.time() - most_recent_min * 60: + continue + model_call.append(req["tstamp"]) + model_call_stats[model] = len(model_call) + if top_k is not None: + top_k_model = sorted( + model_call_stats, key=lambda x: model_call_stats[x], reverse=True + )[:top_k] + model_call_stats = {model: model_call_stats[model] for model in top_k_model} + return model_call_stats + + def get_user_call_stats( + self, target_model=None, most_recent_min: int = 60, top_k: int = 20 + ) -> dict: + user_call_stats = {} + for user_id, reqs in self.user_call.items(): + user_model_call = {"call_dict": {}} + for req in reqs: + if req["tstamp"] < time.time() - most_recent_min * 60: + continue + if target_model is not None and req["model"] != target_model: + continue + if req["model"] not in user_model_call["call_dict"]: + user_model_call["call_dict"][req["model"]] = 0 + user_model_call["call_dict"][req["model"]] += 1 + + user_model_call["total_calls"] = sum(user_model_call["call_dict"].values()) + if user_model_call["total_calls"] > 0: + user_call_stats[user_id] = user_model_call + if top_k is not None: + top_k_user = sorted( + user_call_stats, + key=lambda x: user_call_stats[x]["total_calls"], + reverse=True, + )[:top_k] + user_call_stats = { + user_id: user_call_stats[user_id] for user_id in top_k_user + } + return user_call_stats + + def get_num_users(self, most_recent_min: int = 60) -> int: + user_call_stats = self.get_user_call_stats( + most_recent_min=most_recent_min, top_k=None + ) + return len(user_call_stats) + + +monitor = Monitor(log_dir=LOG_DIR) +app = FastAPI() + + +@app.on_event("startup") +async def app_startup(): + asyncio.create_task(monitor.update_stats(2)) + + +@app.get("/get_model_call_limit/{model}") +async def get_model_call_limit(model: str): + return {"model_call_limit": {model: monitor.get_model_call_limit(model)}} + + +@app.get("/update_model_call_limit/{model}/{limit}") +async def update_model_call_limit(model: str, limit: int): + if not monitor.update_model_call_limit(model, limit): + return {"success": False} + return {"success": True} + + +@app.get("/is_limit_reached") +async def is_limit_reached(model: str, user_id: str): + if monitor.is_model_limit_reached(model): + return { + "is_limit_reached": True, + "reason": f"MODEL_HOURLY_LIMIT ({model}): {monitor.get_model_call_limit(model)}", + } + if monitor.is_user_limit_reached(model, user_id): + return { + "is_limit_reached": True, + "reason": f"USER_DAILY_LIMIT ({model}): {monitor.model_call_day_limit_per_user[model]}", + } + return {"is_limit_reached": False} + + +@app.get("/get_num_users_hr") +async def get_num_users(): + return {"num_users": len(monitor.user_call_stats_hour)} + + +@app.get("/get_num_users_day") +async def get_num_users_day(): + return {"num_users": len(monitor.user_call_stats_day)} + + +@app.get("/get_user_call_stats") +async def get_user_call_stats( + model: str = None, most_recent_min: int = 60, top_k: int = None +): + return { + "user_call_stats": monitor.get_user_call_stats(model, most_recent_min, top_k) + } + + +@app.get("/get_model_call_stats") +async def get_model_call_stats( + model: str = None, most_recent_min: int = 60, top_k: int = None +): + return { + "model_call_stats": monitor.get_model_call_stats(model, most_recent_min, top_k) + } diff --git a/fastchat/serve/controller.py b/fastchat/serve/controller.py index a67da62c4..42d928403 100644 --- a/fastchat/serve/controller.py +++ b/fastchat/serve/controller.py @@ -52,6 +52,7 @@ class WorkerInfo: queue_length: int check_heart_beat: bool last_heart_beat: str + multimodal: bool def heart_beat_controller(controller): @@ -72,7 +73,11 @@ def __init__(self, dispatch_method: str): self.heart_beat_thread.start() def register_worker( - self, worker_name: str, check_heart_beat: bool, worker_status: dict + self, + worker_name: str, + check_heart_beat: bool, + worker_status: dict, + multimodal: bool, ): if worker_name not in self.worker_info: logger.info(f"Register a new worker: {worker_name}") @@ -90,6 +95,7 @@ def register_worker( worker_status["queue_length"], check_heart_beat, time.time(), + multimodal, ) logger.info(f"Register done: {worker_name}, {worker_status}") @@ -116,7 +122,9 @@ def refresh_all_workers(self): self.worker_info = {} for w_name, w_info in old_info.items(): - if not self.register_worker(w_name, w_info.check_heart_beat, None): + if not self.register_worker( + w_name, w_info.check_heart_beat, None, w_info.multimodal + ): logger.info(f"Remove stale worker: {w_name}") def list_models(self): @@ -127,6 +135,24 @@ def list_models(self): return list(model_names) + def list_multimodal_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + if w_info.multimodal: + model_names.update(w_info.model_names) + + return list(model_names) + + def list_language_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + if not w_info.multimodal: + model_names.update(w_info.model_names) + + return list(model_names) + def get_worker_address(self, model_name: str): if self.dispatch_method == DispatchMethod.LOTTERY: worker_names = [] @@ -263,7 +289,10 @@ def worker_api_generate_stream(self, params): async def register_worker(request: Request): data = await request.json() controller.register_worker( - data["worker_name"], data["check_heart_beat"], data.get("worker_status", None) + data["worker_name"], + data["check_heart_beat"], + data.get("worker_status", None), + data.get("multimodal", False), ) @@ -278,6 +307,18 @@ async def list_models(): return {"models": models} +@app.post("/list_multimodal_models") +async def list_multimodal_models(): + models = controller.list_multimodal_models() + return {"models": models} + + +@app.post("/list_language_models") +async def list_language_models(): + models = controller.list_language_models() + return {"models": models} + + @app.post("/get_worker_address") async def get_worker_address(request: Request): data = await request.json() diff --git a/fastchat/serve/example_images/distracted.jpg b/fastchat/serve/example_images/distracted.jpg new file mode 100644 index 000000000..382c888a0 Binary files /dev/null and b/fastchat/serve/example_images/distracted.jpg differ diff --git a/fastchat/serve/example_images/fridge.jpg b/fastchat/serve/example_images/fridge.jpg new file mode 100644 index 000000000..8ed943e8b Binary files /dev/null and b/fastchat/serve/example_images/fridge.jpg differ diff --git a/fastchat/serve/gradio_block_arena_anony.py b/fastchat/serve/gradio_block_arena_anony.py index 48e49deef..c9d8aba6b 100644 --- a/fastchat/serve/gradio_block_arena_anony.py +++ b/fastchat/serve/gradio_block_arena_anony.py @@ -27,8 +27,8 @@ disable_btn, invisible_btn, acknowledgment_md, - ip_expiration_dict, get_ip, + get_model_description_md, ) from fastchat.utils import ( build_logger, @@ -54,8 +54,8 @@ def load_demo_side_by_side_anony(models_, url_params): states = (None,) * num_sides selector_updates = ( - gr.Markdown.update(visible=True), - gr.Markdown.update(visible=True), + gr.Markdown(visible=True), + gr.Markdown(visible=True), ) return states + selector_updates @@ -73,13 +73,13 @@ def vote_last_response(states, vote_type, model_selectors, request: gr.Request): fout.write(json.dumps(data) + "\n") if ":" not in model_selectors[0]: - for i in range(15): + for i in range(5): names = ( "### Model A: " + states[0].model_name, "### Model B: " + states[1].model_name, ) yield names + ("",) + (disable_btn,) * 4 - time.sleep(0.2) + time.sleep(0.1) else: names = ( "### Model A: " + states[0].model_name, @@ -160,29 +160,58 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re SAMPLING_WEIGHTS = { # tier 0 "gpt-4": 4, + "gpt-4-0314": 4, + "gpt-4-0613": 4, "gpt-4-turbo": 4, - "gpt-3.5-turbo": 2, + "gpt-4-1106-preview": 4, + "gpt-4-0125-preview": 4, + "gpt-3.5-turbo-0613": 2, "gpt-3.5-turbo-1106": 2, - "claude-2": 8, + "gpt-3.5-turbo-0125": 4, + "claude-2.1": 4, + "claude-2.0": 2, "claude-1": 2, - "claude-instant-1": 8, + "claude-instant-1": 2, + "gemini-pro": 4, + "gemini-pro-dev-api": 4, + "bard-jan-24-gemini-pro": 4, + "bard-feb-2024": 4, + "mixtral-8x7b-instruct-v0.1": 4, + "mistral-medium": 4, + "qwen1.5-72b-chat": 4, + "qwen1.5-7b-chat": 2, + "qwen1.5-4b-chat": 2, + "nous-hermes-2-mixtral-8x7b-dpo": 2, + "deepseek-llm-67b-chat": 2, + "stripedhyena-nous-7b": 2, + "openchat-3.5-0106": 2, + "mistral-7b-instruct-v0.2": 2, + "solar-10.7b-instruct-v1.0": 2, + "dolphin-2.2.1-mistral-7b": 2, + "starling-lm-7b-alpha": 2, + "tulu-2-dpo-70b": 2, + "yi-34b-chat": 2, "zephyr-7b-beta": 2, - "openchat-3.5": 2, # tier 1 - "deluxe-chat-v1.1": 2, - "palm-2": 1.5, - "llama-2-70b-chat": 1.5, - "llama-2-13b-chat": 1.5, + "deluxe-chat-v1.2": 4, + "llama-2-70b-chat": 4, + "llama-2-13b-chat": 2, + "llama-2-7b-chat": 2, + "mistral-7b-instruct": 2, "codellama-34b-instruct": 1.5, - "vicuna-33b": 8, + "vicuna-33b": 2, "vicuna-13b": 1.5, - "wizardlm-70b": 1.5, "wizardlm-13b": 1.5, "qwen-14b-chat": 1.5, - "mistral-7b-instruct": 1.5, # tier 2 + "pplx-7b-online": 1, + "pplx-70b-online": 1, + "openhermes-2.5-mistral-7b": 1.0, + "llama2-70b-steerlm-chat": 1.0, + "chatglm3-6b": 1.0, + "openchat-3.5": 1.0, + "wizardlm-70b": 1.0, "vicuna-7b": 1.0, - "llama-2-7b-chat": 1.0, "chatglm2-6b": 1.0, # deprecated "zephyr-7b-alpha": 1.5, @@ -201,18 +230,162 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re "llama-13b": 0.1, "chatglm-6b": 0.5, "deluxe-chat-v1": 4, + "palm-2": 1.5, } # target model sampling weights will be boosted. BATTLE_TARGETS = { - "gpt-4": {"claude-2"}, - "gpt-4-turbo": {"gpt-4", "gpt-3.5-turbo"}, - "gpt-3.5-turbo": {"claude-instant-1", "gpt-4", "claude-2"}, - "claude-2": {"gpt-4", "gpt-3.5-turbo", "claude-1"}, - "claude-1": {"claude-2", "gpt-4", "gpt-3.5-turbo"}, - "claude-instant-1": {"gpt-3.5-turbo", "claude-2"}, - "deluxe-chat-v1.1": {"gpt-4"}, - "openchat-3.5": {"gpt-3.5-turbo", "llama-2-70b-chat", "zephyr-7b-beta"}, + "gpt-4": {"gpt-4-0314", "claude-2.1", "gpt-4-1106-preview"}, + "gpt-4-0613": {"gpt-4-0314", "claude-2.1", "gpt-4-1106-preview"}, + "gpt-4-0314": { + "gpt-4-1106-preview", + "gpt-4-0613", + "claude-2.1", + "gpt-3.5-turbo-0613", + }, + "gpt-4-1106-preview": { + "gpt-4-0613", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "claude-2.1", + "bard-feb-2024", + }, + "gpt-4-0125-preview": { + "gpt-4-1106-preview", + "gpt-4-0613", + "gpt-3.5-turbo-0613", + "claude-2.1", + "mistral-medium", + "bard-feb-2024", + }, + "gpt-3.5-turbo-0613": {"claude-instant-1", "gpt-4-0613", "claude-2.1"}, + "gpt-3.5-turbo-1106": {"gpt-4-0613", "claude-instant-1", "gpt-3.5-turbo-0613"}, + "gpt-3.5-turbo-0125": { + "gpt-4-0613", + "gpt-4-1106-preview", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "mixtral-8x7b-instruct-v0.1", + }, + "qwen1.5-72b-chat": { + "gpt-3.5-turbo-0125", + "gpt-4-0613", + "gpt-4-1106-preview", + "llama-2-70b-chat", + "mixtral-8x7b-instruct-v0.1", + "mistral-medium", + "yi-34b-chat", + }, + "qwen1.5-7b-chat": { + "gpt-3.5-turbo-0125", + "starling-lm-7b-alpha", + "llama-2-70b-chat", + "openchat-3.5", + "mixtral-8x7b-instruct-v0.1", + }, + "qwen1.5-4b-chat": { + "llama-2-70b-chat", + "llama-2-13b-chat", + "llama-2-7b-chat", + "openchat-3.5", + }, + "openchat-3.5-0106": { + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-0613", + "llama-2-70b-chat", + "openchat-3.5", + "mixtral-8x7b-instruct-v0.1", + }, + "nous-hermes-2-mixtral-8x7b-dpo": { + "gpt-4-1106-preview", + "claude-2.1", + "mistral-medium", + "gpt-3.5-turbo-0613", + "mixtral-8x7b-instruct-v0.1", + }, + "mistral-7b-instruct-v0.2": { + "llama-2-70b-chat", + "mixtral-8x7b-instruct-v0.1", + "starling-lm-7b-alpha", + "openhermes-2.5-mistral-7b", + }, + "solar-10.7b-instruct-v1.0": { + "mixtral-8x7b-instruct-v0.1", + "gpt-3.5-turbo-0613", + "llama-2-70b-chat", + }, + "mistral-medium": { + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-0613", + "gpt-4-1106-preview", + "mixtral-8x7b-instruct-v0.1", + "bard-feb-2024", + }, + "mixtral-8x7b-instruct-v0.1": { + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-0613", + "gpt-4-1106-preview", + "llama-2-70b-chat", + }, + "claude-2.1": {"gpt-4-1106-preview", "gpt-4-0613", "claude-1"}, + "claude-2.0": {"gpt-4-1106-preview", "gpt-4-0613", "claude-1"}, + "claude-1": {"claude-2.1", "gpt-4-0613", "gpt-3.5-turbo-0613"}, + "claude-instant-1": {"gpt-3.5-turbo-0125", "claude-2.1"}, + "gemini-pro": {"gpt-4-1106-preview", "gpt-4-0613", "gpt-3.5-turbo-0613"}, + "gemini-pro-dev-api": { + "gpt-4-1106-preview", + "gpt-4-0613", + "gpt-3.5-turbo-0613", + "bard-feb-2024", + }, + "bard-jan-24-gemini-pro": { + "gpt-4-1106-preview", + "gpt-4-0613", + "gpt-3.5-turbo-0613", + "gemini-pro-dev-api", + }, + "bard-feb-2024": { + "gpt-4-1106-preview", + "gpt-4-0613", + "gpt-3.5-turbo-0613", + "bard-jan-24-gemini-pro", + }, + "deepseek-llm-67b-chat": { + "gpt-4-1106-preview", + "gpt-4-turbo", + "gpt-3.5-turbo-0613", + }, + "llama2-70b-steerlm-chat": { + "llama-2-70b-chat", + "tulu-2-dpo-70b", + "yi-34b-chat", + }, + "stripedhyena-nous-7b": { + "starling-lm-7b-alpha", + "openhermes-2.5-mistral-7b", + "mistral-7b-instruct", + "llama-2-7b-chat", + }, + "deluxe-chat-v1.1": {"gpt-4-0613", "gpt-4-1106-preview"}, + "deluxe-chat-v1.2": {"gpt-4-0613", "gpt-4-1106-preview"}, + "pplx-7b-online": {"gpt-3.5-turbo-0125", "llama-2-70b-chat"}, + "pplx-70b-online": {"gpt-3.5-turbo-0125", "llama-2-70b-chat"}, + "openhermes-2.5-mistral-7b": { + "gpt-3.5-turbo-0613", + "openchat-3.5", + "zephyr-7b-beta", + }, + "dolphin-2.2.1-mistral-7b": { + "gpt-3.5-turbo-0613", + "vicuna-33b", + "starling-lm-7b-alpha", + "openhermes-2.5-mistral-7b", + }, + "starling-lm-7b-alpha": {"gpt-3.5-turbo-0613", "openchat-3.5", "zephyr-7b-beta"}, + "tulu-2-dpo-70b": {"gpt-3.5-turbo-0613", "vicuna-33b", "claude-instant-1"}, + "yi-34b-chat": {"gpt-3.5-turbo-0613", "vicuna-33b", "claude-instant-1"}, + "openchat-3.5": {"gpt-3.5-turbo-0613", "llama-2-70b-chat", "zephyr-7b-beta"}, + "chatglm3-6b": {"yi-34b-chat", "qwen-14b-chat"}, "qwen-14b-chat": {"vicuna-13b", "llama-2-13b-chat", "llama-2-70b-chat"}, "zephyr-7b-alpha": {"mistral-7b-instruct", "llama-2-13b-chat"}, "zephyr-7b-beta": { @@ -221,7 +394,7 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re "llama-2-7b-chat", "wizardlm-13b", }, - "llama-2-70b-chat": {"gpt-3.5-turbo", "vicuna-33b", "claude-instant-1"}, + "llama-2-70b-chat": {"gpt-3.5-turbo-0125", "claude-instant-1"}, "llama-2-13b-chat": {"mistral-7b-instruct", "vicuna-13b", "llama-2-70b-chat"}, "llama-2-7b-chat": {"mistral-7b-instruct", "vicuna-7b", "llama-2-13b-chat"}, "mistral-7b-instruct": { @@ -229,14 +402,27 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re "llama-2-13b-chat", "llama-2-70b-chat", }, - "vicuna-33b": {"llama-2-70b-chat", "gpt-3.5-turbo", "claude-instant-1"}, + "vicuna-33b": {"llama-2-70b-chat", "gpt-3.5-turbo-0613", "claude-instant-1"}, "vicuna-13b": {"llama-2-13b-chat", "llama-2-70b-chat"}, "vicuna-7b": {"llama-2-7b-chat", "mistral-7b-instruct", "llama-2-13b-chat"}, - "wizardlm-70b": {"gpt-3.5-turbo", "vicuna-33b", "claude-instant-1"}, - "palm-2": {"llama-2-13b-chat", "gpt-3.5-turbo"}, + "wizardlm-70b": {"gpt-3.5-turbo-0613", "vicuna-33b", "claude-instant-1"}, } -SAMPLING_BOOST_MODELS = ["openchat-3.5", "gpt-4-turbo", "gpt-3.5-turbo-1106"] +SAMPLING_BOOST_MODELS = [ + # "claude-2.1", + # "gpt-4-0613", + # "gpt-4-0314", + # "gpt-4-1106-preview", + # "gpt-4-0125-preview", + "gpt-3.5-turbo-0125", + # "mistral-medium", + "nous-hermes-2-mixtral-8x7b-dpo", + "openchat-3.5-0106", + "qwen1.5-72b-chat", + "qwen1.5-7b-chat", + "qwen1.5-4b-chat", + # "mistral-7b-instruct-v0.2", +] # outage models won't be sampled. OUTAGE_MODELS = [] @@ -263,6 +449,8 @@ def get_battle_pair(): model_weights = model_weights / total_weight chosen_idx = np.random.choice(len(models), p=model_weights) chosen_model = models[chosen_idx] + # for p, w in zip(models, model_weights): + # print(p, w) rival_models = [] rival_weights = [] @@ -353,10 +541,10 @@ def add_text( states[i].conv.append_message(states[i].conv.roles[1], None) states[i].skip_next = False - slow_model_msg = "" + hint_msg = "" for i in range(num_sides): if "deluxe" in states[i].model_name: - slow_model_msg = SLOW_MODEL_MSG + hint_msg = SLOW_MODEL_MSG return ( states + [x.to_gradio_chatbot() for x in states] @@ -365,7 +553,7 @@ def add_text( disable_btn, ] * 6 - + [slow_model_msg] + + [hint_msg] ) @@ -399,16 +587,25 @@ def bot_response_multi( top_p, max_new_tokens, request, + apply_rate_limit=False, ) ) + is_gemini = [] + for i in range(num_sides): + is_gemini.append(states[i].model_name in ["gemini-pro", "gemini-pro-dev-api"]) chatbots = [None] * num_sides + iters = 0 while True: stop = True + iters += 1 for i in range(num_sides): try: - ret = next(gen[i]) - states[i], chatbots[i] = ret[0], ret[1] + # yield gemini fewer times as its chunk size is larger + # otherwise, gemini will stream too fast + if not is_gemini[i] or (iters % 30 == 1 or iters < 3): + ret = next(gen[i]) + states[i], chatbots[i] = ret[0], ret[1] stop = False except StopIteration: pass @@ -419,7 +616,7 @@ def bot_response_multi( def build_side_by_side_ui_anony(models): notice_markdown = """ -# ⚔️ Chatbot Arena ⚔️ : Benchmarking LLMs in the Wild +# ⚔️ Chatbot Arena: Benchmarking LLMs in the Wild | [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | ## 📜 Rules @@ -427,12 +624,11 @@ def build_side_by_side_ui_anony(models): - You can continue chatting until you identify a winner. - Vote won't be counted if model identity is revealed during conversation. -## 🏆 Arena Elo [Leaderboard](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) -We use **100K** human votes to compile an Elo-based LLM leaderboard. +## 🏆 Arena Elo [Leaderboard](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) +We collect **200K+** human votes to compute an Elo-based LLM leaderboard. Find out who is the 🥇LLM Champion! ## 👇 Chat now! - """ states = [gr.State() for _ in range(num_sides)] @@ -441,44 +637,51 @@ def build_side_by_side_ui_anony(models): gr.Markdown(notice_markdown, elem_id="notice_markdown") - with gr.Box(elem_id="share-region-anony"): + with gr.Group(elem_id="share-region-anony"): + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", open=False + ): + model_description_md = get_model_description_md(models) + gr.Markdown(model_description_md, elem_id="model_description_markdown") with gr.Row(): for i in range(num_sides): label = "Model A" if i == 0 else "Model B" with gr.Column(): chatbots[i] = gr.Chatbot( - label=label, elem_id=f"chatbot", height=550 + label=label, + elem_id="chatbot", + height=550, + show_copy_button=True, ) with gr.Row(): for i in range(num_sides): with gr.Column(): - model_selectors[i] = gr.Markdown(anony_names[i]) + model_selectors[i] = gr.Markdown( + anony_names[i], elem_id="model_selector_md" + ) with gr.Row(): slow_warning = gr.Markdown("", elem_id="notice_markdown") - with gr.Row(): - leftvote_btn = gr.Button( - value="👈 A is better", visible=False, interactive=False - ) - rightvote_btn = gr.Button( - value="👉 B is better", visible=False, interactive=False - ) - tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) - bothbad_btn = gr.Button( - value="👎 Both are bad", visible=False, interactive=False - ) + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) with gr.Row(): - with gr.Column(scale=20): - textbox = gr.Textbox( - show_label=False, - placeholder="👉 Enter your prompt and press ENTER", - container=False, - elem_id="input_box", - ) - with gr.Column(scale=1, min_width=50): - send_btn = gr.Button(value="Send", variant="primary") + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + elem_id="input_box", + ) + send_btn = gr.Button(value="Send", variant="primary", scale=0) with gr.Row() as button_row: clear_btn = gr.Button(value="🎲 New Round", interactive=False) @@ -504,14 +707,14 @@ def build_side_by_side_ui_anony(models): ) max_output_tokens = gr.Slider( minimum=16, - maximum=1024, - value=512, + maximum=2048, + value=1024, step=64, interactive=True, label="Max output tokens", ) - gr.Markdown(acknowledgment_md) + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") # Register listeners btn_list = [ @@ -577,7 +780,7 @@ def build_side_by_side_ui_anony(models): return [a, b, c, d]; } """ - share_btn.click(share_click, states + model_selectors, [], _js=share_js) + share_btn.click(share_click, states + model_selectors, [], js=share_js) textbox.submit( add_text, diff --git a/fastchat/serve/gradio_block_arena_named.py b/fastchat/serve/gradio_block_arena_named.py index c13283495..9774c3dea 100644 --- a/fastchat/serve/gradio_block_arena_named.py +++ b/fastchat/serve/gradio_block_arena_named.py @@ -25,9 +25,8 @@ disable_btn, invisible_btn, acknowledgment_md, - get_model_description_md, - ip_expiration_dict, get_ip, + get_model_description_md, ) from fastchat.utils import ( build_logger, @@ -58,8 +57,8 @@ def load_demo_side_by_side_named(models, url_params): model_right = model_left selector_updates = ( - gr.Dropdown.update(choices=models, value=model_left, visible=True), - gr.Dropdown.update(choices=models, value=model_right, visible=True), + gr.Dropdown(choices=models, value=model_left, visible=True), + gr.Dropdown(choices=models, value=model_right, visible=True), ) return states + selector_updates @@ -242,13 +241,22 @@ def bot_response_multi( ) ) + is_gemini = [] + for i in range(num_sides): + is_gemini.append(states[i].model_name in ["gemini-pro", "gemini-pro-dev-api"]) + chatbots = [None] * num_sides + iters = 0 while True: stop = True + iters += 1 for i in range(num_sides): try: - ret = next(gen[i]) - states[i], chatbots[i] = ret[0], ret[1] + # yield gemini fewer times as its chunk size is larger + # otherwise, gemini will stream too fast + if not is_gemini[i] or (iters % 30 == 1 or iters < 3): + ret = next(gen[i]) + states[i], chatbots[i] = ret[0], ret[1] stop = False except StopIteration: pass @@ -264,12 +272,12 @@ def flash_buttons(): ] for i in range(4): yield btn_updates[i % 2] - time.sleep(0.5) + time.sleep(0.3) def build_side_by_side_ui_named(models): notice_markdown = """ -# ⚔️ Chatbot Arena ⚔️ : Benchmarking LLMs in the Wild +# ⚔️ Chatbot Arena: Benchmarking LLMs in the Wild | [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | ## 📜 Rules @@ -284,12 +292,9 @@ def build_side_by_side_ui_named(models): model_selectors = [None] * num_sides chatbots = [None] * num_sides - model_description_md = get_model_description_md(models) - notice = gr.Markdown( - notice_markdown + model_description_md, elem_id="notice_markdown" - ) + notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") - with gr.Box(elem_id="share-region-named"): + with gr.Group(elem_id="share-region-named"): with gr.Row(): for i in range(num_sides): with gr.Column(): @@ -300,41 +305,47 @@ def build_side_by_side_ui_named(models): show_label=False, container=False, ) + with gr.Row(): + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", open=False + ): + model_description_md = get_model_description_md(models) + gr.Markdown(model_description_md, elem_id="model_description_markdown") with gr.Row(): for i in range(num_sides): label = "Model A" if i == 0 else "Model B" with gr.Column(): chatbots[i] = gr.Chatbot( - label=label, elem_id=f"chatbot", height=550 + label=label, + elem_id=f"chatbot", + height=550, + show_copy_button=True, ) - with gr.Row(): - leftvote_btn = gr.Button( - value="👈 A is better", visible=False, interactive=False - ) - rightvote_btn = gr.Button( - value="👉 B is better", visible=False, interactive=False - ) - tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) - bothbad_btn = gr.Button( - value="👎 Both are bad", visible=False, interactive=False - ) + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) with gr.Row(): - with gr.Column(scale=20): - textbox = gr.Textbox( - show_label=False, - placeholder="Enter your prompt here and press ENTER", - container=False, - elem_id="input_box", - ) - with gr.Column(scale=1, min_width=50): - send_btn = gr.Button(value="Send", variant="primary") + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + elem_id="input_box", + ) + send_btn = gr.Button(value="Send", variant="primary", scale=0) with gr.Row() as button_row: - regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) share_btn = gr.Button(value="📷 Share") with gr.Accordion("Parameters", open=False) as parameter_row: @@ -356,14 +367,14 @@ def build_side_by_side_ui_named(models): ) max_output_tokens = gr.Slider( minimum=16, - maximum=1024, - value=512, + maximum=2048, + value=1024, step=64, interactive=True, label="Max output tokens", ) - gr.Markdown(acknowledgment_md) + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") # Register listeners btn_list = [ @@ -425,7 +436,7 @@ def build_side_by_side_ui_named(models): return [a, b, c, d]; } """ - share_btn.click(share_click, states + model_selectors, [], _js=share_js) + share_btn.click(share_click, states + model_selectors, [], js=share_js) for i in range(num_sides): model_selectors[i].change( diff --git a/fastchat/serve/gradio_block_arena_vision.py b/fastchat/serve/gradio_block_arena_vision.py new file mode 100644 index 000000000..5ddf138e8 --- /dev/null +++ b/fastchat/serve/gradio_block_arena_vision.py @@ -0,0 +1,222 @@ +""" +The gradio demo server for chatting with a large multimodal model. + +Usage: +python3 -m fastchat.serve.controller +python3 -m fastchat.serve.sglang_worker --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf +python3 -m fastchat.serve.gradio_web_server_multi --share --multimodal +""" + +import json +import os + +import gradio as gr +import numpy as np + +from fastchat.serve.gradio_web_server import ( + upvote_last_response, + downvote_last_response, + flag_last_response, + get_model_description_md, + acknowledgment_md, + bot_response, + add_text, + clear_history, + regenerate, + get_ip, + disable_btn, +) +from fastchat.utils import ( + build_logger, +) + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + + +def get_vqa_sample(): + random_sample = np.random.choice(vqa_samples) + question, path = random_sample["question"], random_sample["path"] + return question, path + + +def clear_history_example(request: gr.Request): + ip = get_ip(request) + logger.info(f"clear_history_example. ip: {ip}") + state = None + return (state, []) + (disable_btn,) * 5 + + +def build_single_vision_language_model_ui( + models, add_promotion_links=False, random_questions=None +): + promotion = ( + """ +| [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | +""" + if add_promotion_links + else "" + ) + + notice_markdown = f""" +# 🏔️ Chat with Open Large Vision-Language Models +{promotion} +""" + + state = gr.State() + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Group(): + with gr.Row(elem_id="model_selector_row"): + model_selector = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "", + interactive=True, + show_label=False, + container=False, + ) + + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", open=False + ): + model_description_md = get_model_description_md(models) + gr.Markdown(model_description_md, elem_id="model_description_markdown") + + with gr.Row(): + with gr.Column(scale=3): + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + container=False, + render=False, + elem_id="input_box", + ) + imagebox = gr.Image(type="pil", sources=["upload", "clipboard"]) + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.2, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=0, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + examples = gr.Examples( + examples=[ + [ + f"{cur_dir}/example_images/fridge.jpg", + "How can I prepare a delicious meal using these ingredients?", + ], + [ + f"{cur_dir}/example_images/distracted.jpg", + "What might the woman on the right be thinking about?", + ], + ], + inputs=[imagebox, textbox], + ) + + if random_questions: + global vqa_samples + with open(random_questions, "r") as f: + vqa_samples = json.load(f) + random_btn = gr.Button(value="🎲 Random Example", interactive=True) + + with gr.Column(scale=8): + chatbot = gr.Chatbot( + elem_id="chatbot", label="Scroll down and start chatting", height=550 + ) + + with gr.Row(): + with gr.Column(scale=8): + textbox.render() + with gr.Column(scale=1, min_width=50): + send_btn = gr.Button(value="Send", variant="primary") + + with gr.Row(elem_id="buttons"): + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear", interactive=False) + + if add_promotion_links: + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] + upvote_btn.click( + upvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + downvote_btn.click( + downvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + flag_btn.click( + flag_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + regenerate_btn.click( + regenerate, state, [state, chatbot, textbox, imagebox] + btn_list + ).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list) + + model_selector.change( + clear_history, None, [state, chatbot, textbox, imagebox] + btn_list + ) + imagebox.upload(clear_history_example, None, [state, chatbot] + btn_list) + examples.dataset.click(clear_history_example, None, [state, chatbot] + btn_list) + + textbox.submit( + add_text, + [state, model_selector, textbox, imagebox], + [state, chatbot, textbox, imagebox] + btn_list, + ).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + send_btn.click( + add_text, + [state, model_selector, textbox, imagebox], + [state, chatbot, textbox, imagebox] + btn_list, + ).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + + if random_questions: + random_btn.click( + get_vqa_sample, # First, get the VQA sample + [], # Pass the path to the VQA samples + [textbox, imagebox], # Outputs are textbox and imagebox + ) + + return [state, model_selector] diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index ba7a4aa4c..843e2a5b1 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -5,6 +5,7 @@ import argparse from collections import defaultdict import datetime +import hashlib import json import os import random @@ -14,32 +15,30 @@ import gradio as gr import requests -from fastchat.conversation import SeparatorStyle from fastchat.constants import ( LOGDIR, WORKER_API_TIMEOUT, ErrorCode, MODERATION_MSG, CONVERSATION_LIMIT_MSG, + RATE_LIMIT_MSG, SERVER_ERROR_MSG, INPUT_CHAR_LEN_LIMIT, CONVERSATION_TURN_LIMIT, SESSION_EXPIRATION_TIME, ) -from fastchat.model.model_adapter import get_conversation_template -from fastchat.model.model_registry import get_model_info, model_info -from fastchat.serve.api_provider import ( - anthropic_api_stream_iter, - openai_api_stream_iter, - palm_api_stream_iter, - init_palm_chat, +from fastchat.model.model_adapter import ( + get_conversation_template, ) +from fastchat.model.model_registry import get_model_info, model_info +from fastchat.serve.api_provider import get_api_provider_stream_iter from fastchat.utils import ( build_logger, - moderation_filter, get_window_url_params_js, get_window_url_params_with_tos_js, + moderation_filter, parse_gradio_auth_creds, + load_image, ) @@ -47,37 +46,53 @@ headers = {"User-Agent": "FastChat Client"} -no_change_btn = gr.Button.update() -enable_btn = gr.Button.update(interactive=True, visible=True) -disable_btn = gr.Button.update(interactive=False) -invisible_btn = gr.Button.update(interactive=False, visible=False) +no_change_btn = gr.Button() +enable_btn = gr.Button(interactive=True, visible=True) +disable_btn = gr.Button(interactive=False) +invisible_btn = gr.Button(interactive=False, visible=False) controller_url = None enable_moderation = False acknowledgment_md = """ +### Terms of Service + +Users are required to agree to the following terms before using the service: + +The service is a research preview. It only provides limited safety measures and may generate offensive content. +It must not be used for any illegal, harmful, violent, racist, or sexual purposes. +Please do not upload any private information. +The service collects user dialogue data, including both text and images, and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license. +Additionally, Bard is offered on LMSys for research purposes only. To access the Bard product, please visit its [website](http://bard.google.com). + ### Acknowledgment -
-

We thank Kaggle, MBZUAI, AnyScale, and HuggingFace for their sponsorship.

- Image 1 - Image 2 - Image 3 - Image 4 +We thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [a16z](https://www.a16z.com/), [Together AI](https://www.together.ai/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous [sponsorship](https://lmsys.org/donations/). + + """ -ip_expiration_dict = defaultdict(lambda: 0) - -# Information about custom OpenAI compatible API models. -# JSON file format: +# JSON file format of API-based models: # { -# "vicuna-7b": { -# "model_name": "vicuna-7b-v1.5", -# "api_base": "http://8.8.8.55:5555/v1", -# "api_key": "password" -# }, +# "gpt-3.5-turbo": { +# "model_name": "gpt-3.5-turbo", +# "api_type": "openai", +# "api_base": "https://api.openai.com/v1", +# "api_key": "sk-******", +# "anony_only": false +# } # } -openai_compatible_models_info = {} +# +# - "api_type" can be one of the following: openai, anthropic, gemini, or mistral. For custom APIs, add a new type and implement it accordingly. +# - "anony_only" indicates whether to display this model in anonymous mode only. + +api_endpoint_info = {} class State: @@ -87,11 +102,6 @@ def __init__(self, model_name): self.skip_next = False self.model_name = model_name - if model_name == "palm-2": - # According to release note, "chat-bison@001" is PaLM 2 for chat. - # https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023 - self.palm_chat = init_palm_chat("chat-bison@001") - def to_gradio_chatbot(self): return self.conv.to_gradio_chatbot() @@ -118,42 +128,50 @@ def get_conv_log_filename(): return name -def get_model_list( - controller_url, register_openai_compatible_models, add_chatgpt, add_claude, add_palm -): +def get_model_list(controller_url, register_api_endpoint_file, multimodal): + global api_endpoint_info + + # Add models from the controller if controller_url: ret = requests.post(controller_url + "/refresh_all_workers") assert ret.status_code == 200 - ret = requests.post(controller_url + "/list_models") - models = ret.json()["models"] + + if multimodal: + ret = requests.post(controller_url + "/list_multimodal_models") + models = ret.json()["models"] + else: + ret = requests.post(controller_url + "/list_language_models") + models = ret.json()["models"] else: models = [] - # Add API providers - if register_openai_compatible_models: - global openai_compatible_models_info - openai_compatible_models_info = json.load( - open(register_openai_compatible_models) - ) - models += list(openai_compatible_models_info.keys()) - - if add_chatgpt: - models += ["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", "gpt-3.5-turbo-1106"] - if add_claude: - models += ["claude-2", "claude-instant-1"] - if add_palm: - models += ["palm-2"] + # Add models from the API providers + if register_api_endpoint_file: + api_endpoint_info = json.load(open(register_api_endpoint_file)) + for mdl, mdl_dict in api_endpoint_info.items(): + mdl_multimodal = mdl_dict.get("multimodal", False) + if multimodal and mdl_multimodal: + models += [mdl] + elif not multimodal and not mdl_multimodal: + models += [mdl] + + # Remove anonymous models models = list(set(models)) + visible_models = models.copy() + for mdl in visible_models: + if mdl not in api_endpoint_info: + continue + mdl_dict = api_endpoint_info[mdl] + if mdl_dict["anony_only"]: + visible_models.remove(mdl) - if "deluxe-chat-v1" in models: - del models[models.index("deluxe-chat-v1")] - if "deluxe-chat-v1.1" in models: - del models[models.index("deluxe-chat-v1.1")] - - priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)} + # Sort models and add descriptions + priority = {k: f"___{i:03d}" for i, k in enumerate(model_info)} models.sort(key=lambda x: priority.get(x, x)) - logger.info(f"Models: {models}") - return models + visible_models.sort(key=lambda x: priority.get(x, x)) + logger.info(f"All models: {models}") + logger.info(f"Visible models: {visible_models}") + return visible_models, models def load_demo_single(models, url_params): @@ -163,10 +181,7 @@ def load_demo_single(models, url_params): if model in models: selected_model = model - dropdown_update = gr.Dropdown.update( - choices=models, value=selected_model, visible=True - ) - + dropdown_update = gr.Dropdown(choices=models, value=selected_model, visible=True) state = None return state, dropdown_update @@ -176,15 +191,10 @@ def load_demo(url_params, request: gr.Request): ip = get_ip(request) logger.info(f"load_demo. ip: {ip}. params: {url_params}") - ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME if args.model_list_mode == "reload": - models = get_model_list( - controller_url, - args.register_openai_compatible_models, - args.add_chatgpt, - args.add_claude, - args.add_palm, + models, all_models = get_model_list( + controller_url, args.register_api_endpoint_file, False ) return load_demo_single(models, url_params) @@ -227,14 +237,14 @@ def regenerate(state, request: gr.Request): ip = get_ip(request) logger.info(f"regenerate. ip: {ip}") state.conv.update_last_message(None) - return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def clear_history(request: gr.Request): ip = get_ip(request) logger.info(f"clear_history. ip: {ip}") state = None - return (state, [], "") + (disable_btn,) * 5 + return (state, [], "", None) + (disable_btn,) * 5 def get_ip(request: gr.Request): @@ -245,7 +255,22 @@ def get_ip(request: gr.Request): return ip -def add_text(state, model_selector, text, request: gr.Request): +def _prepare_text_with_image(state, text, image): + if image is not None: + if len(state.conv.get_images()) > 0: + # reset convo with new image + state.conv = get_conversation_template(state.model_name) + + image = state.conv.convert_image_to_base64( + image + ) # PIL type is not JSON serializable + + text = text, [image] + + return text + + +def add_text(state, model_selector, text, image, request: gr.Request): ip = get_ip(request) logger.info(f"add_text. ip: {ip}. len: {len(text)}") @@ -262,8 +287,7 @@ def add_text(state, model_selector, text, request: gr.Request): # overwrite the original text text = MODERATION_MSG - conv = state.conv - if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: logger.info(f"conversation turn limit. ip: {ip}. text: {text}") state.skip_next = True return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + ( @@ -271,20 +295,10 @@ def add_text(state, model_selector, text, request: gr.Request): ) * 5 text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off - conv.append_message(conv.roles[0], text) - conv.append_message(conv.roles[1], None) - return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 - - -def post_process_code(code): - sep = "\n```" - if sep in code: - blocks = code.split(sep) - if len(blocks) % 2 == 1: - for i in range(1, len(blocks), 2): - blocks[i] = blocks[i].replace("\\_", "_") - code = sep.join(blocks) - return code + text = _prepare_text_with_image(state, text, image) + state.conv.append_message(state.conv.roles[0], text) + state.conv.append_message(state.conv.roles[1], None) + return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def model_worker_stream_iter( @@ -296,6 +310,7 @@ def model_worker_stream_iter( repetition_penalty, top_p, max_new_tokens, + images, ): # Make requests gen_params = { @@ -309,8 +324,12 @@ def model_worker_stream_iter( "stop_token_ids": conv.stop_token_ids, "echo": False, } + logger.info(f"==== request ====\n{gen_params}") + if len(images) > 0: + gen_params["images"] = images + # Stream output response = requests.post( worker_addr + "/worker_generate_stream", @@ -325,7 +344,27 @@ def model_worker_stream_iter( yield data -def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request): +def is_limit_reached(model_name, ip): + monitor_url = "http://localhost:9090" + try: + ret = requests.get( + f"{monitor_url}/is_limit_reached?model={model_name}&user_id={ip}", timeout=1 + ) + obj = ret.json() + return obj + except Exception as e: + logger.info(f"monitor error: {e}") + return None + + +def bot_response( + state, + temperature, + top_p, + max_new_tokens, + request: gr.Request, + apply_rate_limit=True, +): ip = get_ip(request) logger.info(f"bot_response. ip: {ip}") start_tstamp = time.time() @@ -339,34 +378,22 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request) yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return + if apply_rate_limit: + ret = is_limit_reached(state.model_name, ip) + if ret is not None and ret["is_limit_reached"]: + error_msg = RATE_LIMIT_MSG + "\n\n" + ret["reason"] + logger.info(f"rate limit reached. ip: {ip}. error_msg: {ret['reason']}") + state.conv.update_last_message(error_msg) + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + conv, model_name = state.conv, state.model_name - if model_name in ["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", "gpt-3.5-turbo-1106"]: - prompt = conv.to_openai_api_messages() - stream_iter = openai_api_stream_iter( - model_name, prompt, temperature, top_p, max_new_tokens - ) - elif model_name in ["claude-2", "claude-1", "claude-instant-1"]: - prompt = conv.get_prompt() - stream_iter = anthropic_api_stream_iter( - model_name, prompt, temperature, top_p, max_new_tokens - ) - elif model_name == "palm-2": - stream_iter = palm_api_stream_iter( - state.palm_chat, conv.messages[-2][1], temperature, top_p, max_new_tokens - ) - elif model_name in openai_compatible_models_info: - model_info = openai_compatible_models_info[model_name] - prompt = conv.to_openai_api_messages() - stream_iter = openai_api_stream_iter( - model_info["model_name"], - prompt, - temperature, - top_p, - max_new_tokens, - api_base=model_info["api_base"], - api_key=model_info["api_key"], - ) - else: + model_api_dict = ( + api_endpoint_info[model_name] if model_name in api_endpoint_info else None + ) + images = conv.get_images() + + if model_api_dict is None: # Query worker address ret = requests.post( controller_url + "/get_worker_address", json={"model": model_name} @@ -407,6 +434,16 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request) repetition_penalty, top_p, max_new_tokens, + images, + ) + else: + stream_iter = get_api_provider_stream_iter( + conv, + model_name, + model_api_dict, + temperature, + top_p, + max_new_tokens, ) conv.update_last_message("▌") @@ -430,8 +467,6 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request) ) return output = data["text"].strip() - if "vicuna" in model_name: - output = post_process_code(output) conv.update_last_message(output) yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 except requests.exceptions.RequestException as e: @@ -464,6 +499,20 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request) finish_tstamp = time.time() logger.info(f"{output}") + # We load the image because gradio accepts base64 but that increases file size by ~1.33x + loaded_images = [load_image(image) for image in images] + images_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in loaded_images] + for image, hash_str in zip(loaded_images, images_hash): + t = datetime.datetime.now() + filename = os.path.join( + LOGDIR, + "serve_images", + f"{hash_str}.jpg", + ) + if not os.path.isfile(filename): + os.makedirs(os.path.dirname(filename), exist_ok=True) + image.save(filename) + with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), @@ -478,13 +527,14 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request) "finish": round(finish_tstamp, 4), "state": state.dict(), "ip": get_ip(request), + "images": images_hash, } fout.write(json.dumps(data) + "\n") block_css = """ -#notice_markdown { - font-size: 110% +#notice_markdown .prose { + font-size: 120% !important; } #notice_markdown th { display: none; @@ -493,8 +543,11 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request) padding-top: 6px; padding-bottom: 6px; } -#leaderboard_markdown { - font-size: 110% +#model_description_markdown { + font-size: 120% !important; +} +#leaderboard_markdown .prose { + font-size: 120% !important; } #leaderboard_markdown td { padding-top: 6px; @@ -503,13 +556,22 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request) #leaderboard_dataframe td { line-height: 0.1em; } -#about_markdown { - font-size: 110% +#about_markdown .prose { + font-size: 120% !important; } -#input_box textarea { +#ack_markdown .prose { + font-size: 120% !important; } footer { - display:none !important + display:none !important; +} +.sponsor-image-about img { + margin: 0 20px; + margin-top: 20px; + height: 40px; + max-height: 100%; + width: auto; + float: left; } .image-container { display: flex; @@ -558,9 +620,9 @@ def get_model_description_md(models): def build_about(): - about_markdown = f""" + about_markdown = """ # About Us -Chatbot Arena is an open-source research project developed by members from [LMSYS](https://lmsys.org/about/) and UC Berkeley [SkyLab](https://sky.cs.berkeley.edu/). Our mission is to build an open crowdsourced platform to collect human feedback and evaluate LLMs under real-world scenarios. We open-source our code at [GitHub](https://github.com/lm-sys/FastChat) and release chat and human feedback datasets [here](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md). We invite everyone to join us in this journey! +Chatbot Arena is an open-source research project developed by members from [LMSYS](https://lmsys.org/about/) and UC Berkeley [SkyLab](https://sky.cs.berkeley.edu/). Our mission is to build an open crowdsourced platform to collect human feedback and evaluate LLMs under real-world scenarios. We open-source our [FastChat](https://github.com/lm-sys/FastChat) project at GitHub and release chat and human feedback datasets [here](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md). We invite everyone to join us in this journey! ## Read More - Chatbot Arena [launch post](https://lmsys.org/blog/2023-05-03-arena/), [data release](https://lmsys.org/blog/2023-07-20-dataset/) @@ -577,23 +639,21 @@ def build_about(): - File issues on [GitHub](https://github.com/lm-sys/FastChat) - Download our datasets and models on [HuggingFace](https://huggingface.co/lmsys) -## Sponsors -We thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous sponsorship. -Learn more about partnership [here](https://lmsys.org/donations/). - -
- Image 1 - Image 2 - Image 3 - Image 4 +## Acknowledgment +We thank [SkyPilot](https://github.com/skypilot-org/skypilot) and [Gradio](https://github.com/gradio-app/gradio) team for their system support. +We also thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [a16z](https://www.a16z.com/), [Together AI](https://www.together.ai/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous sponsorship. Learn more about partnership [here](https://lmsys.org/donations/). + + """ - - # state = gr.State() gr.Markdown(about_markdown, elem_id="about_markdown") - # return [state] - def build_single_model_ui(models, add_promotion_links=False): promotion = ( @@ -601,6 +661,8 @@ def build_single_model_ui(models, add_promotion_links=False): - | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | - Introducing Llama 2: The Next Generation Open Source Large Language Model. [[Website]](https://ai.meta.com/llama/) - Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog]](https://lmsys.org/blog/2023-03-30-vicuna/) + +## 🤖 Choose any model to chat """ if add_promotion_links else "" @@ -609,38 +671,42 @@ def build_single_model_ui(models, add_promotion_links=False): notice_markdown = f""" # 🏔️ Chat with Open Large Language Models {promotion} - -## Choose any model to chat """ state = gr.State() - model_description_md = get_model_description_md(models) - gr.Markdown(notice_markdown + model_description_md, elem_id="notice_markdown") - - with gr.Row(elem_id="model_selector_row"): - model_selector = gr.Dropdown( - choices=models, - value=models[0] if len(models) > 0 else "", - interactive=True, - show_label=False, - container=False, - ) - - chatbot = gr.Chatbot( - elem_id="chatbot", - label="Scroll down and start chatting", - height=550, - ) - with gr.Row(): - with gr.Column(scale=20): - textbox = gr.Textbox( + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Group(elem_id="share-region-named"): + with gr.Row(elem_id="model_selector_row"): + model_selector = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "", + interactive=True, show_label=False, - placeholder="👉 Enter your prompt and press ENTER", container=False, elem_id="input_box", ) - with gr.Column(scale=1, min_width=50): - send_btn = gr.Button(value="Send", variant="primary") + with gr.Row(): + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", + open=False, + ): + model_description_md = get_model_description_md(models) + gr.Markdown(model_description_md, elem_id="model_description_markdown") + + chatbot = gr.Chatbot( + elem_id="chatbot", + label="Scroll down and start chatting", + height=550, + show_copy_button=True, + ) + with gr.Row(): + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + elem_id="input_box", + ) + send_btn = gr.Button(value="Send", variant="primary", scale=0) with gr.Row() as button_row: upvote_btn = gr.Button(value="👍 Upvote", interactive=False) @@ -668,17 +734,18 @@ def build_single_model_ui(models, add_promotion_links=False): ) max_output_tokens = gr.Slider( minimum=16, - maximum=1024, - value=512, + maximum=2048, + value=1024, step=64, interactive=True, label="Max output tokens", ) if add_promotion_links: - gr.Markdown(acknowledgment_md) + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") # Register listeners + imagebox = gr.State(None) btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] upvote_btn.click( upvote_last_response, @@ -695,17 +762,23 @@ def build_single_model_ui(models, add_promotion_links=False): [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], ) - regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( + regenerate_btn.click( + regenerate, state, [state, chatbot, textbox, imagebox] + btn_list + ).then( bot_response, [state, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, ) - clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) + clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list) - model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list) + model_selector.change( + clear_history, None, [state, chatbot, textbox, imagebox] + btn_list + ) textbox.submit( - add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list + add_text, + [state, model_selector, textbox, imagebox], + [state, chatbot, textbox, imagebox] + btn_list, ).then( bot_response, [state, temperature, top_p, max_output_tokens], @@ -713,8 +786,8 @@ def build_single_model_ui(models, add_promotion_links=False): ) send_btn.click( add_text, - [state, model_selector, textbox], - [state, chatbot, textbox] + btn_list, + [state, model_selector, textbox, imagebox], + [state, chatbot, textbox, imagebox] + btn_list, ).then( bot_response, [state, temperature, top_p, max_output_tokens], @@ -749,7 +822,7 @@ def build_demo(models): state, model_selector, ], - _js=load_js, + js=load_js, ) return demo @@ -794,41 +867,27 @@ def build_demo(models): help="Shows term of use before loading the demo", ) parser.add_argument( - "--add-chatgpt", - action="store_true", - help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)", - ) - parser.add_argument( - "--add-claude", - action="store_true", - help="Add Anthropic's Claude models (claude-2, claude-instant-1)", - ) - parser.add_argument( - "--add-palm", - action="store_true", - help="Add Google's PaLM model (PaLM 2 for Chat: chat-bison@001)", - ) - parser.add_argument( - "--register-openai-compatible-models", + "--register-api-endpoint-file", type=str, - help="Register custom OpenAI API compatible models by loading them from a JSON file", + help="Register API-based model endpoints from a JSON file", ) parser.add_argument( "--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', ) + parser.add_argument( + "--gradio-root-path", + type=str, + help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix", + ) args = parser.parse_args() logger.info(f"args: {args}") # Set global variables set_global_vars(args.controller_url, args.moderate) - models = get_model_list( - args.controller_url, - args.register_openai_compatible_models, - args.add_chatgpt, - args.add_claude, - args.add_palm, + models, all_models = get_model_list( + args.controller_url, args.register_api_endpoint_file, False ) # Set authorization credentials @@ -839,11 +898,14 @@ def build_demo(models): # Launch the demo demo = build_demo(models) demo.queue( - concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + default_concurrency_limit=args.concurrency_count, + status_update_rate=10, + api_open=False, ).launch( server_name=args.host, server_port=args.port, share=args.share, max_threads=200, auth=auth, + root_path=args.gradio_root_path, ) diff --git a/fastchat/serve/gradio_web_server_multi.py b/fastchat/serve/gradio_web_server_multi.py index b918f9d6b..538d7776b 100644 --- a/fastchat/serve/gradio_web_server_multi.py +++ b/fastchat/serve/gradio_web_server_multi.py @@ -9,9 +9,6 @@ import gradio as gr -from fastchat.constants import ( - SESSION_EXPIRATION_TIME, -) from fastchat.serve.gradio_block_arena_anony import ( build_side_by_side_ui_anony, load_demo_side_by_side_anony, @@ -22,6 +19,9 @@ load_demo_side_by_side_named, set_global_vars_named, ) +from fastchat.serve.gradio_block_arena_vision import ( + build_single_vision_language_model_ui, +) from fastchat.serve.gradio_web_server import ( set_global_vars, block_css, @@ -29,7 +29,6 @@ build_about, get_model_list, load_demo_single, - ip_expiration_dict, get_ip, ) from fastchat.serve.monitor.monitor import build_leaderboard_tab @@ -44,74 +43,78 @@ def load_demo(url_params, request: gr.Request): - global models + global models, all_models, vl_models ip = get_ip(request) logger.info(f"load_demo. ip: {ip}. params: {url_params}") - ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME selected = 0 if "arena" in url_params: selected = 0 elif "compare" in url_params: selected = 1 - elif "single" in url_params: + elif "direct" in url_params or "model" in url_params: selected = 2 - elif "leaderboard" in url_params: + elif "vision" in url_params: selected = 3 + elif "leaderboard" in url_params: + selected = 4 if args.model_list_mode == "reload": - if args.anony_only_for_proprietary_model: - models = get_model_list( - args.controller_url, - args.register_openai_compatible_models, - False, - False, - False, - ) - else: - models = get_model_list( - args.controller_url, - args.register_openai_compatible_models, - args.add_chatgpt, - args.add_claude, - args.add_palm, - ) + models, all_models = get_model_list( + args.controller_url, + args.register_api_endpoint_file, + False, + ) - single_updates = load_demo_single(models, url_params) + vl_models, all_vl_models = get_model_list( + args.controller_url, + args.register_api_endpoint_file, + True, + ) - models_anony = list(models) - if args.anony_only_for_proprietary_model: - # Only enable these models in anony battles. - if args.add_chatgpt: - models_anony += [ - "gpt-4", - "gpt-3.5-turbo", - "gpt-4-turbo", - "gpt-3.5-turbo-1106", - ] - if args.add_claude: - models_anony += ["claude-2", "claude-1", "claude-instant-1"] - if args.add_palm: - models_anony += ["palm-2"] - models_anony = list(set(models_anony)) - - side_by_side_anony_updates = load_demo_side_by_side_anony(models_anony, url_params) + single_updates = load_demo_single(models, url_params) + side_by_side_anony_updates = load_demo_side_by_side_anony(all_models, url_params) side_by_side_named_updates = load_demo_side_by_side_named(models, url_params) + vision_language_updates = load_demo_single(vl_models, url_params) + return ( - (gr.Tabs.update(selected=selected),) + (gr.Tabs(selected=selected),) + single_updates + side_by_side_anony_updates + side_by_side_named_updates + + vision_language_updates ) -def build_demo(models, elo_results_file, leaderboard_table_file): +def build_demo(models, vl_models, elo_results_file, leaderboard_table_file): text_size = gr.themes.sizes.text_md + if args.show_terms_of_use: + load_js = get_window_url_params_with_tos_js + else: + load_js = get_window_url_params_js + + head_js = """ + +""" + if args.ga_id is not None: + head_js += f""" + + + """ + with gr.Blocks( title="Chat with Open Large Language Models", theme=gr.themes.Default(text_size=text_size), css=block_css, + head=head_js, ) as demo: with gr.Tabs() as tabs: with gr.Tab("Arena (battle)", id=0): @@ -124,30 +127,39 @@ def build_demo(models, elo_results_file, leaderboard_table_file): single_model_list = build_single_model_ui( models, add_promotion_links=True ) + + with gr.Tab("Vision Direct Chat", id=3, visible=args.multimodal): + single_vision_language_model_list = ( + build_single_vision_language_model_ui( + vl_models, + add_promotion_links=True, + random_questions=args.random_questions, + ) + ) + if elo_results_file: - with gr.Tab("Leaderboard", id=3): + with gr.Tab("Leaderboard", id=4): build_leaderboard_tab(elo_results_file, leaderboard_table_file) with gr.Tab("About Us", id=4): about = build_about() + with gr.Tab("About Us", id=5): + about = build_about() + url_params = gr.JSON(visible=False) if args.model_list_mode not in ["once", "reload"]: raise ValueError(f"Unknown model list mode: {args.model_list_mode}") - if args.show_terms_of_use: - load_js = get_window_url_params_with_tos_js - else: - load_js = get_window_url_params_js - demo.load( load_demo, [url_params], [tabs] + single_model_list + side_by_side_anony_list - + side_by_side_named_list, - _js=load_js, + + side_by_side_named_list + + single_vision_language_model_list, + js=load_js, ) return demo @@ -192,29 +204,15 @@ def build_demo(models, elo_results_file, leaderboard_table_file): help="Shows term of use before loading the demo", ) parser.add_argument( - "--add-chatgpt", - action="store_true", - help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)", - ) - parser.add_argument( - "--add-claude", - action="store_true", - help="Add Anthropic's Claude models (claude-2, claude-instant-1)", + "--multimodal", action="store_true", help="Show multi modal tabs." ) parser.add_argument( - "--add-palm", - action="store_true", - help="Add Google's PaLM model (PaLM 2 for Chat: chat-bison@001)", + "--random-questions", type=str, help="Load random questions from a JSON file" ) parser.add_argument( - "--anony-only-for-proprietary-model", - action="store_true", - help="Only add ChatGPT, Claude, Bard under anony battle tab", - ) - parser.add_argument( - "--register-openai-compatible-models", + "--register-api-endpoint-file", type=str, - help="Register custom OpenAI API compatible models by loading them from a JSON file", + help="Register API-based model endpoints from a JSON file", ) parser.add_argument( "--gradio-auth-path", @@ -228,6 +226,17 @@ def build_demo(models, elo_results_file, leaderboard_table_file): parser.add_argument( "--leaderboard-table-file", type=str, help="Load leaderboard results and plots" ) + parser.add_argument( + "--gradio-root-path", + type=str, + help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix", + ) + parser.add_argument( + "--ga-id", + type=str, + help="the Google Analytics ID", + default=None, + ) args = parser.parse_args() logger.info(f"args: {args}") @@ -235,22 +244,17 @@ def build_demo(models, elo_results_file, leaderboard_table_file): set_global_vars(args.controller_url, args.moderate) set_global_vars_named(args.moderate) set_global_vars_anony(args.moderate) - if args.anony_only_for_proprietary_model: - models = get_model_list( - args.controller_url, - args.register_openai_compatible_models, - False, - False, - False, - ) - else: - models = get_model_list( - args.controller_url, - args.register_openai_compatible_models, - args.add_chatgpt, - args.add_claude, - args.add_palm, - ) + models, all_models = get_model_list( + args.controller_url, + args.register_api_endpoint_file, + False, + ) + + vl_models, all_vl_models = get_model_list( + args.controller_url, + args.register_api_endpoint_file, + True, + ) # Set authorization credentials auth = None @@ -258,13 +262,21 @@ def build_demo(models, elo_results_file, leaderboard_table_file): auth = parse_gradio_auth_creds(args.gradio_auth_path) # Launch the demo - demo = build_demo(models, args.elo_results_file, args.leaderboard_table_file) + demo = build_demo( + models, + vl_models, + args.elo_results_file, + args.leaderboard_table_file, + ) demo.queue( - concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + default_concurrency_limit=args.concurrency_count, + status_update_rate=10, + api_open=False, ).launch( server_name=args.host, server_port=args.port, share=args.share, max_threads=200, auth=auth, + root_path=args.gradio_root_path, ) diff --git a/fastchat/serve/huggingface_api.py b/fastchat/serve/huggingface_api.py index 2a49bf5f1..8022fbc93 100644 --- a/fastchat/serve/huggingface_api.py +++ b/fastchat/serve/huggingface_api.py @@ -61,7 +61,7 @@ def main(args): add_model_args(parser) parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--repetition_penalty", type=float, default=1.0) - parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--max-new-tokens", type=int, default=1024) parser.add_argument("--debug", action="store_true") parser.add_argument("--message", type=str, default="Hello! Who are you?") args = parser.parse_args() diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py index 2d0611fe5..6ed8e6c8c 100644 --- a/fastchat/serve/huggingface_api_worker.py +++ b/fastchat/serve/huggingface_api_worker.py @@ -4,12 +4,18 @@ Register models in a JSON file with the following format: { "falcon-180b-chat": { - "model_path": "tiiuae/falcon-180B-chat", + "model_name": "falcon-180B-chat", "api_base": "https://api-inference.huggingface.co/models", - "token": "hf_xxx", - "context_length": 2048, - "model_names": "falcon-180b-chat", - "conv_template": null + "model_path": "tiiuae/falcon-180B-chat", + "token": "hf_XXX", + "context_length": 2048 + }, + "zephyr-7b-beta": { + "model_name": "zephyr-7b-beta", + "model_path": "", + "api_base": "xxx", + "token": "hf_XXX", + "context_length": 4096 } } diff --git a/fastchat/serve/lightllm_worker.py b/fastchat/serve/lightllm_worker.py new file mode 100644 index 000000000..ed0e21b68 --- /dev/null +++ b/fastchat/serve/lightllm_worker.py @@ -0,0 +1,512 @@ +""" +A model worker that executes the model based on LightLLM. + +See documentations at docs/lightllm_integration.md +""" + +import argparse +import asyncio +import json +import os +import torch +import uvicorn + +from transformers import AutoConfig + +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse + +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) + +from lightllm.server.sampling_params import SamplingParams +from lightllm.server.multimodal_params import MultimodalParams +from lightllm.server.httpserver.manager import HttpServerManager +from lightllm.server.detokenization.manager import start_detokenization_process +from lightllm.server.router.manager import start_router_process +from lightllm.server.req_id_generator import ReqIDGenerator + +from lightllm.utils.net_utils import alloc_can_use_network_port +from lightllm.utils.start_utils import start_submodule_processes +from fastchat.utils import get_context_length, is_partial_stop + +app = FastAPI() +g_id_gen = ReqIDGenerator() + + +class LightLLMWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + conv_template: str, + tokenizer, + context_len, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: LightLLM worker..." + ) + self.tokenizer = tokenizer + self.context_len = context_len + + self.is_first = True + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + prompt = params.pop("prompt") + request_id = params.pop("request_id") + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + presence_penalty = float(params.get("presence_penalty", 0.0)) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + max_new_tokens = params.get("max_new_tokens", 256) + echo = params.get("echo", True) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if self.tokenizer.eos_token_id is not None: + stop_token_ids.append(self.tokenizer.eos_token_id) + + request = params.get("request", None) + + # Handle stop_str + stop = set() + if isinstance(stop_str, str) and stop_str != "": + stop.add(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.update(stop_str) + + for tid in stop_token_ids: + if tid is not None: + s = self.tokenizer.decode(tid) + if s != "": + stop.add(s) + + if self.is_first: + loop = asyncio.get_event_loop() + loop.create_task(httpserver_manager.handle_loop()) + self.is_first = False + + # make sampling params in vllm + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + + sampling_params = SamplingParams( + do_sample=temperature > 0.0, + temperature=temperature, + top_p=top_p, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repetition_penalty=repetition_penalty, + max_new_tokens=max_new_tokens, + stop_sequences=list(stop), + ) + sampling_params.verify() + + results_generator = httpserver_manager.generate( + prompt, sampling_params, request_id, MultimodalParams() + ) + + completion_tokens = 0 + text_outputs = "" + cumulative_logprob = 0.0 + + async for request_output, metadata, finish_status in results_generator: + text_outputs += request_output + completion_tokens += 1 + + partial_stop = any(is_partial_stop(text_outputs, i) for i in stop) + # prevent yielding partial stop sequence + if partial_stop: + continue + + if type(finish_status) is bool: # compatibility with old version + finish_reason = "stop" if finish_status else None + else: + finish_reason = finish_status.get_finish_reason() + + if request and await request.is_disconnected(): + await httpserver_manager.abort(request_id) + finish_reason = "abort" + + logprob = metadata.get("logprob", None) + if logprob is not None: + cumulative_logprob += logprob + + prompt_tokens = metadata["prompt_tokens"] + ret = { + "text": prompt + text_outputs if echo else text_outputs, + "error_code": 0, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "cumulative_logprob": cumulative_logprob, + } + + if finish_reason is not None: + yield ( + json.dumps({**ret, "finish_reason": None}, ensure_ascii=False) + + "\0" + ).encode("utf-8") + yield ( + json.dumps({**ret, "finish_reason": finish_reason}, ensure_ascii=False) + + "\0" + ).encode("utf-8") + + if finish_reason is not None: # In case of abort, we need to break the loop + break + + async def generate(self, params): + async for x in self.generate_stream(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(request_id): + async def abort_request() -> None: + await httpserver_manager.abort(request_id) + + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + background_tasks.add_task(abort_request) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = g_id_gen.generate_id() + params["request_id"] = request_id + params["request"] = request + generator = worker.generate_stream(params) + background_tasks = create_background_tasks(request_id) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = g_id_gen.generate_id() + params["request_id"] = request_id + params["request"] = request + output = await worker.generate(params) + release_worker_semaphore() + await httpserver_manager.abort(request_id) + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn") + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + + parser.add_argument( + "--model-path", + dest="model_dir", + type=str, + default=None, + help="the model weight dir path, the app will load config, weights and tokenizer from this dir", + ) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=1024) + parser.add_argument("--no-register", action="store_true") + + parser.add_argument( + "--tokenizer_mode", + type=str, + default="slow", + help="""tokenizer load mode, can be slow or auto, slow mode load fast but run slow, slow mode is good for debug and test, + when you want to get best performance, try auto mode""", + ) + parser.add_argument( + "--load_way", + type=str, + default="HF", + help="the way of loading model weights, the default is HF(Huggingface format), llama also supports DS(Deepspeed)", + ) + parser.add_argument( + "--max_total_token_num", + type=int, + default=6000, + help="the total token nums the gpu and model can support, equals = max_batch * (input_len + output_len)", + ) + parser.add_argument( + "--batch_max_tokens", + type=int, + default=None, + help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM", + ) + parser.add_argument("--eos_id", type=int, default=2, help="eos stop token id") + parser.add_argument( + "--running_max_req_size", + type=int, + default=1000, + help="the max size for forward requests in the same time", + ) + parser.add_argument( + "--tp", type=int, default=1, help="model tp parral size, the default is 1" + ) + parser.add_argument( + "--max_req_input_len", + type=int, + default=None, + help="the max value for req input tokens num. If None, it will be derived from the config.", + ) + parser.add_argument( + "--max_req_total_len", + type=int, + default=None, + help="the max value for req_input_len + req_output_len. If None, it will be derived from the config.", + ) + parser.add_argument( + "--mode", + type=str, + default=[], + nargs="+", + help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding + | triton_gqa_attention | triton_gqa_flashdecoding] + [triton_int8weight | triton_int4weight | lmdeploy_int4weight | ppl_int4weight], + triton_flashdecoding mode is for long context, current support llama llama2 qwen; + triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA; + triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; + ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; + ppl_fp16 mode use ppl fast fp16 decode attention kernel; + triton_int8weight and triton_int4weight and lmdeploy_int4weight or ppl_int4weight mode use int8 and int4 to store weights; + you need to read source code to make sure the supported detail mode for all models""", + ) + parser.add_argument( + "--trust_remote_code", + action="store_true", + help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", + ) + parser.add_argument( + "--disable_log_stats", + action="store_true", + help="disable logging throughput stats.", + ) + parser.add_argument( + "--log_stats_interval", + type=int, + default=10, + help="log stats interval in second.", + ) + + parser.add_argument( + "--router_token_ratio", + type=float, + default=0.0, + help="token ratio to control router dispatch", + ) + parser.add_argument( + "--router_max_new_token_len", + type=int, + default=1024, + help="the request max new token len for router", + ) + + parser.add_argument( + "--no_skipping_special_tokens", + action="store_true", + help="whether to skip special tokens when decoding", + ) + parser.add_argument( + "--no_spaces_between_special_tokens", + action="store_true", + help="whether to add spaces between special tokens when decoding", + ) + + parser.add_argument( + "--splitfuse_mode", action="store_true", help="use splitfuse mode" + ) + parser.add_argument( + "--splitfuse_block_size", type=int, default=256, help="splitfuse block size" + ) + parser.add_argument( + "--prompt_cache_strs", + type=str, + default=[], + nargs="+", + help="""prompt cache strs""", + ) + parser.add_argument( + "--cache_capacity", + type=int, + default=200, + help="cache server capacity for multimodal resources", + ) + parser.add_argument( + "--cache_reserved_ratio", + type=float, + default=0.5, + help="cache server reserved capacity ratio after clear", + ) + parser.add_argument( + "--return_all_prompt_logprobs", + action="store_true", + help="return all prompt tokens logprobs", + ) + parser.add_argument( + "--long_truncation_mode", + type=str, + choices=[None, "head", "center"], + default=None, + help="""use to select the handle way when input token len > max_req_input_len. + None : raise Exception + head : remove some head tokens to make input token len <= max_req_input_len + center : remove some tokens in center loc to make input token len <= max_req_input_len""", + ) + + args = parser.parse_args() + + # 非splitfuse 模式,不支持 prompt cache 特性 + if not args.splitfuse_mode: + assert len(args.prompt_cache_strs) == 0 + + model_config = AutoConfig.from_pretrained(args.model_dir) + context_length = get_context_length(model_config) + + if args.max_req_input_len is None: + args.max_req_input_len = context_length - 1 + if args.max_req_total_len is None: + args.max_req_total_len = context_length + + assert args.max_req_input_len < args.max_req_total_len + assert args.max_req_total_len <= args.max_total_token_num + + if not args.splitfuse_mode: + # 普通模式下 + if args.batch_max_tokens is None: + batch_max_tokens = int(1 / 6 * args.max_total_token_num) + batch_max_tokens = max(batch_max_tokens, args.max_req_total_len) + args.batch_max_tokens = batch_max_tokens + else: + assert ( + args.batch_max_tokens >= args.max_req_total_len + ), "batch_max_tokens must >= max_req_total_len" + else: + # splitfuse 模式下 + # assert args.batch_max_tokens is not None, "need to set by yourself" + if args.batch_max_tokens is None: + batch_max_tokens = int(1 / 6 * args.max_total_token_num) + batch_max_tokens = max(batch_max_tokens, args.splitfuse_block_size) + args.batch_max_tokens = batch_max_tokens + + can_use_ports = alloc_can_use_network_port(num=6 + args.tp) + + assert can_use_ports is not None, "Can not alloc enough free ports." + ( + router_port, + detokenization_port, + httpserver_port, + visual_port, + cache_port, + nccl_port, + ) = can_use_ports[0:6] + args.nccl_port = nccl_port + model_rpc_ports = can_use_ports[6:] + + global httpserver_manager + httpserver_manager = HttpServerManager( + args, + router_port=router_port, + cache_port=cache_port, + visual_port=visual_port, + httpserver_port=httpserver_port, + enable_multimodal=False, + ) + + start_submodule_processes( + start_funcs=[start_router_process, start_detokenization_process], + start_args=[ + (args, router_port, detokenization_port, model_rpc_ports), + (args, detokenization_port, httpserver_port), + ], + ) + worker = LightLLMWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_dir, + args.model_names, + args.limit_worker_concurrency, + args.no_register, + args.conv_template, + httpserver_manager.tokenizer, + context_length, + ) + + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/fastchat/serve/mlx_worker.py b/fastchat/serve/mlx_worker.py new file mode 100644 index 000000000..a7e85f848 --- /dev/null +++ b/fastchat/serve/mlx_worker.py @@ -0,0 +1,288 @@ +""" +A model worker using Apple MLX + +https://github.com/ml-explore/mlx-examples/tree/main/llms + +Code based on vllm_worker https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py + +You must install MLX python: + +pip install mlx-lm +""" + +import argparse +import asyncio +import atexit +import json +from typing import List +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.concurrency import run_in_threadpool +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn + +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) +from fastchat.utils import get_context_length, is_partial_stop + +import mlx.core as mx +from mlx_lm import load, generate +from mlx_lm.utils import generate_step + +app = FastAPI() + + +class MLXWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + llm_engine: "MLX", + conv_template: str, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: MLX worker..." + ) + + self.model_name = model_path + self.mlx_model, self.mlx_tokenizer = load(model_path) + + self.tokenizer = self.mlx_tokenizer + # self.context_len = get_context_length( + # llm_engine.engine.model_config.hf_config) + self.context_len = 2048 # hard code for now -- not sure how to get in MLX + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + context = params.pop("prompt") + request_id = params.pop("request_id") + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + presence_penalty = float(params.get("presence_penalty", 0.0)) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + max_new_tokens = params.get("max_new_tokens", 256) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if self.tokenizer.eos_token_id is not None: + stop_token_ids.append(self.tokenizer.eos_token_id) + echo = params.get("echo", True) + use_beam_search = params.get("use_beam_search", False) + best_of = params.get("best_of", None) + + # Handle stop_str + stop = set() + if isinstance(stop_str, str) and stop_str != "": + stop.add(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.update(stop_str) + + for tid in stop_token_ids: + if tid is not None: + s = self.tokenizer.decode(tid) + if s != "": + stop.add(s) + + print("Stop patterns: ", stop) + + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + + tokens = [] + skip = 0 + + context_mlx = mx.array(self.tokenizer.encode(context)) + + finish_reason = "length" + + iterator = await run_in_threadpool( + generate_step, context_mlx, self.mlx_model, temperature + ) + + for i in range(max_new_tokens): + (token, _) = await run_in_threadpool(next, iterator) + if token == self.mlx_tokenizer.eos_token_id: + finish_reason = "stop" + break + tokens.append(token.item()) + tokens_decoded = self.mlx_tokenizer.decode(tokens) + last_token_decoded = self.mlx_tokenizer.decode([token.item()]) + skip = len(tokens_decoded) + + partial_stop = any(is_partial_stop(tokens_decoded, i) for i in stop) + + if partial_stop: + finish_reason = "stop" + break + + ret = { + "text": tokens_decoded, + "error_code": 0, + "usage": { + "prompt_tokens": len(context), + "completion_tokens": len(tokens), + "total_tokens": len(context) + len(tokens), + }, + "cumulative_logprob": [], + "finish_reason": None, # hard code for now + } + # print(ret) + yield (json.dumps(ret) + "\0").encode() + ret = { + "text": self.mlx_tokenizer.decode(tokens), + "error_code": 0, + "usage": {}, + "cumulative_logprob": [], + "finish_reason": finish_reason, + } + yield (json.dumps(obj={**ret, **{"finish_reason": None}}) + "\0").encode() + yield (json.dumps(ret) + "\0").encode() + + async def generate(self, params): + async for x in self.generate_stream(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(request_id): + async def abort_request() -> None: + print("trying to abort but not implemented") + + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + background_tasks.add_task(abort_request) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = uuid.uuid4() + params["request_id"] = str(request_id) + generator = worker.generate_stream(params) + background_tasks = create_background_tasks(request_id) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = uuid.uuid4() + params["request_id"] = str(request_id) + output = await worker.generate(params) + release_worker_semaphore() + # await engine.abort(request_id) + print("Trying to abort but not implemented") + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +worker = None + + +def cleanup_at_exit(): + global worker + print("Cleaning up...") + del worker + + +atexit.register(cleanup_at_exit) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--model-path", type=str, default="microsoft/phi-2") + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--trust_remote_code", + action="store_false", + default=True, + help="Trust remote code (e.g., from HuggingFace) when" + "downloading the model and tokenizer.", + ) + + args, unknown = parser.parse_known_args() + + if args.model_path: + args.model = args.model_path + + worker = MLXWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + 1024, + False, + "MLX", + args.conv_template, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py index 5e84a4262..683a78556 100644 --- a/fastchat/serve/model_worker.py +++ b/fastchat/serve/model_worker.py @@ -31,7 +31,6 @@ str_to_torch_dtype, ) - worker_id = str(uuid.uuid4())[:8] logger = build_logger("model_worker", f"model_worker_{worker_id}.log") @@ -49,6 +48,7 @@ def __init__( device: str, num_gpus: int, max_gpu_memory: str, + revision: str = None, dtype: Optional[torch.dtype] = None, load_8bit: bool = False, cpu_offloading: bool = False, @@ -76,6 +76,7 @@ def __init__( logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...") self.model, self.tokenizer = load_model( model_path, + revision=revision, device=device, num_gpus=num_gpus, max_gpu_memory=max_gpu_memory, @@ -101,6 +102,10 @@ def __init__( self.init_heart_beat() def generate_stream_gate(self, params): + if self.device == "npu": + import torch_npu + + torch_npu.npu.set_device("npu:0") self.call_ct += 1 try: @@ -159,9 +164,13 @@ def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict): data = model_output.hidden_states[-1].transpose(0, 1) else: data = model_output.hidden_states[-1] - mask = attention_mask.unsqueeze(-1).expand(data.size()).float() - masked_embeddings = data * mask - sum_embeddings = torch.sum(masked_embeddings, dim=1) + + if hasattr(self.model, "use_cls_pooling") and self.model.use_cls_pooling: + sum_embeddings = data[:, 0] + else: + mask = attention_mask.unsqueeze(-1).expand(data.size()).float() + masked_embeddings = data * mask + sum_embeddings = torch.sum(masked_embeddings, dim=1) token_num = torch.sum(attention_mask).item() return sum_embeddings, token_num @@ -206,10 +215,14 @@ def get_embeddings(self, params): base64_encode = params.get("encoding_format", None) if self.embed_in_truncate: - chunk_embeddings, token_num = self.__process_embed_chunk( + embedding, token_num = self.__process_embed_chunk( input_ids, attention_mask, **model_type_dict ) - embedding = chunk_embeddings / token_num + if ( + not hasattr(self.model, "use_cls_pooling") + or not self.model.use_cls_pooling + ): + embedding = embedding / token_num normalized_embeddings = F.normalize(embedding, p=2, dim=1) ret["token_num"] = token_num else: @@ -219,10 +232,41 @@ def get_embeddings(self, params): chunk_input_ids = input_ids[:, i : i + self.context_len] chunk_attention_mask = attention_mask[:, i : i + self.context_len] + # add cls token and mask to get cls embedding + if ( + hasattr(self.model, "use_cls_pooling") + and self.model.use_cls_pooling + ): + cls_tokens = ( + torch.zeros( + (chunk_input_ids.size(0), 1), + dtype=chunk_input_ids.dtype, + device=chunk_input_ids.device, + ) + + tokenizer.cls_token_id + ) + chunk_input_ids = torch.cat( + [cls_tokens, chunk_input_ids], dim=-1 + ) + mask = torch.ones( + (chunk_attention_mask.size(0), 1), + dtype=chunk_attention_mask.dtype, + device=chunk_attention_mask.device, + ) + chunk_attention_mask = torch.cat( + [mask, chunk_attention_mask], dim=-1 + ) + chunk_embeddings, token_num = self.__process_embed_chunk( chunk_input_ids, chunk_attention_mask, **model_type_dict ) - all_embeddings.append(chunk_embeddings) + if ( + hasattr(self.model, "use_cls_pooling") + and self.model.use_cls_pooling + ): + all_embeddings.append(chunk_embeddings * token_num) + else: + all_embeddings.append(chunk_embeddings) all_token_num += token_num all_embeddings_tensor = torch.stack(all_embeddings) @@ -345,6 +389,7 @@ def create_model_worker(): args.model_path, args.model_names, args.limit_worker_concurrency, + revision=args.revision, no_register=args.no_register, device=args.device, num_gpus=args.num_gpus, diff --git a/fastchat/serve/monitor/basic_stats.py b/fastchat/serve/monitor/basic_stats.py index e1934bb07..3c1a8793d 100644 --- a/fastchat/serve/monitor/basic_stats.py +++ b/fastchat/serve/monitor/basic_stats.py @@ -13,50 +13,60 @@ NUM_SERVERS = 14 +LOG_ROOT_DIR = "~/fastchat_logs" def get_log_files(max_num_files=None): - dates = [] - for month in range(4, 12): - for day in range(1, 33): - dates.append(f"2023-{month:02d}-{day:02d}") - + log_root = os.path.expanduser(LOG_ROOT_DIR) filenames = [] - for d in dates: - for i in range(NUM_SERVERS): - name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") - if os.path.exists(name): - filenames.append(name) + for i in range(NUM_SERVERS): + for filename in os.listdir(f"{log_root}/server{i}"): + if filename.endswith("-conv.json"): + filepath = f"{log_root}/server{i}/{filename}" + name_tstamp_tuple = (filepath, os.path.getmtime(filepath)) + filenames.append(name_tstamp_tuple) + # sort by tstamp + filenames = sorted(filenames, key=lambda x: x[1]) + filenames = [x[0] for x in filenames] + max_num_files = max_num_files or len(filenames) filenames = filenames[-max_num_files:] return filenames -def load_log_files(log_files): +def load_log_files(filename): data = [] - for filename in tqdm(log_files, desc="read files"): - for retry in range(5): - try: - lines = open(filename).readlines() - break - except FileNotFoundError: - time.sleep(2) - - for l in lines: - row = json.loads(l) - - data.append( - dict( - type=row["type"], - tstamp=row["tstamp"], - model=row.get("model", ""), - models=row.get("models", ["", ""]), - ) + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + data.append( + dict( + type=row["type"], + tstamp=row["tstamp"], + model=row.get("model", ""), + models=row.get("models", ["", ""]), ) - + ) return data +def load_log_files_parallel(log_files, num_threads=16): + data_all = [] + from multiprocessing import Pool + + with Pool(num_threads) as p: + ret_all = list(tqdm(p.imap(load_log_files, log_files), total=len(log_files))) + for ret in ret_all: + data_all.extend(ret) + return data_all + + def get_anony_vote_df(df): anony_vote_df = df[ df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"]) @@ -77,7 +87,7 @@ def merge_counts(series, on, names): def report_basic_stats(log_files): - df_all = load_log_files(log_files) + df_all = load_log_files_parallel(log_files) df_all = pd.DataFrame(df_all) now_t = df_all["tstamp"].max() df_1_hour = df_all[df_all["tstamp"] > (now_t - 3600)] diff --git a/fastchat/serve/monitor/clean_battle_data.py b/fastchat/serve/monitor/clean_battle_data.py index 23357d08c..58541c3d0 100644 --- a/fastchat/serve/monitor/clean_battle_data.py +++ b/fastchat/serve/monitor/clean_battle_data.py @@ -27,6 +27,7 @@ "laion", "chatglm", "chatgpt", + "gpt-4", "openai", "anthropic", "claude", @@ -35,31 +36,26 @@ "lamda", "google", "llama", + "qianwan", + "alibaba", + "mistral", + "zhipu", + "KEG lab", + "01.AI", + "AI2", + "Tülu", + "Tulu", "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.", "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES.", + "API REQUEST ERROR. Please increase the number of max tokens.", + "**API REQUEST ERROR** Reason: The response was blocked.", + "**API REQUEST ERROR**", ] for i in range(len(IDENTITY_WORDS)): IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower() -def get_log_files(max_num_files=None): - dates = [] - for month in range(4, 12): - for day in range(1, 33): - dates.append(f"2023-{month:02d}-{day:02d}") - - filenames = [] - for d in dates: - for i in range(NUM_SERVERS): - name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") - if os.path.exists(name): - filenames.append(name) - max_num_files = max_num_files or len(filenames) - filenames = filenames[-max_num_files:] - return filenames - - def remove_html(raw): if raw.startswith("

"): return raw[raw.find(": ") + 2 : -len("

\n")] @@ -74,29 +70,54 @@ def to_openai_format(messages): return ret -def replace_model_name(old_name): - return ( - old_name.replace("bard", "palm-2") - .replace("claude-v1", "claude-1") - .replace("claude-instant-v1", "claude-instant-1") - .replace("oasst-sft-1-pythia-12b", "oasst-pythia-12b") - ) +def replace_model_name(old_name, tstamp): + replace_dict = { + "bard": "palm-2", + "claude-v1": "claude-1", + "claude-instant-v1": "claude-instant-1", + "oasst-sft-1-pythia-12b": "oasst-pythia-12b", + "claude-2": "claude-2.0", + } + if old_name in ["gpt-4", "gpt-3.5-turbo"]: + if tstamp > 1687849200: + return old_name + "-0613" + else: + return old_name + "-0314" + if old_name in replace_dict: + return replace_dict[old_name] + return old_name -def clean_battle_data(log_files, exclude_model_names): +def read_file(filename): data = [] - for filename in tqdm(log_files, desc="read files"): - for retry in range(5): - try: - lines = open(filename).readlines() - break - except FileNotFoundError: - time.sleep(2) - - for l in lines: - row = json.loads(l) - if row["type"] in VOTES: - data.append(row) + for retry in range(5): + try: + # lines = open(filename).readlines() + for l in open(filename): + row = json.loads(l) + if row["type"] in VOTES: + data.append(row) + break + except FileNotFoundError: + time.sleep(2) + return data + + +def read_file_parallel(log_files, num_threads=16): + data_all = [] + from multiprocessing import Pool + + with Pool(num_threads) as p: + ret_all = list(tqdm(p.imap(read_file, log_files), total=len(log_files))) + for ret in ret_all: + data_all.extend(ret) + return data_all + + +def clean_battle_data( + log_files, exclude_model_names, ban_ip_list=None, sanitize_ip=False +): + data = read_file_parallel(log_files, num_threads=16) convert_type = { "leftvote": "model_a", @@ -110,6 +131,7 @@ def clean_battle_data(log_files, exclude_model_names): ct_anony = 0 ct_invalid = 0 ct_leaked_identity = 0 + ct_banned = 0 battles = [] for row in data: if row["models"][0] is None or row["models"][1] is None: @@ -156,7 +178,9 @@ def clean_battle_data(log_files, exclude_model_names): messages = "" for i in range(2): state = row["states"][i] - for role, msg in state["messages"][state["offset"] :]: + for turn_idx, (role, msg) in enumerate( + state["messages"][state["offset"] :] + ): if msg: messages += msg.lower() for word in IDENTITY_WORDS: @@ -169,7 +193,11 @@ def clean_battle_data(log_files, exclude_model_names): continue # Replace bard with palm - models = [replace_model_name(m) for m in models] + models = [replace_model_name(m, row["tstamp"]) for m in models] + # Exclude certain models + if exclude_model_names and any(x in exclude_model_names for x in models): + ct_invalid += 1 + continue # Exclude certain models if any(x in exclude_model_names for x in models): @@ -186,8 +214,16 @@ def clean_battle_data(log_files, exclude_model_names): ip = row["ip"] if ip not in all_ips: - all_ips[ip] = len(all_ips) - user_id = all_ips[ip] + all_ips[ip] = {"ip": ip, "count": 0, "sanitized_id": len(all_ips)} + all_ips[ip]["count"] += 1 + if sanitize_ip: + user_id = f"arena_user_{all_ips[ip]['sanitized_id']}" + else: + user_id = f"{all_ips[ip]['ip']}" + + if ban_ip_list is not None and ip in ban_ip_list: + ct_banned += 1 + continue # Save the results battles.append( @@ -216,12 +252,19 @@ def clean_battle_data(log_files, exclude_model_names): print( f"#votes: {len(data)}, #invalid votes: {ct_invalid}, " - f"#leaked_identity: {ct_leaked_identity}" + f"#leaked_identity: {ct_leaked_identity} " + f"#banned: {ct_banned} " ) print(f"#battles: {len(battles)}, #anony: {ct_anony}") print(f"#models: {len(all_models)}, {all_models}") print(f"last-updated: {last_updated_datetime}") + if ban_ip_list is not None: + for ban_ip in ban_ip_list: + if ban_ip in all_ips: + del all_ips[ban_ip] + print("Top 30 IPs:") + print(sorted(all_ips.values(), key=lambda x: x["count"], reverse=True)[:30]) return battles @@ -232,10 +275,16 @@ def clean_battle_data(log_files, exclude_model_names): "--mode", type=str, choices=["simple", "conv_release"], default="simple" ) parser.add_argument("--exclude-model-names", type=str, nargs="+") + parser.add_argument("--ban-ip-file", type=str) + parser.add_argument("--sanitize-ip", action="store_true", default=False) args = parser.parse_args() log_files = get_log_files(args.max_num_files) - battles = clean_battle_data(log_files, args.exclude_model_names or []) + ban_ip_list = json.load(open(args.ban_ip_file)) if args.ban_ip_file else None + + battles = clean_battle_data( + log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip + ) last_updated_tstamp = battles[-1]["tstamp"] cutoff_date = datetime.datetime.fromtimestamp( last_updated_tstamp, tz=timezone("US/Pacific") diff --git a/fastchat/serve/monitor/clean_chat_data.py b/fastchat/serve/monitor/clean_chat_data.py index 7f0c9bd4f..1dd8b594d 100644 --- a/fastchat/serve/monitor/clean_chat_data.py +++ b/fastchat/serve/monitor/clean_chat_data.py @@ -2,7 +2,7 @@ Clean chatbot arena chat log. Usage: -python3 clean_chat_data.py --mode conv_release +python3 clean_chat_data.py """ import argparse import datetime diff --git a/fastchat/serve/monitor/elo_analysis.py b/fastchat/serve/monitor/elo_analysis.py index e95f157c8..d0ff0fb09 100644 --- a/fastchat/serve/monitor/elo_analysis.py +++ b/fastchat/serve/monitor/elo_analysis.py @@ -52,6 +52,41 @@ def get_bootstrap_result(battles, func_compute_elo, num_round=1000): return df[df.median().sort_values(ascending=False).index] +def compute_elo_mle_with_tie(df, SCALE=400, BASE=10, INIT_RATING=1000): + from sklearn.linear_model import LogisticRegression + + models = pd.concat([df["model_a"], df["model_b"]]).unique() + models = pd.Series(np.arange(len(models)), index=models) + + # duplicate battles + df = pd.concat([df, df], ignore_index=True) + p = len(models.index) + n = df.shape[0] + + X = np.zeros([n, p]) + X[np.arange(n), models[df["model_a"]]] = +math.log(BASE) + X[np.arange(n), models[df["model_b"]]] = -math.log(BASE) + + # one A win => two A win + Y = np.zeros(n) + Y[df["winner"] == "model_a"] = 1.0 + + # one tie => one A win + one B win + # find tie + tie (both bad) index + tie_idx = (df["winner"] == "tie") | (df["winner"] == "tie (bothbad)") + tie_idx[len(tie_idx) // 2 :] = False + Y[tie_idx] = 1.0 + + lr = LogisticRegression(fit_intercept=False) + lr.fit(X, Y) + + elo_scores = SCALE * lr.coef_[0] + INIT_RATING + # calibrate llama-13b to 800 if applicable + if "llama-13b" in models.index: + elo_scores += 800 - elo_scores[models["llama-13b"]] + return pd.Series(elo_scores, index=models.index).sort_values(ascending=False) + + def get_median_elo_from_bootstrap(bootstrap_df): median = dict(bootstrap_df.quantile(0.5)) median = {k: int(v + 0.5) for k, v in median.items()} @@ -185,12 +220,12 @@ def visualize_average_win_rate(battles, limit_show_number): return fig -def visualize_bootstrap_elo_rating(df, limit_show_number): +def visualize_bootstrap_elo_rating(df, df_final, limit_show_number): bars = ( pd.DataFrame( dict( lower=df.quantile(0.025), - rating=df.quantile(0.5), + rating=df_final, upper=df.quantile(0.975), ) ) @@ -215,7 +250,7 @@ def visualize_bootstrap_elo_rating(df, limit_show_number): return fig -def report_elo_analysis_results(battles_json): +def report_elo_analysis_results(battles_json, rating_system="bt", num_bootstrap=100): battles = pd.DataFrame(battles_json) battles = battles.sort_values(ascending=True, by=["tstamp"]) # Only use anonymous votes @@ -225,24 +260,48 @@ def report_elo_analysis_results(battles_json): # Online update elo_rating_online = compute_elo(battles) - # Bootstrap - bootstrap_df = get_bootstrap_result(battles, compute_elo) - elo_rating_median = get_median_elo_from_bootstrap(bootstrap_df) - model_order = list(elo_rating_median.keys()) - model_order.sort(key=lambda k: -elo_rating_median[k]) + if rating_system == "bt": + bootstrap_df = get_bootstrap_result( + battles, compute_elo_mle_with_tie, num_round=num_bootstrap + ) + elo_rating_final = compute_elo_mle_with_tie(battles) + elif rating_system == "elo": + bootstrap_df = get_bootstrap_result( + battles, compute_elo, num_round=num_bootstrap + ) + elo_rating_median = get_median_elo_from_bootstrap(bootstrap_df) + elo_rating_final = elo_rating_median + + model_order = list(elo_rating_final.keys()) + model_order.sort(key=lambda k: -elo_rating_final[k]) + + limit_show_number = 25 # limit show number to make plots smaller + model_order = model_order[:limit_show_number] + + # leaderboard_table_df: elo rating, variance, 95% interval, number of battles + leaderboard_table_df = pd.DataFrame( + { + "rating": elo_rating_final, + "variance": bootstrap_df.var(), + "rating_q975": bootstrap_df.quantile(0.975), + "rating_q025": bootstrap_df.quantile(0.025), + "num_battles": battles["model_a"].value_counts() + + battles["model_b"].value_counts(), + } + ) limit_show_number = 25 # limit show number to make plots smaller model_order = model_order[:limit_show_number] # Plots - leaderboard_table = visualize_leaderboard_table(elo_rating_median) + leaderboard_table = visualize_leaderboard_table(elo_rating_final) win_fraction_heatmap = visualize_pairwise_win_fraction(battles_no_ties, model_order) battle_count_heatmap = visualize_battle_count(battles_no_ties, model_order) average_win_rate_bar = visualize_average_win_rate( battles_no_ties, limit_show_number ) bootstrap_elo_rating = visualize_bootstrap_elo_rating( - bootstrap_df, limit_show_number + bootstrap_df, elo_rating_final, limit_show_number ) last_updated_tstamp = battles["tstamp"].max() @@ -251,8 +310,9 @@ def report_elo_analysis_results(battles_json): ).strftime("%Y-%m-%d %H:%M:%S %Z") return { + "rating_system": rating_system, "elo_rating_online": elo_rating_online, - "elo_rating_median": elo_rating_median, + "elo_rating_final": elo_rating_final, "leaderboard_table": leaderboard_table, "win_fraction_heatmap": win_fraction_heatmap, "battle_count_heatmap": battle_count_heatmap, @@ -260,6 +320,8 @@ def report_elo_analysis_results(battles_json): "bootstrap_elo_rating": bootstrap_elo_rating, "last_updated_datetime": last_updated_datetime, "last_updated_tstamp": last_updated_tstamp, + "bootstrap_df": bootstrap_df, + "leaderboard_table_df": leaderboard_table_df, } @@ -274,6 +336,11 @@ def pretty_print_elo_rating(rating): parser = argparse.ArgumentParser() parser.add_argument("--clean-battle-file", type=str) parser.add_argument("--max-num-files", type=int) + parser.add_argument("--num-bootstrap", type=int, default=100) + parser.add_argument( + "--rating-system", type=str, choices=["bt", "elo"], default="bt" + ) + parser.add_argument("--exclude-tie", action="store_true", default=False) args = parser.parse_args() np.random.seed(42) @@ -286,12 +353,14 @@ def pretty_print_elo_rating(rating): log_files = get_log_files(args.max_num_files) battles = clean_battle_data(log_files) - results = report_elo_analysis_results(battles) + results = report_elo_analysis_results( + battles, rating_system=args.rating_system, num_bootstrap=args.num_bootstrap + ) - print("# Online") + print("# Online Elo") pretty_print_elo_rating(results["elo_rating_online"]) print("# Median") - pretty_print_elo_rating(results["elo_rating_median"]) + pretty_print_elo_rating(results["elo_rating_final"]) print(f"last update : {results['last_updated_datetime']}") last_updated_tstamp = results["last_updated_tstamp"] diff --git a/fastchat/serve/monitor/monitor.py b/fastchat/serve/monitor/monitor.py index 580a2c866..1912ef6fe 100644 --- a/fastchat/serve/monitor/monitor.py +++ b/fastchat/serve/monitor/monitor.py @@ -8,11 +8,13 @@ import argparse import ast +import json import pickle import os import threading import time +import pandas as pd import gradio as gr import numpy as np @@ -22,24 +24,52 @@ from fastchat.utils import build_logger, get_window_url_params_js -notebook_url = "https://colab.research.google.com/drive/1RAWb22-PFNI-X1gPVzc927SGUdfr6nsR?usp=sharing" - +notebook_url = ( + "https://colab.research.google.com/drive/1KdwokPjirkTmpO_P1WByFNFiqxWQquwH" +) basic_component_values = [None] * 6 leader_component_values = [None] * 5 -def make_leaderboard_md(elo_results): +def make_default_md(arena_df, elo_results): + total_votes = sum(arena_df["num_battles"]) // 2 + total_models = len(arena_df) + + leaderboard_md = f""" +# 🏆 LMSYS Chatbot Arena Leaderboard +| [Vote](https://chat.lmsys.org) | [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | + +LMSYS [Chatbot Arena](https://lmsys.org/blog/2023-05-03-arena/) is a crowdsourced open platform for LLM evals. +We've collected over **200,000** human preference votes to rank LLMs with the Elo ranking system. +""" + return leaderboard_md + + +def make_arena_leaderboard_md(arena_df): + total_votes = sum(arena_df["num_battles"]) // 2 + total_models = len(arena_df) + leaderboard_md = f""" -# 🏆 Chatbot Arena Leaderboard -| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | +Total #models: **{total_models}**. Total #votes: **{total_votes}**. Last updated: Feb 2, 2024. + +Contribute your vote 🗳️ at [chat.lmsys.org](https://chat.lmsys.org)! Find more analysis in the [notebook]({notebook_url}). + +⚠️ **Some mobile users reported the leaderboard is not displayed normally, please visit [our HF alternative](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) while we are fixing it**. +""" + return leaderboard_md + -This leaderboard is based on the following three benchmarks. -- [Chatbot Arena](https://lmsys.org/blog/2023-05-03-arena/) - a crowdsourced, randomized battle platform. We use 100K+ user votes to compute Elo ratings. -- [MT-Bench](https://arxiv.org/abs/2306.05685) - a set of challenging multi-turn questions. We use GPT-4 to grade the model responses. -- [MMLU](https://arxiv.org/abs/2009.03300) (5-shot) - a test to measure a model's multitask accuracy on 57 tasks. +def make_full_leaderboard_md(elo_results): + leaderboard_md = """ +Three benchmarks are displayed: **Arena Elo**, **MT-Bench** and **MMLU**. +- [Chatbot Arena](https://chat.lmsys.org/?arena) - a crowdsourced, randomized battle platform based on human preference votes. +- [MT-Bench](https://arxiv.org/abs/2306.05685): a set of challenging multi-turn questions. We use GPT-4 to grade the model responses. +- [MMLU](https://arxiv.org/abs/2009.03300) (5-shot): a test to measure a model's multitask accuracy on 57 tasks. -💻 Code: The Arena Elo ratings are computed by this [notebook]({notebook_url}). The MT-bench scores (single-answer grading on a scale of 10) are computed by [fastchat.llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge). The MMLU scores are mostly computed by [InstructEval](https://github.com/declare-lab/instruct-eval). Higher values are better for all benchmarks. Empty cells mean not available. Last updated: November, 2023. +💻 Code: The MT-bench scores (single-answer grading on a scale of 10) are computed by [fastchat.llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge). +The MMLU scores are mostly computed by [InstructEval](https://github.com/declare-lab/instruct-eval). +Higher values are better for all benchmarks. Empty cells mean not available. """ return leaderboard_md @@ -53,12 +83,17 @@ def make_leaderboard_md_live(elo_results): return leaderboard_md -def update_elo_components(max_num_files, elo_results_file): +def update_elo_components( + max_num_files, elo_results_file, ban_ip_file, exclude_model_names +): log_files = get_log_files(max_num_files) # Leaderboard if elo_results_file is None: # Do live update - battles = clean_battle_data(log_files, []) + ban_ip_list = json.load(open(ban_ip_file)) if ban_ip_file else None + battles = clean_battle_data( + log_files, exclude_model_names, ban_ip_list=ban_ip_list + ) elo_results = report_elo_analysis_results(battles) leader_component_values[0] = make_leaderboard_md_live(elo_results) @@ -91,10 +126,14 @@ def update_elo_components(max_num_files, elo_results_file): basic_component_values[5] = md4 -def update_worker(max_num_files, interval, elo_results_file): +def update_worker( + max_num_files, interval, elo_results_file, ban_ip_file, exclude_model_names +): while True: tic = time.time() - update_elo_components(max_num_files, elo_results_file) + update_elo_components( + max_num_files, elo_results_file, ban_ip_file, exclude_model_names + ) durtaion = time.time() - tic print(f"update duration: {durtaion:.2f} s") time.sleep(max(interval - durtaion, 0)) @@ -166,90 +205,186 @@ def build_basic_stats_tab(): return [md0, plot_1, md1, md2, md3, md4] -def build_leaderboard_tab(elo_results_file, leaderboard_table_file): +def get_full_table(arena_df, model_table_df): + values = [] + for i in range(len(model_table_df)): + row = [] + model_key = model_table_df.iloc[i]["key"] + model_name = model_table_df.iloc[i]["Model"] + # model display name + row.append(model_name) + if model_key in arena_df.index: + idx = arena_df.index.get_loc(model_key) + row.append(round(arena_df.iloc[idx]["rating"])) + else: + row.append(np.nan) + row.append(model_table_df.iloc[i]["MT-bench (score)"]) + row.append(model_table_df.iloc[i]["MMLU"]) + # Organization + row.append(model_table_df.iloc[i]["Organization"]) + # license + row.append(model_table_df.iloc[i]["License"]) + + values.append(row) + values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9) + return values + + +def get_arena_table(arena_df, model_table_df): + # sort by rating + arena_df = arena_df.sort_values(by=["rating"], ascending=False) + values = [] + for i in range(len(arena_df)): + row = [] + model_key = arena_df.index[i] + model_name = model_table_df[model_table_df["key"] == model_key]["Model"].values[ + 0 + ] + + # rank + row.append(i + 1) + # model display name + row.append(model_name) + # elo rating + row.append(round(arena_df.iloc[i]["rating"])) + upper_diff = round(arena_df.iloc[i]["rating_q975"] - arena_df.iloc[i]["rating"]) + lower_diff = round(arena_df.iloc[i]["rating"] - arena_df.iloc[i]["rating_q025"]) + row.append(f"+{upper_diff}/-{lower_diff}") + # num battles + row.append(round(arena_df.iloc[i]["num_battles"])) + # Organization + row.append( + model_table_df[model_table_df["key"] == model_key]["Organization"].values[0] + ) + # license + row.append( + model_table_df[model_table_df["key"] == model_key]["License"].values[0] + ) + + values.append(row) + return values + + +def build_leaderboard_tab(elo_results_file, leaderboard_table_file, show_plot=False): if elo_results_file is None: # Do live update - md = "Loading ..." + default_md = "Loading ..." p1 = p2 = p3 = p4 = None else: with open(elo_results_file, "rb") as fin: elo_results = pickle.load(fin) - md = make_leaderboard_md(elo_results) p1 = elo_results["win_fraction_heatmap"] p2 = elo_results["battle_count_heatmap"] p3 = elo_results["bootstrap_elo_rating"] p4 = elo_results["average_win_rate_bar"] + arena_df = elo_results["leaderboard_table_df"] + default_md = make_default_md(arena_df, elo_results) - md_1 = gr.Markdown(md, elem_id="leaderboard_markdown") - + md_1 = gr.Markdown(default_md, elem_id="leaderboard_markdown") if leaderboard_table_file: data = load_leaderboard_table_csv(leaderboard_table_file) - headers = [ - "Model", - "Arena Elo rating", - "MT-bench (score)", - "MMLU", - "License", - ] - values = [] - for item in data: - row = [] - for key in headers: - value = item[key] - row.append(value) - values.append(row) - values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9) - - headers[1] = "⭐ " + headers[1] - headers[2] = "📈 " + headers[2] - - gr.Dataframe( - headers=headers, - datatype=["markdown", "number", "number", "number", "str"], - value=values, - elem_id="leaderboard_dataframe", - ) - gr.Markdown( - """ ## Visit our [HF space](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) for more analysis! - If you want to see more models, please help us [add them](https://github.com/lm-sys/FastChat/blob/main/docs/arena.md#how-to-add-a-new-model). - """, - elem_id="leaderboard_markdown", - ) + model_table_df = pd.DataFrame(data) + + with gr.Tabs() as tabs: + # arena table + arena_table_vals = get_arena_table(arena_df, model_table_df) + with gr.Tab("Arena Elo", id=0): + md = make_arena_leaderboard_md(arena_df) + gr.Markdown(md, elem_id="leaderboard_markdown") + gr.Dataframe( + headers=[ + "Rank", + "🤖 Model", + "⭐ Arena Elo", + "📊 95% CI", + "🗳️ Votes", + "Organization", + "License", + ], + datatype=[ + "str", + "markdown", + "number", + "str", + "number", + "str", + "str", + ], + value=arena_table_vals, + elem_id="arena_leaderboard_dataframe", + height=700, + column_widths=[50, 200, 100, 100, 100, 150, 150], + wrap=True, + ) + with gr.Tab("Full Leaderboard", id=1): + md = make_full_leaderboard_md(elo_results) + gr.Markdown(md, elem_id="leaderboard_markdown") + full_table_vals = get_full_table(arena_df, model_table_df) + gr.Dataframe( + headers=[ + "🤖 Model", + "⭐ Arena Elo", + "📈 MT-bench", + "📚 MMLU", + "Organization", + "License", + ], + datatype=["markdown", "number", "number", "number", "str", "str"], + value=full_table_vals, + elem_id="full_leaderboard_dataframe", + column_widths=[200, 100, 100, 100, 150, 150], + height=700, + wrap=True, + ) + if not show_plot: + gr.Markdown( + """ ## Visit our [HF space](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) for more analysis! + If you want to see more models, please help us [add them](https://github.com/lm-sys/FastChat/blob/main/docs/arena.md#how-to-add-a-new-model). + """, + elem_id="leaderboard_markdown", + ) else: pass - leader_component_values[:] = [md, p1, p2, p3, p4] + leader_component_values[:] = [default_md, p1, p2, p3, p4] - """ - with gr.Row(): - with gr.Column(): - gr.Markdown( - "#### Figure 1: Fraction of Model A Wins for All Non-tied A vs. B Battles" - ) - plot_1 = gr.Plot(p1, show_label=False) - with gr.Column(): - gr.Markdown( - "#### Figure 2: Battle Count for Each Combination of Models (without Ties)" - ) - plot_2 = gr.Plot(p2, show_label=False) - with gr.Row(): - with gr.Column(): - gr.Markdown( - "#### Figure 3: Bootstrap of Elo Estimates (1000 Rounds of Random Sampling)" - ) - plot_3 = gr.Plot(p3, show_label=False) - with gr.Column(): - gr.Markdown( - "#### Figure 4: Average Win Rate Against All Other Models (Assuming Uniform Sampling and No Ties)" - ) - plot_4 = gr.Plot(p4, show_label=False) - """ + if show_plot: + gr.Markdown( + f"""## More Statistics for Chatbot Arena\n +Below are figures for more statistics. The code for generating them is also included in this [notebook]({notebook_url}). +You can find more discussions in this blog [post](https://lmsys.org/blog/2023-12-07-leaderboard/). + """, + elem_id="leaderboard_markdown", + ) + with gr.Row(): + with gr.Column(): + gr.Markdown( + "#### Figure 1: Fraction of Model A Wins for All Non-tied A vs. B Battles" + ) + plot_1 = gr.Plot(p1, show_label=False) + with gr.Column(): + gr.Markdown( + "#### Figure 2: Battle Count for Each Combination of Models (without Ties)" + ) + plot_2 = gr.Plot(p2, show_label=False) + with gr.Row(): + with gr.Column(): + gr.Markdown( + "#### Figure 3: Bootstrap of Elo Estimates (1000 Rounds of Random Sampling)" + ) + plot_3 = gr.Plot(p3, show_label=False) + with gr.Column(): + gr.Markdown( + "#### Figure 4: Average Win Rate Against All Other Models (Assuming Uniform Sampling and No Ties)" + ) + plot_4 = gr.Plot(p4, show_label=False) from fastchat.serve.gradio_web_server import acknowledgment_md - gr.Markdown(acknowledgment_md) + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") - # return [md_1, plot_1, plot_2, plot_3, plot_4] + if show_plot: + return [md_1, plot_1, plot_2, plot_3, plot_4] return [md_1] @@ -266,7 +401,9 @@ def build_demo(elo_results_file, leaderboard_table_file): with gr.Tabs() as tabs: with gr.Tab("Leaderboard", id=0): leader_components = build_leaderboard_tab( - elo_results_file, leaderboard_table_file + elo_results_file, + leaderboard_table_file, + show_plot=True, ) with gr.Tab("Basic Stats", id=1): @@ -293,6 +430,8 @@ def build_demo(elo_results_file, leaderboard_table_file): parser.add_argument("--max-num-files", type=int) parser.add_argument("--elo-results-file", type=str) parser.add_argument("--leaderboard-table-file", type=str) + parser.add_argument("--ban-ip-file", type=str) + parser.add_argument("--exclude-model-names", type=str, nargs="+") args = parser.parse_args() logger = build_logger("monitor", "monitor.log") @@ -301,13 +440,21 @@ def build_demo(elo_results_file, leaderboard_table_file): if args.elo_results_file is None: # Do live update update_thread = threading.Thread( target=update_worker, - args=(args.max_num_files, args.update_interval, args.elo_results_file), + args=( + args.max_num_files, + args.update_interval, + args.elo_results_file, + args.ban_ip_file, + args.exclude_model_names, + ), ) update_thread.start() demo = build_demo(args.elo_results_file, args.leaderboard_table_file) demo.queue( - concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + default_concurrency_limit=args.concurrency_count, + status_update_rate=10, + api_open=False, ).launch( server_name=args.host, server_port=args.port, share=args.share, max_threads=200 ) diff --git a/fastchat/serve/monitor/summarize_cluster.py b/fastchat/serve/monitor/summarize_cluster.py index 1d5fbcddc..b461a68b2 100644 --- a/fastchat/serve/monitor/summarize_cluster.py +++ b/fastchat/serve/monitor/summarize_cluster.py @@ -6,10 +6,12 @@ import argparse import pickle +import pandas as pd + from fastchat.llm_judge.common import ( - chat_compeletion_openai, - chat_compeletion_openai_azure, - chat_compeletion_anthropic, + chat_completion_openai, + chat_completion_openai_azure, + chat_completion_anthropic, ) from fastchat.conversation import get_conv_template @@ -52,13 +54,13 @@ def truncate_string(s, l): if "azure-" in model: template_name = "chatgpt" - completion_func = chat_compeletion_openai_azure + completion_func = chat_completion_openai_azure elif "gpt" in model: template_name = "chatgpt" - completion_func = chat_compeletion_openai + completion_func = chat_completion_openai elif "claude" in model: template_name = "claude" - completion_func = chat_compeletion_anthropic + completion_func = chat_completion_anthropic conv = get_conv_template(template_name) conv.set_system_message(instruct) @@ -74,3 +76,10 @@ def truncate_string(s, l): print() print(f"topics: {topics}") print(f"percentages: {percentages}") + + # save the informations + df = pd.DataFrame() + df["topic"] = topics + df["percentage"] = percentages + + df.to_json(f"cluster_summary_{len(df)}.jsonl", lines=True, orient="records") diff --git a/fastchat/serve/monitor/topic_clustering.py b/fastchat/serve/monitor/topic_clustering.py index dd15c6edc..3d58e56bf 100644 --- a/fastchat/serve/monitor/topic_clustering.py +++ b/fastchat/serve/monitor/topic_clustering.py @@ -16,6 +16,7 @@ from sklearn.cluster import KMeans, AgglomerativeClustering import torch from tqdm import tqdm +from openai import OpenAI from fastchat.utils import detect_language @@ -46,6 +47,8 @@ def read_texts(input_file, min_length, max_length, english_only): line_texts = [ x["content"] for x in l["conversation"] if x["role"] == "user" ] + elif "turns" in l: + line_texts = l["turns"] for text in line_texts: text = text.strip() @@ -77,14 +80,26 @@ def read_texts(input_file, min_length, max_length, english_only): def get_embeddings(texts, model_name, batch_size): - model = SentenceTransformer(model_name) - embeddings = model.encode( - texts, - batch_size=batch_size, - show_progress_bar=True, - device="cuda", - convert_to_tensor=True, - ) + if model_name == "text-embedding-ada-002": + client = OpenAI() + texts = texts.tolist() + + embeddings = [] + for i in tqdm(range(0, len(texts), batch_size)): + text = texts[i : i + batch_size] + responses = client.embeddings.create(input=text, model=model_name).data + embeddings.extend([data.embedding for data in responses]) + embeddings = torch.tensor(embeddings) + else: + model = SentenceTransformer(model_name) + embeddings = model.encode( + texts, + batch_size=batch_size, + show_progress_bar=True, + device="cuda", + convert_to_tensor=True, + ) + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) return embeddings.cpu() @@ -218,6 +233,8 @@ def get_cluster_info(texts, labels, topk_indices): ) parser.add_argument("--show-top-k", type=int, default=200) parser.add_argument("--show-cut-off", type=int, default=512) + parser.add_argument("--save-embeddings", action="store_true") + parser.add_argument("--embeddings-file", type=str, default=None) args = parser.parse_args() num_clusters = args.num_clusters @@ -229,7 +246,15 @@ def get_cluster_info(texts, labels, topk_indices): ) print(f"#text: {len(texts)}") - embeddings = get_embeddings(texts, args.model, args.batch_size) + if args.embeddings_file is None: + embeddings = get_embeddings(texts, args.model, args.batch_size) + if args.save_embeddings: + # allow saving embedding to save time and money + torch.save(embeddings, "embeddings.pt") + else: + embeddings = torch.load(args.embeddings_file) + print(f"embeddings shape: {embeddings.shape}") + if args.cluster_alg == "kmeans": centers, labels = run_k_means(embeddings, num_clusters) elif args.cluster_alg == "aggcls": @@ -249,7 +274,7 @@ def get_cluster_info(texts, labels, topk_indices): with open(filename_prefix + "_topk.txt", "w") as fout: fout.write(topk_str) - with open(filename_prefix + "_all.txt", "w") as fout: + with open(filename_prefix + "_all.jsonl", "w") as fout: for i in range(len(centers)): tmp_indices = labels == i tmp_embeddings = embeddings[tmp_indices] diff --git a/fastchat/serve/openai_api_server.py b/fastchat/serve/openai_api_server.py index 65fcab977..58bfbba92 100644 --- a/fastchat/serve/openai_api_server.py +++ b/fastchat/serve/openai_api_server.py @@ -10,7 +10,6 @@ import asyncio import argparse import json -import logging import os from typing import Generator, Optional, Union, Dict, List, Any @@ -22,7 +21,11 @@ from fastapi.responses import StreamingResponse, JSONResponse from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer import httpx -from pydantic import BaseSettings + +try: + from pydantic.v1 import BaseSettings +except ImportError: + from pydantic import BaseSettings import shortuuid import tiktoken import uvicorn @@ -61,6 +64,7 @@ APITokenCheckResponse, APITokenCheckResponseItem, ) +from fastchat.utils import build_logger ###### Shale @@ -75,7 +79,7 @@ from fastapi.requests import Request -logger = logging.getLogger(__name__) +logger = build_logger("openai_api_server", "openai_api_server.log") conv_template_map = {} @@ -213,7 +217,12 @@ def check_requests(request) -> Optional[JSONResponse]: if request.top_p is not None and request.top_p > 1: return create_error_response( ErrorCode.PARAM_OUT_OF_RANGE, - f"{request.top_p} is greater than the maximum of 1 - 'temperature'", + f"{request.top_p} is greater than the maximum of 1 - 'top_p'", + ) + if request.top_k is not None and (request.top_k > -1 and request.top_k < 1): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.", ) if request.top_k is not None and (request.top_k > -1 and request.top_k < 1): return create_error_response( @@ -236,10 +245,20 @@ def process_input(model_name, inp): inp = [inp] elif isinstance(inp, list): if isinstance(inp[0], int): - decoding = tiktoken.model.encoding_for_model(model_name) + try: + decoding = tiktoken.model.encoding_for_model(model_name) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + decoding = tiktoken.get_encoding(model) inp = [decoding.decode(inp)] elif isinstance(inp[0], list): - decoding = tiktoken.model.encoding_for_model(model_name) + try: + decoding = tiktoken.model.encoding_for_model(model_name) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + decoding = tiktoken.get_encoding(model) inp = [decoding.decode(text) for text in inp] return inp @@ -295,13 +314,29 @@ async def get_gen_params( prompt = messages elif isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], str): prompt = '. '.join(messages) + images = [] else: for message in messages: msg_role = message["role"] if msg_role == "system": conv.set_system_message(message["content"]) elif msg_role == "user": - conv.append_message(conv.roles[0], message["content"]) + if type(message["content"]) == list: + image_list = [ + item["image_url"]["url"] + for item in message["content"] + if item["type"] == "image_url" + ] + text_list = [ + item["text"] + for item in message["content"] + if item["type"] == "text" + ] + + text = "\n".join(text_list) + conv.append_message(conv.roles[0], (text, image_list)) + else: + conv.append_message(conv.roles[0], message["content"]) elif msg_role == "assistant": conv.append_message(conv.roles[1], message["content"]) else: @@ -310,6 +345,7 @@ async def get_gen_params( # Add a blank message for the assistant. conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() + images = conv.get_images() gen_params = { "model": model_name, @@ -325,6 +361,9 @@ async def get_gen_params( "stop_token_ids": conv.stop_token_ids, } + if len(images) > 0: + gen_params["images"] = images + if best_of is not None: gen_params.update({"best_of": best_of}) if use_beam_search is not None: @@ -455,6 +494,9 @@ async def create_chat_completion(request: ChatCompletionRequest): return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) usage = UsageInfo() for i, content in enumerate(all_tasks): + if isinstance(content, str): + content = json.loads(content) + if content["error_code"] != 0: return create_error_response(content["error_code"], content["text"]) choices.append( diff --git a/fastchat/serve/register_worker.py b/fastchat/serve/register_worker.py index 2c2c40295..aa57117b9 100644 --- a/fastchat/serve/register_worker.py +++ b/fastchat/serve/register_worker.py @@ -14,6 +14,7 @@ parser.add_argument("--controller-address", type=str) parser.add_argument("--worker-name", type=str) parser.add_argument("--check-heart-beat", action="store_true") + parser.add_argument("--multimodal", action="store_true") args = parser.parse_args() url = args.controller_address + "/register_worker" @@ -21,6 +22,7 @@ "worker_name": args.worker_name, "check_heart_beat": args.check_heart_beat, "worker_status": None, + "multimodal": args.multimodal, } r = requests.post(url, json=data) assert r.status_code == 200 diff --git a/fastchat/serve/sglang_worker.py b/fastchat/serve/sglang_worker.py new file mode 100644 index 000000000..b30668433 --- /dev/null +++ b/fastchat/serve/sglang_worker.py @@ -0,0 +1,313 @@ +""" +A model worker that executes the model based on SGLang. + +Usage: +python3 -m fastchat.serve.sglang_worker --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000 --worker-address http://localhost:30000 +""" + +import argparse +import asyncio +import json +import multiprocessing +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +import sglang as sgl +from sglang.srt.hf_transformers_utils import get_tokenizer, get_config +from sglang.srt.utils import load_image, is_multimodal_model + +from fastchat.conversation import IMAGE_PLACEHOLDER_STR +from fastchat.constants import ErrorCode, SERVER_ERROR_MSG +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) +from fastchat.utils import get_context_length, is_partial_stop + +app = FastAPI() + + +@sgl.function +def pipeline(s, prompt, max_tokens): + for p in prompt: + if isinstance(p, str): + s += p + else: + s += sgl.image(p) + s += sgl.gen("response", max_tokens=max_tokens) + + +class SGLWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + tokenizer_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + conv_template: str, + runtime: sgl.Runtime, + trust_remote_code: bool, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + is_multimodal_model(model_path), + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: SGLang worker..." + ) + + self.tokenizer = get_tokenizer(tokenizer_path) + self.context_len = get_context_length( + get_config(model_path, trust_remote_code=trust_remote_code) + ) + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + prompt = params.pop("prompt") + images = params.get("images", []) + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + presence_penalty = float(params.get("presence_penalty", 0.0)) + max_new_tokens = params.get("max_new_tokens", 256) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + echo = params.get("echo", True) + + # Handle stop_str + stop = [] + if isinstance(stop_str, str) and stop_str != "": + stop.append(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.extend(stop_str) + + for tid in stop_token_ids: + if tid is not None: + s = self.tokenizer.decode(tid) + if s != "": + stop.append(s) + + # make sampling params for sgl.gen + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + + # split prompt by image token + split_prompt = prompt.split(IMAGE_PLACEHOLDER_STR) + if prompt.count(IMAGE_PLACEHOLDER_STR) != len(images): + raise ValueError( + "The number of images passed in does not match the number of tokens in the prompt!" + ) + prompt = [] + for i in range(len(split_prompt)): + prompt.append(split_prompt[i]) + if i < len(images): + prompt[-1] = prompt[-1].strip() + prompt.append(load_image(images[i])) + + state = pipeline.run( + prompt, + max_new_tokens, + stop=stop, + temperature=temperature, + top_p=top_p, + top_k=top_k, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + stream=True, + ) + + entire_output = prompt if echo else "" + async for out, meta_info in state.text_async_iter( + var_name="response", return_meta_data=True + ): + partial_stop = any(is_partial_stop(out, i) for i in stop) + + # prevent yielding partial stop sequence + if partial_stop: + continue + + entire_output += out + prompt_tokens = meta_info["prompt_tokens"] + completion_tokens = meta_info["completion_tokens"] + + ret = { + "text": entire_output, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "error_code": 0, + } + yield ret + + async def generate_stream_gate(self, params): + try: + async for ret in self.generate_stream(params): + yield json.dumps(ret).encode() + b"\0" + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + async def generate_gate(self, params): + async for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + output = await worker.generate_gate(params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.5") + parser.add_argument("--tokenizer-path", type=str, default="") + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=1024) + parser.add_argument("--no-register", action="store_true") + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--trust-remote-code", + action="store_false", + default=True, + help="Trust remote code (e.g., from HuggingFace) when" + "downloading the model and tokenizer.", + ) + parser.add_argument( + "--mem-fraction-static", + type=float, + default=0.9, + help="The ratio (between 0 and 1) of GPU memory to" + "reserve for the model weights, activations, and KV cache. Higher" + "values will increase the KV cache size and thus improve the model's" + "throughput. However, if the value is too high, it may cause out-of-" + "memory (OOM) errors.", + ) + parser.add_argument( + "--multimodal", + action="store_true", + required=False, + default=False, + help="Register this worker as serving a multimodal model.", + ) + + args = parser.parse_args() + + args.tp_size = args.num_gpus if args.num_gpus > 1 else 1 + args.tokenizer_path = ( + args.model_path if args.tokenizer_path == "" else args.tokenizer_path + ) + + multiprocessing.set_start_method("spawn", force=True) + runtime = sgl.Runtime( + model_path=args.model_path, + tokenizer_path=args.tokenizer_path, + trust_remote_code=args.trust_remote_code, + mem_fraction_static=args.mem_fraction_static, + tp_size=args.tp_size, + log_level="info", + ) + sgl.set_default_backend(runtime) + + worker = SGLWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.tokenizer_path, + args.model_names, + args.limit_worker_concurrency, + args.no_register, + args.conv_template, + runtime, + args.trust_remote_code, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 46e876b2f..e2be90a24 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -22,7 +22,7 @@ logger, worker_id, ) -from fastchat.utils import get_context_length +from fastchat.utils import get_context_length, is_partial_stop app = FastAPI() @@ -55,6 +55,10 @@ def __init__( f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker..." ) self.tokenizer = llm_engine.engine.tokenizer + # This is to support vllm >= 0.2.7 where TokenizerGroup was introduced + # and llm_engine.engine.tokenizer was no longer a raw tokenizer + if hasattr(self.tokenizer, "tokenizer"): + self.tokenizer = llm_engine.engine.tokenizer.tokenizer self.context_len = get_context_length(llm_engine.engine.model_config.hf_config) if not no_register: @@ -79,6 +83,8 @@ async def generate_stream(self, params): use_beam_search = params.get("use_beam_search", False) best_of = params.get("best_of", None) + request = params.get("request", None) + # Handle stop_str stop = set() if isinstance(stop_str, str) and stop_str != "": @@ -88,7 +94,9 @@ async def generate_stream(self, params): for tid in stop_token_ids: if tid is not None: - stop.add(self.tokenizer.decode(tid)) + s = self.tokenizer.decode(tid) + if s != "": + stop.add(s) # make sampling params in vllm top_p = max(top_p, 1e-5) @@ -119,7 +127,20 @@ async def generate_stream(self, params): else: text_outputs = [output.text for output in request_output.outputs] text_outputs = " ".join(text_outputs) - # Note: usage is not supported yet + + partial_stop = any(is_partial_stop(text_outputs, i) for i in stop) + # prevent yielding partial stop sequence + if partial_stop: + continue + + aborted = False + if request and await request.is_disconnected(): + await engine.abort(request_id) + request_output.finished = True + aborted = True + for output in request_output.outputs: + output.finish_reason = "abort" + prompt_tokens = len(request_output.prompt_token_ids) completion_tokens = sum( len(output.token_ids) for output in request_output.outputs @@ -139,8 +160,15 @@ async def generate_stream(self, params): if len(request_output.outputs) == 1 else [output.finish_reason for output in request_output.outputs], } + # Emit twice here to ensure a 'finish_reason' with empty content in the OpenAI API response. + # This aligns with the behavior of model_worker. + if request_output.finished: + yield (json.dumps({**ret, **{"finish_reason": None}}) + "\0").encode() yield (json.dumps(ret) + "\0").encode() + if aborted: + break + async def generate(self, params): async for x in self.generate_stream(params): pass @@ -173,6 +201,7 @@ async def api_generate_stream(request: Request): await acquire_worker_semaphore() request_id = random_uuid() params["request_id"] = request_id + params["request"] = request generator = worker.generate_stream(params) background_tasks = create_background_tasks(request_id) return StreamingResponse(generator, background=background_tasks) @@ -184,6 +213,7 @@ async def api_generate(request: Request): await acquire_worker_semaphore() request_id = random_uuid() params["request_id"] = request_id + params["request"] = request output = await worker.generate(params) release_worker_semaphore() await engine.abort(request_id) diff --git a/fastchat/train/train_baichuan.py b/fastchat/train/train_baichuan.py index 70c6488b5..b6b19b486 100644 --- a/fastchat/train/train_baichuan.py +++ b/fastchat/train/train_baichuan.py @@ -159,7 +159,7 @@ def preprocess(sources, tokenizer: transformers.PreTrainedTokenizer, **kwargs) - else: # If the data volume is large, use multithreading for processing with Pool() as p: conversations, conv = p.apply_async( - apply_prompt_template, (sources, tokenizer, systems) + apply_prompt_template, (sources, systems) ).get() input_ids, targets = p.apply_async( tokenize_conversations, (conversations, tokenizer) diff --git a/fastchat/train/train_with_template.py b/fastchat/train/train_with_template.py new file mode 100644 index 000000000..e5c5f353d --- /dev/null +++ b/fastchat/train/train_with_template.py @@ -0,0 +1,400 @@ +# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright: +# +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +import json +import math +import jsonlines +import pathlib +from multiprocessing import Pool +from typing import Dict, Optional, Sequence + +import numpy as np +import torch +from torch.utils.data import Dataset +import transformers +from transformers import Trainer +from transformers.trainer_pt_utils import LabelSmoother + +from fastchat.conversation import SeparatorStyle +from fastchat.model.model_adapter import get_conversation_template + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + lazy_preprocess: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=512, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def apply_prompt_template(sources, template_id, systems=None): + conv = get_conversation_template(template_id) + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + if systems and systems[i]: + conv.set_system_message(systems[i]) + prompt = conv.get_prompt() + conversations.append(prompt) + return conversations, conv + + +def tokenize_conversations(conversations, tokenizer): + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + return input_ids, targets + + +def get_prompt_separator(conv): + if conv.sep_style == SeparatorStyle.ADD_COLON_SINGLE: + user_turn_separator = conv.sep2 + assistant_turn_separator = conv.roles[1] + ": " + + elif conv.sep_style == SeparatorStyle.ADD_COLON_TWO: + user_turn_separator = conv.sep2 + assistant_turn_separator = conv.roles[1] + ": " + + elif conv.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: + if conv.sep2 is None: + user_turn_separator = conv.roles[0] + ": " + else: + user_turn_separator = conv.sep2 + + assistant_turn_separator = conv.roles[1] + ": " + + elif conv.sep_style == SeparatorStyle.LLAMA2: + user_turn_separator = conv.sep2 + assistant_turn_separator = conv.roles[1] + " " + + elif conv.sep_style == SeparatorStyle.CHATML: + if conv.sep2 is None: + user_turn_separator = conv.sep + "\n" + else: + user_turn_separator = conv.sep2 + "\n" + + assistant_turn_separator = conv.roles[1] + "\n" + + return user_turn_separator, assistant_turn_separator + + +def mask_targets(conversations, targets, tokenizer, conv): + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + if tokenizer.eos_token is None: + cur_len = 0 + elif tokenizer.eos_token is not None and target[0] != tokenizer.bos_token_id: + cur_len = 0 + elif tokenizer.eos_token is not None and target[0] == tokenizer.bos_token_id: + cur_len = 1 + + target[:cur_len] = IGNORE_TOKEN_ID + user_turn_separator, assistant_turn_separator = get_prompt_separator(conv) + turns = conversation.split(user_turn_separator) + for i, turn in enumerate(turns): + if ( + i < len(turns) - 1 and turn == "" + ): # Last turn is the user_turn_separator + break + + if i != 0: + turn = user_turn_separator + turn + + turn_len = len(tokenizer(turn, add_special_tokens=False).input_ids) + + if assistant_turn_separator in turn: + parts = turn.rsplit(assistant_turn_separator) + parts[0] += assistant_turn_separator + else: + parts = [turn] + + instruction_len = len( + tokenizer(parts[0], add_special_tokens=False).input_ids + ) + + target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + cur_len += turn_len + + target[cur_len:] = IGNORE_TOKEN_ID + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + rank0_print(tokenizer.decode(z)) + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + rank0_print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + return targets + + +def preprocess( + sources, tokenizer: transformers.PreTrainedTokenizer, template_id, **kwargs +) -> Dict: + systems = None if not kwargs else kwargs.get("systems", None) + + # If the data volume is small, process it directly in the main thread + if len(sources) <= 1000: + conversations, conv = apply_prompt_template(sources, template_id, systems) + input_ids, targets = tokenize_conversations(conversations, tokenizer) + targets = mask_targets(conversations, targets, tokenizer, conv) + else: # If the data volume is large, use multithreading for processing + with Pool() as p: + conversations, conv = p.apply_async( + apply_prompt_template, (sources, template_id, systems) + ).get() + input_ids, targets = p.apply_async( + tokenize_conversations, (conversations, tokenizer) + ).get() + targets = p.apply_async( + mask_targets, (conversations, targets, tokenizer, conv) + ).get() + p.close() + p.join() + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__( + self, raw_data, tokenizer: transformers.PreTrainedTokenizer, template_id + ): + super(SupervisedDataset, self).__init__() + + rank0_print("Formatting inputs...") + systems = [example.get("system", "") for example in raw_data] + sources = [example["conversations"] for example in raw_data] + + data_dict = preprocess(sources, tokenizer, template_id, systems=systems) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + self.attention_mask = data_dict["attention_mask"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict( + input_ids=self.input_ids[i], + labels=self.labels[i], + attention_mask=self.attention_mask[i], + ) + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__( + self, raw_data, tokenizer: transformers.PreTrainedTokenizer, template_id + ): + super(LazySupervisedDataset, self).__init__() + self.tokenizer = tokenizer + self.template_id = template_id + + rank0_print("Formatting inputs...Skip in lazy mode") + self.raw_data = raw_data + self.cached_data_dict = {} + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + if i in self.cached_data_dict: + return self.cached_data_dict[i] + + ret = preprocess( + [self.raw_data[i]["conversations"]], + self.tokenizer, + self.template_id, + systems=[self.raw_data[i].get("system", "")], + ) + ret = dict( + input_ids=ret["input_ids"][0], + labels=ret["labels"][0], + attention_mask=ret["attention_mask"][0], + ) + self.cached_data_dict[i] = ret + + return ret + + +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, + data_args, + template_id, + train_ratio=0.98, +) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_ratio = min(train_ratio, 1.0) + dataset_cls = ( + LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset + ) + rank0_print("Loading data...") + data_path = data_args.data_path + if data_path.endswith(".json"): + raw_data = json.load(open(data_path, "r")) + elif data_path.endswith(".jsonl"): + with jsonlines.open(data_path, mode="r") as reader: + raw_data = [item for item in reader] + + # Split train/test + np.random.seed(0) + perm = np.random.permutation(len(raw_data)) + split = int(len(perm) * train_ratio) + train_indices = perm[:split] + if train_ratio < 1: + eval_indices = perm[split:] + else: + # if train_ratio==1, we use 5% of data as eval data, make sure trainer will not throw error when eval data is empty + eval_indices = perm[-int(len(perm) * 0.05) :] + train_raw_data = [raw_data[i] for i in train_indices] + eval_raw_data = [raw_data[i] for i in eval_indices] + rank0_print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}") + + train_dataset = dataset_cls( + train_raw_data, tokenizer=tokenizer, template_id=template_id + ) + eval_dataset = dataset_cls( + eval_raw_data, tokenizer=tokenizer, template_id=template_id + ) + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) + + +def train(): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + config = transformers.AutoConfig.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=True, + cache_dir=training_args.cache_dir, + ) + # Set RoPE scaling factor + orig_ctx_len = getattr(config, "max_position_embeddings", None) + if orig_ctx_len and training_args.model_max_length > orig_ctx_len: + scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + config.use_cache = False + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + config=config, + trust_remote_code=True, + cache_dir=training_args.cache_dir, + ) + # Tie the weights + model.tie_weights() + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + config=config, + trust_remote_code=True, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + # NOTE: if the token_id exceed the vocab_size will cause failing in training process! we need add special config and resize the embedding size! + tokenizer.pad_token = tokenizer.unk_token + tokenizer.pad_token_id = tokenizer.unk_token_id + print(f"tokens len: {len(tokenizer)}") + model.resize_token_embeddings(len(tokenizer)) + + template_id = model_args.model_name_or_path + data_module = make_supervised_data_module( + tokenizer=tokenizer, + template_id=template_id, + train_ratio=0.98, + data_args=data_args, + ) + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/fastchat/train/train_yuan2.py b/fastchat/train/train_yuan2.py new file mode 100644 index 000000000..6f3c09a14 --- /dev/null +++ b/fastchat/train/train_yuan2.py @@ -0,0 +1,482 @@ +# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright: +# +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +import json +import math +import pathlib +from typing import Dict, Optional, Sequence + +import numpy as np +import torch +from torch.utils.data import Dataset +import transformers +from transformers import Trainer +from transformers.trainer_pt_utils import LabelSmoother + +from fastchat.conversation import SeparatorStyle +from fastchat.model.model_adapter import get_conversation_template + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + trust_remote_code: bool = field( + default=False, + metadata={ + "help": "Whether or not to allow for custom models defined on the Hub in their own modeling files" + }, + ) + padding_side: str = field( + default="right", metadata={"help": "The padding side in tokenizer"} + ) + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + eval_data_path: str = field( + default=None, metadata={"help": "Path to the evaluation data."} + ) + lazy_preprocess: bool = False + last_response_loss: bool = False + split_example_loss: bool = False + efficient_loss: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=512, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +def trainer_save_model_safe(trainer: transformers.Trainer): + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import StateDictType, FullStateDictConfig + + save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type( + trainer.model, StateDictType.FULL_STATE_DICT, save_policy + ): + trainer.save_model() + + +# add by wpf for yuan test +def right_replace(string, old, new, max=1): + return string[::-1].replace(old[::-1], new[::-1], max)[::-1] + + +def preprocess( + sources, + tokenizer: transformers.PreTrainedTokenizer, + data_args, +) -> Dict: + conv = get_conversation_template("yuan2") # wpf + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + if data_args.last_response_loss: + a = conversations[0].replace("", "") + a = right_replace(a, "", "") + # a=right_replace(a,"","\n",max=20) + conversations[0] = a + if data_args.split_example_loss: + a = conversations[0].replace("", "") + a = a.split("") + for i in range(int(len(a) / 2)): + if i == 0: + conversations[i] = "" + if i != 0: + conversations.append("") + for j in range(i * 2): + conversations[i] = conversations[i] + a[j] + "" + conversations[i] = ( + conversations[i] + a[i * 2] + "" + a[i * 2 + 1] + "" + ) + + if data_args.efficient_loss: + a = conversations[0].replace("", "") + conversations[0] = a + + print(conversations) + + # Tokenize conversations + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + + # assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO #wpf + # Mask targets. Only compute loss on the assistant outputs. + # sep = conv.sep + conv.roles[1] + ": " #wpf + + if data_args.split_example_loss: + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + turns = conversation.split("") + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID + + for i, turn in enumerate(turns): + if turn == "": + break + if i == 0 or i == len(turns) - 1: + turn_len = len(tokenizer(turn).input_ids) + else: + turn_len = len(tokenizer(turn).input_ids) + 1 + # parts = turn.split(sep) + # if len(parts) != 2: + # break + # parts[0] += sep + # "-2" is hardcoded for the Llama tokenizer to make the offset correct. + instruction_len = 0 + if i == len(turns) - 1: + instruction_len = turn_len + target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + cur_len += turn_len + + target[cur_len:] = IGNORE_TOKEN_ID + # print("cur_len: ", cur_len) + # print("total_len: ", total_len) + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + rank0_print(tokenizer.decode(z)) + exit() + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + rank0_print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" #turn = {len(turns) - 1}. (ignored)" + ) + + if data_args.efficient_loss: + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + turns = conversation.split("") + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID + + for i, turn in enumerate(turns): + if turn == "": + break + if i == 0 or i == len(turns) - 1: + turn_len = len(tokenizer(turn).input_ids) + else: + turn_len = len(tokenizer(turn).input_ids) + 1 + # parts = turn.split(sep) + # if len(parts) != 2: + # break + # parts[0] += sep + # "-2" is hardcoded for the Llama tokenizer to make the offset correct. + instruction_len = 0 + if i % 2 == 0: + instruction_len = turn_len + + # if i != 0 and not tokenizer.legacy: + # # The legacy and non-legacy modes handle special tokens differently + # instruction_len -= 1 + + # Ignore the user instructions + target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + cur_len += turn_len + + if i != 0 and not tokenizer.legacy: + # The legacy and non-legacy modes handle special tokens differently + cur_len -= 1 + target[cur_len:] = IGNORE_TOKEN_ID + # print("cur_len: ", cur_len) + # print("total_len: ", total_len) + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + rank0_print(tokenizer.decode(z)) + exit() + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + rank0_print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" #turn = {len(turns) - 1}. (ignored)" + ) + if data_args.last_response_loss: + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + turns = conversation.split("") + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID + + for i, turn in enumerate(turns): + if turn == "": + break + if i == 0 or i == len(turns) - 1: + turn_len = len(tokenizer(turn).input_ids) + else: + turn_len = len(tokenizer(turn).input_ids) + 1 + # parts = turn.split(sep) + # if len(parts) != 2: + # break + # parts[0] += sep + # "-2" is hardcoded for the Llama tokenizer to make the offset correct. + instruction_len = 0 + if i == len(turns) - 1: + instruction_len = turn_len + + # if i != 0 and not tokenizer.legacy: + # # The legacy and non-legacy modes handle special tokens differently + # instruction_len -= 1 + + # Ignore the user instructions + target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + cur_len += turn_len + + # if i != 0 and not tokenizer.legacy: + # # The legacy and non-legacy modes handle special tokens differently + # cur_len -= 1 + + target[cur_len:] = IGNORE_TOKEN_ID + # print("cur_len: ", cur_len) + # print("total_len: ", total_len) + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + rank0_print(tokenizer.decode(z)) + exit() + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + rank0_print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" #turn = {len(turns) - 1}. (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__( + self, raw_data, data_args, tokenizer: transformers.PreTrainedTokenizer + ): + super(SupervisedDataset, self).__init__() + + rank0_print("Formatting inputs...") + sources = [example["conversations"] for example in raw_data] + data_dict = preprocess(sources, tokenizer, data_args) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + self.attention_mask = data_dict["attention_mask"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict( + input_ids=self.input_ids[i], + labels=self.labels[i], + attention_mask=self.attention_mask[i], + ) + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__( + self, raw_data, data_args, tokenizer: transformers.PreTrainedTokenizer + ): + super(LazySupervisedDataset, self).__init__() + self.tokenizer = tokenizer + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.raw_data = raw_data + self.data_args = data_args + self.cached_data_dict = {} + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + if i in self.cached_data_dict: + return self.cached_data_dict[i] + + ret = preprocess( + [self.raw_data[i]["conversations"]], self.tokenizer, self.data_args + ) + ret = dict( + input_ids=ret["input_ids"][0], + labels=ret["labels"][0], + attention_mask=ret["attention_mask"][0], + ) + self.cached_data_dict[i] = ret + + return ret + + +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, data_args +) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + dataset_cls = ( + LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset + ) + rank0_print("Loading data...") + + train_json = json.load(open(data_args.data_path, "r")) + train_dataset = dataset_cls(train_json, data_args, tokenizer=tokenizer) + + if data_args.eval_data_path: + eval_json = json.load(open(data_args.eval_data_path, "r")) + eval_dataset = dataset_cls(eval_json, data_args, tokenizer=tokenizer) + else: + eval_dataset = None + + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) + + +def train(): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + + # Set RoPE scaling factor + config = transformers.AutoConfig.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + trust_remote_code=model_args.trust_remote_code, + ) + orig_ctx_len = getattr(config, "max_position_embeddings", None) + if orig_ctx_len and training_args.model_max_length > orig_ctx_len: + scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + config.use_cache = False + + # Load model and tokenizer + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + config=config, + cache_dir=training_args.cache_dir, + trust_remote_code=model_args.trust_remote_code, + ) + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side=model_args.padding_side, + use_fast=False, + trust_remote_code=model_args.trust_remote_code, + ) + + if tokenizer.pad_token != tokenizer.unk_token: + tokenizer.pad_token = tokenizer.unk_token + tokenizer.add_tokens( + [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + ], + special_tokens=True, + ) + + # Load data + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + + # Start trainner + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + + # Save model + model.config.use_cache = True + trainer.save_state() + if trainer.is_deepspeed_enabled: + trainer.save_model() + else: + trainer_save_model_safe(trainer) + + +if __name__ == "__main__": + train() diff --git a/fastchat/utils.py b/fastchat/utils.py index b5e3ba543..70f61202f 100644 --- a/fastchat/utils.py +++ b/fastchat/utils.py @@ -2,6 +2,8 @@ Common utilities. """ from asyncio import AbstractEventLoop +from io import BytesIO +import base64 import json import logging import logging.handlers @@ -57,6 +59,9 @@ def build_logger(logger_name, logger_filename): logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) + # Avoid httpx flooding POST logs + logging.getLogger("httpx").setLevel(logging.WARNING) + # if LOGDIR is empty, then don't try output log to local file if LOGDIR != "": os.makedirs(LOGDIR, exist_ok=True) @@ -149,16 +154,21 @@ def oai_moderation(text): """ import openai - openai.api_base = "https://api.openai.com/v1" - openai.api_key = os.environ["OPENAI_API_KEY"] + client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) + threshold_dict = { + "sexual": 0.2, + } MAX_RETRY = 3 - for i in range(MAX_RETRY): + for _ in range(MAX_RETRY): try: - res = openai.Moderation.create(input=text) - flagged = res["results"][0]["flagged"] + res = client.moderations.create(input=text) + flagged = res.results[0].flagged + for category, threshold in threshold_dict.items(): + if getattr(res.results[0].category_scores, category) > threshold: + flagged = True break - except (openai.error.OpenAIError, KeyError, IndexError) as e: + except (openai.OpenAIError, KeyError, IndexError) as e: # flag true to be conservative flagged = True print(f"MODERATION ERROR: {e}\nInput: {text}") @@ -166,7 +176,7 @@ def oai_moderation(text): def moderation_filter(text, model_list): - MODEL_KEYWORDS = ["claude"] + MODEL_KEYWORDS = ["claude", "gpt-4", "bard"] for keyword in MODEL_KEYWORDS: for model in model_list: @@ -223,7 +233,7 @@ def pretty_print_semaphore(semaphore): url_params = Object.fromEntries(params); console.log("url_params", url_params); - msg = "Users of this website are required to agree to the following terms:\\n\\nThe service is a research preview. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\\nThe service collects user dialogue data and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license." + msg = "Users of this website are required to agree to the following terms:\\n\\nThe service is a research preview. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\\nPlease do not upload any private information.\\nThe service collects user dialogue data, including both text and images, and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license." alert(msg); return url_params; @@ -311,9 +321,9 @@ def is_sentence_complete(output: str): # NOTE: The ordering here is important. Some models have two of these and we # have a preference for which value gets used. SEQUENCE_LENGTH_KEYS = [ + "max_position_embeddings", "max_sequence_length", "seq_length", - "max_position_embeddings", "max_seq_len", "model_max_length", ] @@ -347,3 +357,24 @@ def str_to_torch_dtype(dtype: str): return torch.bfloat16 else: raise ValueError(f"Unrecognized dtype: {dtype}") + + +def load_image(image_file): + from PIL import Image + import requests + + image = None + + if image_file.startswith("http://") or image_file.startswith("https://"): + timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) + response = requests.get(image_file, timeout=timeout) + image = Image.open(BytesIO(response.content)) + elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): + image = Image.open(image_file) + elif image_file.startswith("data:"): + image_file = image_file.split(",")[1] + image = Image.open(BytesIO(base64.b64decode(image_file))) + else: + image = Image.open(BytesIO(base64.b64decode(image_file))) + + return image diff --git a/multigpu_inference.sh b/multigpu_inference.sh index eef154da3..de4408ed6 100644 --- a/multigpu_inference.sh +++ b/multigpu_inference.sh @@ -1 +1 @@ -python3 -m fastchat.serve.cli --model-path /data/ml/llm/vicuna-13b-v1.1 --num-gpus 2 \ No newline at end of file +python3 -m fastchat.serve.cli --model-path /data/ml/llm/OpenHermes-2.5-Mistral-7B --num-gpus 2 --max-gpu-memory 8GiB \ No newline at end of file diff --git a/playground/FastChat_API_GoogleColab.ipynb b/playground/FastChat_API_GoogleColab.ipynb new file mode 100644 index 000000000..9fcdf8358 --- /dev/null +++ b/playground/FastChat_API_GoogleColab.ipynb @@ -0,0 +1,347 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# FastChat API using Google Colab\n", + "\n", + "[ggcr](https://github.com/ggcr)" + ], + "metadata": { + "id": "1UDur96B5C7T" + } + }, + { + "cell_type": "code", + "source": [ + "%cd /content/\n", + "\n", + "# clone FastChat\n", + "!git clone https://github.com/lm-sys/FastChat.git\n", + "\n", + "# install dependencies\n", + "%cd FastChat\n", + "!python3 -m pip install -e \".[model_worker,webui]\" --quiet" + ], + "metadata": { + "id": "NQWpzwse8PrC" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "See [openai_api.md](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md) from FastChat docs.\n", + "\n", + "Because in Google Colab we are limited in resources and executing things in the background is not stable, we will run each API process in a thread and communicate them via explicit addresses:" + ], + "metadata": { + "id": "97181RzwSjha" + } + }, + { + "cell_type": "code", + "source": [ + "import subprocess\n", + "import threading\n", + "\n", + "%cd /content/\n", + "\n", + "# Using 127.0.0.1 because localhost does not work properly in Colab\n", + "\n", + "def run_controller():\n", + " subprocess.run([\"python3\", \"-m\", \"fastchat.serve.controller\", \"--host\", \"127.0.0.1\"])\n", + "\n", + "def run_model_worker():\n", + " subprocess.run([\"python3\", \"-m\", \"fastchat.serve.model_worker\", \"--host\", \"127.0.0.1\", \"--controller-address\", \"http://127.0.0.1:21001\", \"--model-path\", \"lmsys/vicuna-7b-v1.5\", \"--load-8bit\"])\n", + "\n", + "def run_api_server():\n", + " subprocess.run([\"python3\", \"-m\", \"fastchat.serve.openai_api_server\", \"--host\", \"127.0.0.1\", \"--controller-address\", \"http://127.0.0.1:21001\", \"--port\", \"8000\"])\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "BrhPP9ZggVL0", + "outputId": "be510360-21ba-4f6f-d6b6-24c710bdff4d" + }, + "execution_count": 11, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Start controller thread\n", + "# see `controller.log` on the local storage provided by Colab\n", + "controller_thread = threading.Thread(target=run_controller)\n", + "controller_thread.start()" + ], + "metadata": { + "id": "3S8vDHy3gWUv" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Start model worker thread\n", + "\n", + "# see `controller.log` on the local storage provided by Colab\n", + "# important to wait until the checkpoint shards are fully downloaded\n", + "model_worker_thread = threading.Thread(target=run_model_worker)\n", + "model_worker_thread.start()\n" + ], + "metadata": { + "id": "UAU097ymgbNf" + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Start API server thread\n", + "api_server_thread = threading.Thread(target=run_api_server)\n", + "api_server_thread.start()" + ], + "metadata": { + "id": "bTqHMMr1gcQJ" + }, + "execution_count": 12, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "We now have the API running at http://127.0.0.1:8000/v1/ locally from Google Colab.\n", + "\n", + "We can run the examples from FastChat with curl." + ], + "metadata": { + "id": "iBdjt9I6fuSn" + } + }, + { + "cell_type": "markdown", + "source": [ + "Try chat completion with" + ], + "metadata": { + "id": "KtaxADXqhazs" + } + }, + { + "cell_type": "code", + "source": [ + "!curl http://127.0.0.1:8000/v1/chat/completions \\\n", + " -H \"Content-Type: application/json\" \\\n", + " -d '{ \\\n", + " \"model\": \"vicuna-7b-v1.5\", \\\n", + " \"messages\": [{\"role\": \"user\", \"content\": \"Hello, can you tell me a joke for me?\"}], \\\n", + " \"temperature\": 0.5 \\\n", + " }'" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "MZGd4y2SfBTT", + "outputId": "066835bb-f7f0-4e16-f54a-2f74b0e2f9d9" + }, + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{\"id\":\"chatcmpl-3RViU5mrsEBNu8oSxexAEb\",\"object\":\"chat.completion\",\"created\":1705781842,\"model\":\"vicuna-7b-v1.5\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"Sure thing! Here's one for you:\\n\\nWhy did the tomato turn red?\\n\\nBecause it saw the salad dressing!\"},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":50,\"total_tokens\":82,\"completion_tokens\":32}}" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Try embeddings with" + ], + "metadata": { + "id": "umgVIilThc6a" + } + }, + { + "cell_type": "code", + "source": [ + "!curl http://127.0.0.1:8000/v1/embeddings \\\n", + " -H \"Content-Type: application/json\" \\\n", + " -d '{ \\\n", + " \"model\": \"vicuna-7b-v1.5\", \\\n", + " \"input\": \"Hello, can you tell me a joke for me?\" \\\n", + " }'" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "VraqDkMahfAQ", + "outputId": "18710c2c-1994-4f36-eff1-6aff5a2a83a4" + }, + "execution_count": 18, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"embedding\":[0.0229715034365654,-0.020740192383527756,0.01663232035934925,0.013713006861507893,-0.01602417416870594,-0.006382038351148367,0.011642662808299065,-0.021167458966374397,0.004879815969616175,-0.005442662630230188,0.0034834356047213078,-0.010336925275623798,-0.009551243856549263,0.0005828586872667074,-0.0089940270408988,-0.0018360239919275045,-0.021827373653650284,0.007349758874624968,-0.0011765437666326761,-0.01432803925126791,0.012239773757755756,-0.018455859273672104,0.016475312411785126,-0.006144467741250992,-0.013893244788050652,-0.00961716752499342,0.00827623251825571,0.0013034207513555884,0.006355977617204189,0.007773293182253838,0.0029199880082160234,-0.014487813226878643,-0.01615595631301403,0.007242684718221426,-0.004686516709625721,-0.0034376305993646383,-0.0046915397979319096,0.0007899928605183959,-0.003679676679894328,-0.022176748141646385,-0.005467468872666359,-0.02236158587038517,0.02086811512708664,0.0029669292271137238,-0.0168694406747818,0.025603512302041054,0.009139388799667358,0.02165624313056469,-0.004472456872463226,0.0006205983809195459,0.0011453271145001054,0.014379195868968964,0.01994524523615837,-0.017613859847187996,0.005462903995066881,0.005702079739421606,-0.021057194098830223,-0.021468186751008034,-0.004666909575462341,-0.007595115341246128,-0.009129735641181469,-0.0161031112074852,0.009293882176280022,0.00953285675495863,-0.0013638428645208478,0.0007091081934049726,0.0018222536891698837,0.020376019179821014,0.01186810340732336,-0.013734177686274052,-0.004418510012328625,-0.006746952421963215,-0.0006970430840738118,-0.006644704379141331,-0.04453064501285553,0.003871878841891885,-0.01059865765273571,-0.024984514340758324,0.011757172644138336,-0.016218630596995354,-0.009141125716269016,-0.004623874556273222,-0.009732221253216267,-0.009169373661279678,-0.006947007961571217,-0.005838882178068161,-0.0068959807977080345,-0.000743469747249037,0.008742589503526688,-0.008120769634842873,-0.018119709566235542,-0.004530956968665123,-0.003916825633496046,0.02495340257883072,0.010598400607705116,0.010666633024811745,0.00679260678589344,-0.009019959717988968,-0.004487940575927496,-0.0026543298736214638,0.00286748050712049,0.012851846404373646,0.0012102456530556083,0.014895712956786156,-0.01030716486275196,0.01633496955037117,0.015731101855635643,-0.009079995565116405,0.016830960288643837,0.00940327625721693,-0.0014347939286381006,0.0207867082208395,0.06265891343355179,0.002649270463734865,-0.007526970934122801,0.004714089445769787,0.006397288292646408,-0.0029612022917717695,-0.0015034123789519072,-0.006392269395291805,-0.012309122830629349,0.0040127672255039215,0.001810954650864005,-0.016414696350693703,-0.019156336784362793,0.0003308420709799975,0.007823580875992775,0.0020239183213561773,-0.0024881847202777863,-0.008919963613152504,-0.01775810308754444,-0.012687149457633495,0.0022407048381865025,-0.009261680766940117,0.006048525683581829,0.00518012186512351,0.0029072873294353485,-7.72168641560711e-06,0.012007351964712143,-0.0004918070626445115,0.0013227892341092229,0.006292788311839104,-0.010167273692786694,-0.009050589054822922,0.008057740516960621,0.006250383332371712,0.014853340573608875,0.02723078615963459,-0.02242557890713215,0.04399850592017174,0.00313431303948164,-0.022166002541780472,0.010024639777839184,0.003234871895983815,0.0030383227858692408,0.012888548895716667,0.01507903728634119,0.00479199830442667,-0.0024831658229231834,0.008515636436641216,0.0005489314789883792,0.004214818123728037,0.006590660661458969,-0.012804229743778706,0.011747709475457668,0.002035082783550024,0.0143223125487566,0.0134012121707201,-0.0008568498305976391,0.0025005715433508158,-0.012422841973602772,0.014866000972688198,0.020239505916833878,-0.0034607010893523693,-0.026886560022830963,-0.0023535056971013546,-0.0037942437920719385,0.013139543123543262,0.004902820568531752,0.008357052691280842,-0.011724174953997135,0.005840683821588755,0.009768190793693066,0.00013014259457122535,0.016845345497131348,-0.006546108052134514,-0.00838533416390419,-0.01408461295068264,-0.0022769987117499113,0.010644538328051567,0.002947496483102441,0.02589692734181881,0.012639564462006092,0.004540625493973494,-0.0176566019654274,-0.010239857248961926,0.01839127205312252,0.0031600680667907,0.011127336882054806,0.0036535318940877914,0.015353705734014511,-0.026527339592576027,-0.008746611885726452,0.01886408030986786,0.00887488853186369,-0.0001859961193986237,0.001222877879627049,0.0065072583965957165,-0.009838716126978397,0.008630175143480301,-0.00633110711351037,0.02635054476559162,-0.005968477576971054,-0.013434287160634995,0.01017901673913002,-0.003692896803840995,-0.005410553887486458,-0.006332104559987783,-0.017778540030121803,-0.017085647210478783,-0.005269246641546488,-0.013628004118800163,-0.0005570553475990891,0.010984581895172596,0.000956009142100811,0.009669160470366478,-0.0019082700600847602,-0.05074448138475418,-0.03876679390668869,0.0011635294649749994,-0.012585809454321861,0.008794615045189857,0.00023998660617507994,-0.00455761281773448,-0.0020947649609297514,0.017387693747878075,0.004844747018069029,0.008267332799732685,0.00747610442340374,0.02141532674431801,-0.02262278087437153,-0.014600872062146664,-0.021727152168750763,0.008812149986624718,0.009474638849496841,0.03191479295492172,-0.019652077928185463,0.01944698765873909,0.017112286761403084,0.015296016819775105,0.014461753889918327,-0.019157931208610535,0.009540014900267124,0.004215397406369448,-0.008012793958187103,0.013523118570446968,-0.009407458826899529,-0.029304828494787216,0.012041181325912476,0.015149015933275223,0.0031983305234462023,-0.0003109185490757227,0.03257888928055763,0.007614033296704292,-0.005175750236958265,-0.002383652376011014,0.006435382179915905,0.006068408954888582,-0.007524268701672554,0.02373131737112999,0.004817254841327667,0.005436067469418049,-0.0059105646796524525,-0.005925316829234362,-6.454042886616662e-05,-0.008412199094891548,-0.00655836658552289,-0.0010680218692868948,-0.004262322559952736,0.0015925978077575564,0.00412611523643136,-0.011034490540623665,0.009839101694524288,0.00415002042427659,-0.007727092131972313,-0.010377302765846252,0.0007711391081102192,-0.009322070516645908,0.0035655524116009474,-0.026301125064492226,-0.006197007372975349,0.0006739745149388909,-0.00818476639688015,-0.02090131863951683,-0.002644758205860853,0.006994722411036491,-0.0016304099699482322,0.01705804094672203,-0.016460495069622993,0.017486274242401123,0.013187418691813946,0.0033816162031143904,0.017844069749116898,-0.017695210874080658,-0.011941025033593178,0.009029353968799114,0.0033719318453222513,-0.009064359590411186,0.012252643704414368,0.0011845449917018414,0.003185839159414172,0.003374891821295023,-0.007335654925554991,0.0029391313437372446,0.000280876352917403,0.0048222895711660385,-0.0003767217858694494,-0.045474909245967865,0.004725527483969927,0.0075803473591804504,0.005909985862672329,0.002949362387880683,-0.0036183823831379414,0.0026071954052895308,-0.005563989747315645,-0.012707033194601536,-0.004933884367346764,-0.016659578308463097,-0.0081319659948349,0.012579865753650665,-0.022291865199804306,-0.018159057945013046,-0.0069056968204677105,-0.00018650286074262112,-0.006835494190454483,0.0006484286277554929,0.005561383906751871,0.0062789213843643665,0.029090696945786476,0.002546998206526041,0.009344656951725483,-0.0038842656649649143,-0.012519339099526405,-0.0025535617023706436,-0.003679415676742792,-0.0033875037916004658,0.003728062380105257,-0.014787501655519009,0.0023771373089402914,0.005443841218948364,-0.00957341119647026,-0.015306569635868073,0.0046866778284311295,-0.016635537147521973,-0.01424899697303772,0.001698320615105331,-0.004534294828772545,0.0066452836617827415,0.010703673586249352,0.004293128848075867,-0.009486992843449116,-0.0031507215462625027,0.01611129753291607,-0.015744132921099663,-0.014641146175563335,0.0026989546604454517,0.01565713621675968,-0.005524931009858847,0.006648661568760872,0.0040243822149932384,-0.00559786893427372,-0.014391486532986164,0.026553215458989143,-0.009266120381653309,0.020683180540800095,0.00994131714105606,0.0026739235036075115,0.0038542025722563267,-0.012158502824604511,-0.010751161724328995,-0.00017412402667105198,-0.017064156010746956,-0.010691382922232151,0.00937278475612402,-0.014700417406857014,-0.005352479871362448,0.012342552654445171,0.009191831573843956,-0.011637836694717407,-0.012737436220049858,0.01105053722858429,0.020749129354953766,0.07297933101654053,0.027850160375237465,-0.005428216885775328,-0.019425511360168457,0.0016134463949128985,-0.007674881722778082,0.004896160680800676,-0.006309020332992077,0.0028925116639584303,-0.016418879851698875,-0.012568380683660507,-0.0076565672643482685,-0.002051394898444414,0.011267355643212795,0.01101701334118843,0.02482358179986477,0.011389358900487423,-0.01589033007621765,0.0005615596892312169,-0.027247965335845947,-0.008588980883359909,0.005675439722836018,0.008922569453716278,-0.003106530988588929,0.00925450585782528,-0.00030810333555564284,-0.002115500858053565,-0.007074093911796808,-0.005927231162786484,-0.017885340377688408,-0.016033342108130455,-0.0049004401080310345,0.006337509956210852,0.01978384517133236,0.001572070992551744,-0.0143946073949337,-0.008655560202896595,-0.0011587677290663123,-2.521412170608528e-05,-0.01082194410264492,0.010964666493237019,-0.011412781663239002,0.008038532920181751,0.006299568805843592,-0.008974144235253334,0.006545931100845337,0.0006125871441327035,0.00486041558906436,0.0042688059620559216,0.0018871801439672709,-0.006763682700693607,0.013578971847891808,-0.0020262349862605333,-0.0024552710819989443,-0.01506423857063055,0.0054992204532027245,0.011333892121911049,-0.007717472035437822,-0.005762179847806692,0.0007979075890034437,0.007761630229651928,-0.00952511839568615,-0.010288495570421219,0.014522014185786247,-0.005318223498761654,0.009297103621065617,0.0038411528803408146,0.012293890118598938,0.004698003176599741,-0.007060967851430178,-0.004558722488582134,-0.003963573835790157,0.016085496172308922,0.015816137194633484,0.0027972774114459753,-0.017336538061499596,0.014937583357095718,0.013450084254145622,0.06357342004776001,-0.009506811387836933,0.007877970114350319,0.007048371247947216,0.011457744054496288,0.023370005190372467,0.014203527010977268,-0.004614254459738731,-0.008159955963492393,0.0030794248450547457,-0.0010602197144180536,0.0006093939300626516,-0.010418003425002098,-0.007668149657547474,0.015968769788742065,-0.0015574641292914748,-0.018846578896045685,-0.003667157609015703,0.0019307112088426948,-0.001895931432954967,-0.010295855812728405,0.00023113582574296743,0.007993489503860474,0.0022910244297236204,0.00033837073715403676,-0.005313453264534473,0.0010675875237211585,-0.01549510844051838,0.007410695310682058,0.009646059945225716,-0.012997191399335861,0.010529725812375546,-0.019208982586860657,-0.010029473342001438,-0.013124711811542511,0.029043130576610565,-0.00493550905957818,0.008303387090563774,0.0067044831812381744,0.005133184138685465,-0.008268092758953571,0.0027517518028616905,-0.013479426503181458,-0.01547516044229269,-0.020013773813843727,-0.006451855413615704,0.008133156225085258,-0.006830539554357529,-0.007085484452545643,0.010033013299107552,0.002104497514665127,0.0005678657325915992,0.006996427197009325,-0.00785919837653637,-0.029753299430012703,0.03372034803032875,-0.008247010409832,0.008989491499960423,0.017457574605941772,-0.0059603373520076275,-0.003432418452575803,-0.014526166021823883,0.01819109544157982,-0.007616993971168995,-0.008361894637346268,0.008198246359825134,0.004229682497680187,-0.02080651931464672,0.009076694026589394,-0.006605580914765596,0.0037523536011576653,-0.010452975519001484,-0.012760377489030361,-0.017025675624608994,-0.007993683218955994,0.013692287728190422,0.017206765711307526,0.006106856279075146,0.011746293865144253,-0.009011680260300636,-0.007511272560805082,0.006244495511054993,0.009395747445523739,0.006921007763594389,0.00926200207322836,0.03370635211467743,0.0026780739426612854,0.012087206356227398,0.0012626887764781713,-0.014491417445242405,-0.007984738796949387,-0.02033303491771221,-0.008010058663785458,-0.0027411666233092546,-0.006356299389153719,0.014341359958052635,0.00924749206751585,0.008061794564127922,-0.014423820190131664,-0.0027274927124381065,-0.009464149363338947,0.0032869288697838783,0.028920968994498253,-0.007417581044137478,-0.012927711941301823,-0.006823397241532803,0.0021555088460445404,-0.008643687702715397,-0.0023652170784771442,-0.0060961428098380566,-0.017238536849617958,-0.007533663418143988,0.0022437411826103926,-0.0029654495883733034,0.007918364368379116,-0.004272923804819584,0.022094689309597015,-0.01293826475739479,-0.03929437696933746,-0.05735565349459648,-0.013071688823401928,0.0007404614589177072,-0.000720368989277631,-0.006117763463407755,-0.011282929219305515,0.010729444213211536,-0.014913954772055149,0.00311655318364501,0.006948134861886501,-0.00748022273182869,-0.02309916727244854,-0.0178229883313179,-0.0072602517902851105,0.007839913479983807,0.012868576683104038,0.002075975527986884,0.0007498079212382436,0.005815781187266111,-0.011992518790066242,0.010061261244118214,0.004755143541842699,-0.0014543153811246157,0.014422083273530006,-0.0023919050581753254,0.009424189105629921,-0.01841503195464611,0.008597759529948235,0.023288220167160034,-0.009507520124316216,0.015740947797894478,-0.0004225693119224161,0.02476677857339382,-0.011370633728802204,0.011675688438117504,0.020527847111225128,-0.0073279449716210365,-0.013483609072864056,-0.019474929198622704,-0.004016772843897343,-0.012925073504447937,-0.00565439835190773,0.0104595385491848,-0.007314899004995823,0.010194428265094757,0.0022050561383366585,0.011519340798258781,-0.0059105330146849155,-0.0007297637057490647,-0.016200484707951546,0.015271657146513462,-0.016203250735998154,0.034517351537942886,0.0006107089575380087,-0.013269267976284027,0.01328535471111536,-0.02016814425587654,-0.007773164194077253,-0.007333156652748585,-0.01815428026020527,0.006929537747055292,-0.0034732790663838387,-0.004669690039008856,0.0016878641908988357,-0.03094855323433876,0.0019403311889618635,-0.005923015996813774,-0.0040122526697814465,0.009299001656472683,-0.006708343978971243,0.01585310511291027,0.0010694535449147224,0.0006908577051945031,-0.0015497022541239858,-0.014749257825314999,0.013069666922092438,-0.0003381777205504477,-0.0186776015907526,-0.00869465060532093,-0.005246113985776901,0.004712183494120836,-0.0033125269692391157,0.005922533571720123,0.005009307526051998,-0.002772809471935034,0.0018297180067747831,-0.007289668545126915,-0.025313491001725197,-0.010890730656683445,-0.013207301497459412,-0.015217771753668785,-0.0064299451187253,0.0012019408168271184,0.013148745521903038,-0.022279510274529457,0.008878774009644985,-0.007133841048926115,-0.0007347667124122381,0.007130189798772335,0.0017936835065484047,0.012268022634088993,0.007812416646629572,0.009994118474423885,-0.01274168398231268,-0.000458410766441375,-0.006630516145378351,0.0004267197218723595,0.013977475464344025,-0.003951766062527895,-0.0167144313454628,-0.012754247523844242,0.012914633378386497,0.010781855322420597,0.002908888040110469,-0.007131427992135286,0.017916306853294373,-0.005879903212189674,-0.002502115909010172,-0.0016746085602790117,-0.024386180564761162,-0.008716223761439323,0.003937223460525274,0.004685036838054657,-0.005052074324339628,-0.004745748359709978,-0.004316418897360563,-0.009056701324880123,-0.011055074632167816,0.0087593924254179,-0.016003968194127083,-0.001959120621904731,0.014024545438587666,-0.005205253139138222,-0.0034684527199715376,-0.00704217841848731,0.004913646727800369,0.01903299242258072,-0.007594246882945299,-0.0001278904383070767,-0.00024535658303648233,0.01912636123597622,0.02121288888156414,0.01097018364816904,-0.005211591720581055,-0.004693691153079271,0.0002123745362041518,0.01864037662744522,0.004567956551909447,-0.006998493801802397,0.002807476557791233,-0.0272210780531168,0.008950882591307163,-0.007628897670656443,0.017757385969161987,0.011070613749325275,-0.02169198729097843,0.005343310534954071,0.0013322805752977729,-0.004593148827552795,0.009079867042601109,0.011012605391442776,0.00658367108553648,-0.004797258879989386,-0.006833371240645647,-0.0069283475168049335,-0.009916930459439754,-0.006784595549106598,-0.03476946800947189,0.020896492525935173,0.008564138785004616,-0.0012716330820694566,-0.013008822686970234,-0.000613439769949764,0.0047750589437782764,-0.012346075847744942,0.006973704788833857,-0.013979197479784489,-0.006083691958338022,0.005035505164414644,0.011145804077386856,0.013424682430922985,-0.00019037174934055656,-0.008892635814845562,-0.01950671710073948,-0.010386078618466854,0.01175111997872591,-0.014368708245456219,0.00041413979488424957,-0.014867283403873444,0.0020979661494493484,-0.002743129152804613,0.004309915471822023,-0.012755325064063072,0.013642803765833378,0.008863402530550957,0.0013711462961509824,-0.019572222605347633,0.0036479418631643057,0.1259939968585968,0.01384377758949995,0.015267448499798775,0.014036224223673344,0.0038570465985685587,-0.005283885635435581,0.010237026028335094,-0.011374881491065025,-0.011878897435963154,-0.008971023373305798,-0.009165126830339432,-0.0010226268786936998,-0.007337307557463646,-0.010756309144198895,-0.014150279574096203,0.002133630681782961,-0.015334047377109528,0.00481215538457036,-0.013047880493104458,-0.014511879533529282,-0.0030851999763399363,-0.007749861106276512,-0.006487664300948381,0.013752967119216919,-0.012187069281935692,0.0007167012081481516,-0.0016341822920367122,-0.004467220976948738,0.0042928713373839855,0.022611349821090698,0.0005482397391460836,-0.017850179225206375,-0.014368931762874126,-0.02213916927576065,0.009322037920355797,-0.008927379734814167,0.0012655361788347363,0.003878731979057193,-0.011226431466639042,0.014120969921350479,-0.013007482513785362,-0.027299636974930763,-0.02149207703769207,0.0018350587924942374,0.0014142269501462579,-0.000801382411736995,0.010266175493597984,0.006652788259088993,0.0005369306891225278,-0.006750426720827818,0.0077108764089643955,0.008079683408141136,-0.0018402388086542487,-0.016589464619755745,-0.009489567019045353,-0.006460928358137608,-0.008930034004151821,0.005508729722350836,-0.021854624152183533,0.0021647908724844456,-4.1697108827065676e-05,0.0023772178683429956,-0.015694361180067062,-0.0025681040715426207,0.02343827858567238,-0.007234286982566118,0.011763988994061947,0.006332748103886843,0.01569999009370804,0.0011719107860699296,-0.0026809938717633486,-0.019673682749271393,0.010832150466740131,0.0020819918718189,0.0021434274967759848,0.014149283058941364,-0.018654564395546913,-0.005904508288949728,0.024274280294775963,0.0020302003249526024,0.009324193932116032,-0.0019528145203366876,0.010275795124471188,-0.007945165038108826,0.02523057907819748,-0.0015196279855445027,-0.0033202609047293663,-0.00838176254183054,0.009073046036064625,0.004423896782100201,0.0025238976813852787,0.0009007186163216829,0.012340654619038105,0.013026034459471703,0.0006704675615765154,-0.011622972786426544,0.0025514704175293446,0.0018054911633953452,-0.00021421245764940977,0.0015564989298582077,0.0002535287057980895,-0.007833908312022686,-0.002614386146888137,0.010472987778484821,0.008430087007582188,-0.010319744236767292,-0.007126948330551386,-0.0032228068448603153,-0.005715849809348583,-0.007379905320703983,0.0007485531968995929,-0.020927315577864647,0.0019611797761172056,0.0038484123069792986,-0.006966795306652784,-0.018788157030940056,0.007531090173870325,-0.006524322554469109,0.010099516250193119,-0.004077629651874304,-0.017544057220220566,-0.0056204223074018955,0.0014705952489748597,0.02655109204351902,-0.004098542500287294,0.00679929880425334,-0.009616298601031303,-0.00428798096254468,-0.004214432090520859,0.017463093623518944,0.007254500407725573,0.011614413931965828,-0.015450838021934032,0.01430854294449091,0.011353002861142159,0.0038417966570705175,0.013071335852146149,-0.003091377206146717,-0.0012477281270548701,-0.012130544520914555,-0.0005112078506499529,0.0007805016357451677,0.01115238294005394,-0.011903454549610615,0.01652473211288452,-0.016062499955296516,0.0243363119661808,0.00521033676341176,-0.019244149327278137,0.015055154450237751,-0.0014579187845811248,0.024649038910865784,0.003033657558262348,-0.004459853284060955,-0.0024275374598801136,-0.004720765631645918,-0.008315999060869217,0.01299308892339468,0.003514010924845934,0.00035230195499025285,-0.0016822096658870578,-0.011835559271275997,0.013584377244114876,0.014042497612535954,-0.0021746200509369373,-0.013556176796555519,0.009201740846037865,-0.016880186274647713,0.006788729690015316,0.007318035699427128,0.0079000573605299,-0.0021131120156496763,0.005459972191601992,-0.01956108957529068,-0.003485738066956401,-0.012780935503542423,-0.010953888297080994,-0.0035778111778199673,0.013985262252390385,0.004123058635741472,-0.017365043982863426,0.02569989673793316,-0.0032679142896085978,-0.006953733041882515,-0.020901406183838844,0.003745210822671652,0.004216748755425215,0.007281791884452105,0.01097949780523777,-0.008859830908477306,0.0076435767114162445,-0.002383668441325426,0.003228791058063507,0.000471006816951558,0.021136121824383736,0.006612015888094902,-0.00790025107562542,0.002388188848271966,-0.01046378631144762,0.0019024949287995696,-0.020805569365620613,0.008167678490281105,0.01708216592669487,0.003778784302994609,-0.007486400194466114,0.009304165840148926,0.01634320802986622,-0.015319439582526684,0.012349807657301426,0.008008498698472977,0.004085544031113386,-0.0019550668075680733,-0.0013337925774976611,0.005621806252747774,0.00999923050403595,0.0067540789023041725,0.024973737075924873,-0.013562659732997417,-0.009736709296703339,-0.012089909054338932,-0.016808679327368736,0.008086872287094593,0.008295665495097637,-0.012549092061817646,-0.010748330503702164,3.521411053952761e-05,0.0017467420548200607,0.01626216247677803,0.009219243191182613,-0.006609965115785599,0.010143030434846878,-0.020303402096033096,-0.01044105552136898,-0.013318654149770737,0.00010932621080428362,0.007084518671035767,0.007645950186997652,-0.0032920767553150654,-0.01955648884177208,0.0074850814417004585,0.00894773006439209,0.009001234546303749,0.005829519592225552,-0.0045957546681165695,0.0013910618145018816,-0.012523948214948177,0.013304369524121284,0.01453658938407898,0.017666004598140717,-0.004940214566886425,-0.011730528436601162,-0.015571167692542076,-0.010929387994110584,-0.0006716740899719298,0.02221648395061493,0.021565254777669907,0.01226515881717205,-0.0053292508237063885,0.0007020622142590582,0.0024210221599787474,0.01962619461119175,-0.004420963115990162,-0.015309896320104599,0.0034791347570717335,0.02059043198823929,-0.008116353303194046,-0.0032520205713808537,-0.012169377878308296,0.025940747931599617,-9.851584763964638e-05,0.0036511996295303106,0.0037823636084795,-0.010169846937060356,0.010504196397960186,0.013252376578748226,-0.007866725325584412,-0.0026977320667356253,-0.011583752930164337,-0.006372353993356228,-0.0007445314549840987,-0.0030074622482061386,0.016342146322131157,-0.009066401980817318,0.0021215977612882853,0.008862188085913658,0.015515057370066643,0.009001555852591991,-0.024249698966741562,0.020413951948285103,0.008854007348418236,0.0006535120774060488,0.013391399756073952,-0.01817990653216839,-0.0016513630980625749,-0.011816883459687233,0.007374065928161144,0.02026175521314144,-0.019211476668715477,0.00015504502516705543,-0.007945390418171883,0.001324703567661345,0.025466380640864372,0.006762733682990074,-0.01408602949231863,-0.01516133826225996,-0.0069986796006560326,-0.0004754628462251276,-0.01119284238666296,-0.004222266376018524,-0.014954396523535252,0.0031823322642594576,-0.009523541666567326,-0.011928976513445377,-0.0011272374540567398,-0.009063232690095901,-0.011843233369290829,-0.0030050550121814013,-0.010779651813209057,0.017810650169849396,0.009822757914662361,-0.0130256162956357,-0.002755612600594759,0.010061550885438919,-0.002134740585461259,-0.0004929009592160583,-0.011506262235343456,0.004393350332975388,0.002644677646458149,0.013704448938369751,-0.015646131709218025,-0.005174269899725914,0.017940374091267586,0.006815964821726084,-0.014483116567134857,-0.018775692209601402,-0.017056433483958244,-0.00333380582742393,-0.01628420129418373,-0.02220962941646576,-0.007394126150757074,0.004732364322990179,0.003667865414172411,0.013815898448228836,-0.014784134924411774,0.006790837273001671,-0.005050111562013626,-0.01184664387255907,-0.005963458679616451,0.01068057306110859,0.01837034337222576,6.692128226859495e-05,-0.0020520382095128298,-0.005477442871779203,0.008534909226000309,0.021816853433847427,0.019038107246160507,0.008523069322109222,-0.021777216345071793,-0.01595551334321499,-0.012562203221023083,0.012347427196800709,0.013057525269687176,-0.015681490302085876,0.012324455194175243,-0.0041071330197155476,0.01061281468719244,-0.01118357665836811,-0.001830828026868403,0.0030818136874586344,0.0002257306332467124,0.012498816475272179,0.005094640422612429,0.020110618323087692,0.008550223894417286,0.008692882023751736,0.0034023199696093798,-0.0035538740921765566,0.017047973349690437,-0.008395790122449398,0.0036361422389745712,0.0012567044468596578,-0.012467821128666401,0.015781357884407043,-0.009986070916056633,0.01078745350241661,0.008992418646812439,-0.00894157588481903,-0.009751653298735619,-0.007818657904863358,-0.11352294683456421,0.006673813331872225,0.0006858144770376384,0.012712855823338032,0.017139634117484093,-0.003267174120992422,-0.0037179840728640556,-0.027594735845923424,0.015738407149910927,-0.008096124045550823,0.008535375818610191,-0.006178006995469332,0.0021386174485087395,0.00922358687967062,0.015902427956461906,0.010610240511596203,-0.006293817888945341,0.007873225025832653,-0.009341374039649963,-0.015121137723326683,-0.0025967389810830355,0.0009708734578453004,0.02104487642645836,-0.0034994683228433132,-0.012507845647633076,0.022736024111509323,-0.007137798238545656,0.004183493088930845,-0.005087561905384064,0.005540612153708935,0.011934671550989151,-0.008175094611942768,0.013157593086361885,0.003565874882042408,0.007175907958298922,0.02075435034930706,-0.008561364375054836,0.0018133737612515688,-0.0031988373957574368,0.0026560029946267605,-0.015025373548269272,0.0025075653102248907,-0.020946715027093887,0.002409552223980427,0.0030347283463925123,-0.008436071686446667,0.011734389699995518,0.005770737770944834,0.0027340843807905912,0.009276704862713814,0.014263113029301167,0.005924335680902004,-0.013983492739498615,-0.0073938933201134205,-0.0037190215662121773,-0.007606761995702982,0.00866461731493473,-0.00787393283098936,0.004571785684674978,-0.01736222766339779,0.0011665115598589182,-0.018963271751999855,0.002434736117720604,0.023223616182804108,0.013454395346343517,-0.007077569141983986,0.006989220157265663,0.0016794668044894934,-0.0029226583428680897,0.015770161524415016,-0.007460178807377815,0.02135499194264412,-0.0067621381022036076,0.006347097456455231,0.01143655739724636,-0.009779580868780613,0.0011012412142008543,0.022937849164009094,0.03317839652299881,0.002777715912088752,0.0014309572288766503,-0.004011448472738266,-0.020232975482940674,-0.0036248492542654276,0.009381849318742752,-0.004546706099063158,0.01232175249606371,-0.02003932185471058,0.005393791012465954,0.007975440472364426,-0.02001962997019291,0.00812353566288948,0.004558304324746132,0.012361841276288033,-0.00022309240011963993,-0.005494816228747368,-0.005414157174527645,-0.0007955267792567611,-0.006178250070661306,0.0011265840148553252,0.014568240381777287,-0.015398587100207806,-0.009784664027392864,0.002724339719861746,-0.012673153541982174,-0.0022227196022868156,0.012834923341870308,0.011582594364881516,0.0023665439803153276,0.006087005604058504,-0.0014784777304157615,0.004853080026805401,0.004227772355079651,0.005455693230032921,-0.0038181168492883444,-0.009257722645998001,0.006031699012964964,0.0033167218789458275,-0.0009175615850836039,0.023257719352841377,-0.0028650029562413692,0.002901359461247921,0.002793062711134553,0.01102980226278305,0.0026135335210710764,0.028918616473674774,0.015613989904522896,-0.0029948721639811993,-0.009738076478242874,0.018055813387036324,0.0043314797803759575,0.008178786374628544,-0.011788956820964813,0.011455508880317211,0.01573013886809349,0.00820583663880825,0.01591729186475277,0.002678733319044113,-0.017613554373383522,-0.00441357959061861,-0.010343971662223339,0.003275121096521616,-0.004354435950517654,-0.016168376430869102,-0.016327762976288795,0.010710583068430424,-0.0002415279159322381,-0.005174752790480852,-0.010321610607206821,2.5521783754811622e-05,-0.005093996413052082,0.00427284324541688,-0.00925386231392622,-0.022916292771697044,-0.005452363286167383,-0.005463994108140469,-0.00032996939262375236,-0.0056364452466368675,-0.01507771946489811,-0.0140626709908247,-0.001988076837733388,0.010080339387059212,-0.008691756054759026,0.001160038635134697,-0.0021076020784676075,-0.012562798336148262,-0.002622719155624509,0.0030087551567703485,-0.007625970058143139,-0.002947271103039384,0.018139785155653954,0.02823634259402752,-0.0030986485071480274,-0.0026572253555059433,-0.009556874632835388,-0.0120854452252388,-0.016098687425255775,0.004706657491624355,0.018779207020998,-0.008696485310792923,0.02307201363146305,0.008763439022004604,-0.014935833401978016,-0.010818082839250565,-0.2784213721752167,-0.007361662574112415,-0.009495736099779606,-0.023461056873202324,-0.008934522047638893,0.015963122248649597,0.0016804963815957308,-0.009592200629413128,-0.011385498568415642,0.010840379633009434,0.0007005499792285264,0.0030378401279449463,0.01442185789346695,0.0060276128351688385,0.011916878633201122,0.0019495971500873566,0.010881658643484116,0.010174351744353771,0.002560058841481805,-0.011619336903095245,0.005709640681743622,-0.019679618999361992,0.008580016903579235,-0.020601846277713776,-0.003206663765013218,-0.009325030259788036,0.010211093351244926,0.02160986326634884,-0.0012345046270638704,-0.0058813090436160564,0.02697822079062462,-0.009422902949154377,-0.013682184740900993,-0.0015802914276719093,0.020953504368662834,-0.003903919830918312,-0.00243631680496037,-0.020303402096033096,0.01755078323185444,0.024769868701696396,0.0016339250141754746,0.02251550555229187,0.004645044915378094,-0.010381357744336128,-0.014821520075201988,-0.010959195904433727,0.00934459175914526,-0.010714001022279263,0.018016111105680466,-0.00970667414367199,-0.007309091277420521,-0.012314545921981335,-0.02047012746334076,0.027432451024651527,-0.0009060755837708712,0.07745006680488586,-0.0023823976516723633,0.01124457735568285,0.0096189696341753,-0.0008077527745626867,-0.0035770712420344353,-0.0034886582288891077,0.011778567917644978,-0.008943229913711548,0.003386442083865404,-0.00024284704704768956,0.010145587846636772,0.007330470718443394,0.003942918032407761,0.0022819836158305407,-0.0008272781851701438,0.007588133215904236,0.005243266467005014,-0.014266717247664928,-0.005166773218661547,0.0074570500291883945,-0.0016363218892365694,-0.019104288890957832,-0.005167931783944368,0.008953874930739403,-0.007413430605083704,-0.013545575551688671,-0.017633790150284767,0.026401540264487267,-0.0021100472658872604,-0.010175767354667187,0.009788733907043934,-0.014036711305379868,0.003915506415069103,-0.003761973464861512,-0.004975275602191687,0.002093156334012747,-0.001363328075967729,-0.0029019585344940424,-0.009283140301704407,-0.006503944285213947,-0.011568261310458183,0.02174294926226139,-0.014086995273828506,0.0033965124748647213,0.0035606948658823967,0.003461358603090048,0.010544992983341217,0.010210482403635979,-0.002245498588308692,0.019960559904575348,-0.007419897243380547,-0.007997768931090832,0.00904663186520338,0.02357649616897106,-0.011239221319556236,-0.00011569660273380578,-0.0029487835709005594,0.007448234129697084,0.016541525721549988,-0.0001295312977163121,0.009020346216857433,-0.020686302334070206,0.015325473621487617,-0.0016831347020342946,-0.008773420937359333,0.016255050897598267,-0.0012025240575894713,0.01161193661391735,-0.016618099063634872,0.012996693141758442,-0.004140432924032211,-0.007176905404776335,0.020722240209579468,-0.010730667039752007,0.01690627448260784,-0.0032811376731842756,0.010093660093843937,-0.0027236961759626865,-0.03603730350732803,-0.004680242855101824,0.006091711111366749,-0.012325975112617016,-0.014773516915738583,-0.012536093592643738,0.0029048342257738113,-0.02004828117787838,-0.007857202552258968,-0.012408236041665077,-0.005879549775272608,-0.003138889791443944,-0.015323558822274208,-0.0001826178777264431,0.004041365813463926,-0.015603084117174149,0.008681814186275005,0.01134839653968811,0.0006241817027330399,-0.026418721303343773,0.0036757681518793106,0.0031010936945676804,-0.0018149744719266891,-0.0038577064406126738,-0.010925833135843277,-0.006739520467817783,-0.014096260070800781,-0.005563016515225172,0.016652911901474,-0.0007585270213894546,0.011374784633517265,-0.009055189788341522,0.014467866159975529,0.021866194903850555,-0.011922026984393597,-0.006064226385205984,0.014592982828617096,0.012229286134243011,0.007419169414788485,-0.003800228238105774,0.005821636877954006,0.005980832036584616,0.019860951229929924,0.0005983874434605241,-0.021042626351118088,-0.011280648410320282,-0.0034789254423230886,-0.005904307123273611,0.00940112117677927,-0.01505252718925476,-0.007798091508448124,-0.005041247699409723,-0.020565425977110863,0.002939002588391304,-0.010503344237804413,0.006530262529850006,-0.00948650948703289,0.006920433137565851,-0.013644187711179256,-0.01110368873924017,-0.0007017726311460137,-0.011356927454471588,-0.009044218808412552,0.004168874584138393,0.014494956471025944,0.007382184267044067,-0.01204177737236023,-0.0026305855717509985,0.00237200572155416,-0.011614670976996422,0.0075203352607786655,-0.007654733490198851,-0.018017364665865898,-0.007952709682285786,0.009685106575489044,0.016591427847743034,0.008159216493368149,-0.004515109583735466,0.019129447638988495,-0.1756141632795334,-0.024899190291762352,0.0018353804480284452,0.008671293035149574,-0.01384413056075573,0.01001817174255848,-0.012732546776533127,0.005506077315658331,0.0014535110676661134,-0.00014272250700742006,-0.02563503570854664,0.0071355667896568775,-0.02158156782388687,-0.00474808132275939,0.018071835860610008,0.023083724081516266,0.009568641893565655,0.006390306632965803,-0.005066118203103542,-0.01592129096388817,0.017062868922948837,-0.01115796621888876,-0.015767812728881836,-0.005238134413957596,0.006928991060703993,0.006582673639059067,-0.008210115134716034,-0.0006850744248367846,0.003518740413710475,0.02363714389503002,0.014902275986969471,-0.00873962976038456,-0.00457162456586957,0.008439594879746437,0.004671009257435799,0.006651798263192177,0.007029373198747635,0.010178695432841778,-0.01541563868522644,0.005330503452569246,0.005778331309556961,0.010172613896429539,-0.0029294793494045734,-0.005375274922698736,0.015940893441438675,-0.01708410307765007,0.02029111236333847,0.020185356959700584,0.003809751709923148,0.010334190912544727,0.004035063553601503,-0.013017106801271439,-0.009174071252346039,0.0011511747725307941,0.003145364811643958,-0.004294078331440687,0.01332454290241003,-0.013086714781820774,0.016923105344176292,-0.012309269048273563,-0.012259078212082386,0.0015276713529601693,0.00023750621767248958,-0.00841486919671297,-0.012003683485090733,-0.02218620665371418,-0.006810398772358894,-0.05309946462512016,-0.016830896958708763,0.008899983949959278,0.013663781806826591,-0.008498359471559525,-0.009214417077600956,-0.005358291324228048,-0.019415665417909622,-0.0016335167456418276,-0.01287610549479723,-0.005925686564296484,0.007678573951125145,0.004894197918474674,-0.005250392947345972,0.01937422715127468,0.03884986415505409,0.007704956457018852,0.004224277101457119,-0.010258260183036327,0.012103293091058731,0.0007560174562968314,0.009477147832512856,0.005485904403030872,0.011781315319240093,0.005216819699853659,-0.01289766188710928,-0.00058182911016047,-0.006487181875854731,0.010025066323578358,0.01070936769247055,0.008055237121880054,0.009198716841638088,-0.0050565944984555244,0.01677780970931053,-0.004822997841984034,-0.0006103349733166397,-0.010622531175613403,-0.007425166200846434,-0.0016098107444122434,-0.006618257611989975,0.0011639798758551478,-0.08570022881031036,0.020885812118649483,-0.025955354794859886,0.018434884026646614,-0.0073579950258135796,0.005618041846901178,0.005165067967027426,0.0032188494224101305,-0.0012533745029941201,0.015155804343521595,-0.004030752461403608,-0.0077774110250175,0.0008675797143951058,-0.0021942458115518093,0.005814365576952696,0.0067954701371490955,-0.0116463303565979,-0.004899860825389624,0.012563779018819332,-0.02336389385163784,0.0006979600293561816,-0.004649227485060692,-0.012502971105277538,-0.010896007530391216,0.0012360489927232265,-0.012883569113910198,0.025206802412867546,0.011092202737927437,-0.01052560843527317,-0.006687352433800697,-0.01787686161696911,0.004141188692301512,0.0106991371139884,-0.00821922067552805,-0.02622329816222191,0.006792123895138502,-0.013250929303467274,0.007654957938939333,0.008035637438297272,-0.005465570371598005,-0.013763535767793655,-0.01950150541961193,0.008698672987520695,0.0057535613887012005,-0.019228672608733177,-0.011553805321455002,-0.0003967660013586283,0.0012686088448390365,0.006336930673569441,-0.005957281216979027,-0.002579220337793231,-0.002936155302450061,0.0036823435220867395,0.005852008704096079,0.017855370417237282,-0.00011639236618066207,0.0004218293179292232,0.001062761410139501,0.0018936148844659328,0.0179592277854681,0.006386397872120142,0.009569131769239902,0.00946755986660719,0.0031641540117561817,-0.019553659483790398,0.0029401606880128384,-0.014651062898337841,-0.009318306110799313,0.01822330802679062,0.019901007413864136,0.002202707575634122,0.003464141394942999,0.0073665534146130085,-0.014449591748416424,-0.0014002956449985504,0.01639820821583271,0.010666480287909508,0.00931896548718214,-0.0015187592944130301,-0.023576384410262108,-0.00443253805860877,0.014584994874894619,-0.0053917961195111275,0.01415127795189619,0.011401182971894741,-0.0006382536957971752,0.018119532614946365,0.009133468382060528,0.012955060228705406,-0.0014709169045090675,-0.016649436205625534,0.02026389352977276,0.0006713725160807371,0.015495236963033676,0.003925270866602659,0.00319079402834177,-0.003925030119717121,-0.021138904616236687,-0.00461933808401227,-0.005469720810651779,0.00739274313673377,0.019258851185441017,0.02616351842880249,0.023124778643250465,-0.00566488690674305,0.01773357018828392,0.023644834756851196,0.0047590043395757675,0.017013562843203545,-0.0032865749672055244,-0.018152205273509026,-0.010509730316698551,0.004198023583739996,0.011710388585925102,-0.00446705985814333,0.002852680627256632,-0.002007831586524844,-0.000134904301376082,-0.01944751851260662,0.017555125057697296,0.007372296415269375,0.013482901267707348,-0.01416250690817833,0.009404434822499752,0.002286749193444848,0.005182494409382343,-0.0028514256700873375,0.004553719889372587,-0.0026370203122496605,-0.0011353131849318743,0.011851341463625431,-0.00646215071901679,-0.013426951132714748,0.020288217812776566,0.006485862657427788,0.01353476569056511,-0.015545669943094254,0.006692144554108381,0.0026561636477708817,0.0048660943284630775,-0.018292417749762535,-0.007460114546120167,0.022227099165320396,0.0106017105281353,0.05320962518453598,-0.02265460416674614,-0.01131453923881054,0.012853817082941532,-0.0002959979756269604,0.025417005643248558,-0.00955783948302269,0.0014118781546130776,-0.00904284231364727,-0.008947938680648804,-0.007168934214860201,-0.00964303594082594,-0.004022146109491587,-0.005613087210804224,-0.12938329577445984,-0.0043584736995399,0.020456742495298386,0.0071443296037614346,-0.011277008801698685,-0.02349260449409485,-0.010244361124932766,-0.00665429187938571,-0.010064574889838696,0.005249082110822201,0.005279236473143101,0.017985159531235695,-0.02883007377386093,0.010324330069124699,-0.012035149149596691,0.008913593366742134,0.008274752646684647,-0.0018126015784218907,-0.004603218752890825,0.00580825237557292,0.008159039542078972,0.01880655251443386,0.0002549282507970929,-0.004038217011839151,0.005237426608800888,-0.018459560349583626,-0.00046851334627717733,0.0023338748142123222,-0.0042199338786304,-0.006385834887623787,0.011244351975619793,0.0007573044276796281,0.01756402850151062,-0.008600994013249874,-0.0022277063690125942,-0.0030407358426600695,-0.007221739273518324,0.01820104382932186,-0.02493535354733467,0.01585320197045803,-0.0005586881306953728,0.0033721248619258404,-0.00026433906168676913,-0.000743469747249037,0.005868381354957819,0.006111698690801859,-0.0011203524190932512,0.011258958838880062,-0.0008901173714548349,-0.011496561579406261,-0.008037720806896687,0.016194118186831474,0.011407424695789814,-0.014084485359489918,0.017604801803827286,0.002007188042625785,-0.006658796686679125,-0.009705387987196445,0.015173210762441158,0.006459673400968313,-0.00285873725079,0.019698521122336388,0.012200135737657547,-0.008034748956561089,0.0028521015774458647,-0.00245031644590199,-0.006310049910098314,-0.00373665289953351,0.008135923184454441,-0.0090325390920043,-0.0002607999776955694,0.0046803392469882965,-0.01800999790430069,-0.008924789726734161,0.01823682151734829,-0.007351914420723915,-0.019322993233799934,0.012701595202088356,0.0053284624591469765,-0.0064052678644657135,0.019654009491205215,0.00013570864393841475,0.016256112605333328,0.007728443015366793,0.010437853634357452,0.00808533001691103,0.019011886790394783,0.012183984741568565,0.033292051404714584,0.005902435164898634,-0.018925726413726807,-0.00701944762840867,0.011261066421866417,0.005332435946911573,0.0031362916342914104,0.0005442180554382503,-0.0032328530214726925,-0.010592673905193806,-0.018920287489891052,-0.009756236337125301,-0.005785324610769749,-0.030977396294474602,0.001599933486431837,0.00013377821596805006,0.008112323470413685,-0.0063599590212106705,-0.005695757456123829,0.00597459077835083,0.01210800651460886,-0.006559251341968775,0.0007339463336393237,0.011125277727842331,0.022035440430045128,0.017060229554772377,0.01003420352935791,-0.0034310349728912115,0.00637843506410718,0.011094809509813786,-0.013998170383274555,-0.014564729295670986,0.01242771651595831,-0.0036663247738033533,-0.000654135481454432,0.00626980047672987,-0.0076171220280230045,-0.0020285514183342457,0.006653873715549707,0.012656455859541893,-0.01786595582962036,-0.008405892178416252,0.01965014822781086,-0.0021813763305544853,0.010792931541800499,-0.015798313543200493,-0.015769999474287033,-0.006753129884600639,-0.015076013281941414,0.007592670153826475,0.006454171612858772,0.02763102576136589,-0.008400551043450832,-0.0049078394658863544,-0.024386631324887276,0.006857115309685469,0.001914125750772655,-0.01439663302153349,-0.020056629553437233,0.008954518474638462,0.013706443831324577,0.007875348441302776,0.012146084569394588,-0.009473125450313091,0.009648504666984081,0.015645135194063187,0.01922854408621788,0.0068963672965765,0.008811811916530132,0.013530968688428402,-0.017957940697669983,-0.01021209079772234,0.0022633387707173824,-0.007277818396687508,-0.0031573977321386337,-0.11325757950544357,-0.0026099944952875376,0.01439537201076746,-0.004530924838036299,0.001019970397464931,-0.0020006245467811823,-0.004129558335989714,0.015971921384334564,-0.044551171362400055,0.0030149968806654215,0.007847486063838005,-0.01554462406784296,0.007680688984692097,-0.00788731686770916,-0.017942272126674652,-0.000786610587965697,0.005577197298407555,0.009266538545489311,-0.009329116903245449,-0.04451880231499672,-0.0037785109598189592,0.0028084840159863234,-0.009803786873817444,-0.010790380649268627,0.002866531489416957,0.0017853827448561788,0.007238357327878475,-0.007430804427713156,-0.004662869498133659,0.004536635708063841,0.01837938465178013,0.01211519818753004,0.0014415101613849401,-5.029150634072721e-05,0.021934866905212402,-0.010267108678817749,-0.013645731844007969,-0.015742121264338493,0.008256089873611927,-0.04040089249610901,0.07481249421834946,0.007236475590616465,0.009462444111704826,-0.027326276525855064,0.003720212262123823,0.000653174240142107,-0.002285812282934785,-0.0037178313359618187,0.012064619921147823,0.006163128651678562,-4.221188646624796e-05,-0.004891624208539724,-0.009622621349990368,0.0006778354290872812,0.013634954579174519,-0.020278330892324448,-0.004124345723539591,0.007662141229957342,0.018916331231594086,-0.0036245116498321295,0.01430609729140997,-0.01053135097026825,-0.012238960713148117,-0.016030864790081978,0.002648538677021861,0.014399755746126175,-0.008265534415841103,0.017143085598945618,-0.014470246620476246,-5.842742757522501e-05,-0.004861831199377775,-0.015087821520864964,-0.006019762251526117,0.01629151962697506,0.010227116756141186,-0.003751903073862195,-0.01222227606922388,0.0076263234950602055,0.042506661266088486,-0.01409455481916666,-0.0125817796215415,0.006965314969420433,-0.1917276829481125,0.00950542837381363,-0.01586632803082466,0.0023973588831722736,0.005743181332945824,-0.0027462500147521496,0.013118598610162735,0.011540125124156475,-4.4238830014364794e-05,0.0049981833435595036,0.010282487608492374,0.0003759496030397713,0.01399040874093771,0.018821081146597862,-0.014726671390235424,0.004507406149059534,0.011466688476502895,-0.005345562938600779,0.003956358879804611,-0.0034813869278877974,-0.0006390218622982502,-0.012699902057647705,0.006115961819887161,-0.00699468981474638,-0.00933891348540783,0.0034024324268102646,0.0066421241499483585,-0.002772600157186389,-0.00560781080275774,0.0124791469424963,0.008322587236762047,-0.009324386715888977,0.019184015691280365,-0.01484056655317545,0.004880982916802168,0.009200002998113632,-0.004697439726442099,-0.0016762494342401624,0.005595938302576542,0.0051397476345300674,0.015112820081412792,0.0016515520401299,0.0027893949300050735,0.004518795292824507,0.02610747143626213,0.010790864005684853,-0.00240150885656476,0.0018596394220367074,-0.00877827126532793,0.016919050365686417,-0.006034755613654852,0.004655871074646711,-0.007221192587167025,-0.010618927888572216,-0.010135614313185215,0.0057146274484694,-0.0011658620787784457,8.326552051585168e-05,-0.0037010847590863705,0.007693116553127766,-0.011633782647550106,-0.0017288855742663145,0.008993348106741905,0.006360128056257963,-0.006610793061554432,0.02352437563240528,0.001936598913744092,-0.011150550097227097,-0.01644146628677845,0.0009796085068956017,0.0030192439444363117,-0.0053696841932833195,0.013059624470770359,-0.0033805544953793287,0.016168439760804176,0.0018524626502767205,0.012617220170795918,0.005636119283735752,-0.016038715839385986,0.010487047955393791,-0.007545631844550371,-0.001429348485544324,-0.0017839670181274414,-0.008450678549706936,0.005330666434019804,-0.02991759404540062,0.00345455389469862,0.018851209431886673,-0.009807764552533627,0.027462579309940338,0.007071391679346561,0.0019209625897929072,-0.018841171637177467,-0.005503535736352205,0.02069077454507351,-0.020384222269058228,0.00936795026063919,0.007733526639640331,-0.009904591366648674,-0.004870839882642031,-0.03102888911962509,0.010977471247315407,0.015817424282431602,0.0011372757144272327,0.0072667705826461315,0.00784523319453001,-0.003772204741835594,0.015585226006805897,0.006962628103792667,-0.005917835980653763,-0.004866400267928839,-0.002367018721997738,0.005616626236587763,0.008822798728942871,-0.012629799544811249,-0.011987242847681046,0.0032996777445077896,0.0023828642442822456,0.012849369086325169,0.010437403805553913,0.008191507309675217,0.014551647007465363,-0.00907558761537075,-0.012082315981388092,-0.01734895631670952,-0.025283891707658768,0.011902658268809319,0.01442468911409378,-0.00960622914135456,0.009892510250210762,0.006284326780587435,0.09945326298475266,-0.000902246858458966,0.010209871456027031,0.006395020522177219,-0.014969841577112675,0.006021085660904646,0.005478468257933855,0.006624804809689522,-0.005861262790858746,0.018376680091023445,-0.005344887264072895,-0.008701054379343987,0.017867742106318474,0.02290046401321888,0.004558425396680832,-0.0031763159204274416,0.009653178043663502,0.017748555168509483,0.0004191588668618351,-0.020645441487431526,-0.0037479782477021217,0.01151856780052185,-0.018366899341344833,0.013412505388259888,-0.006302890833467245,0.006716001313179731,-0.00566723570227623,0.021751975640654564,-0.009203510358929634,-0.005479597952216864,-0.0036258467007428408,0.011007815599441528,-0.019736887887120247,0.0033232851419597864,-0.00348482932895422,0.005073791369795799,0.017230041325092316,0.020670218393206596,0.004283766727894545,-0.0009454562095925212,0.002031994052231312,-0.017311764881014824,-0.013582253828644753,-0.012368597090244293,0.010673816315829754,-0.0031707175076007843,0.008417531847953796,-0.004093330819159746,-0.01342865638434887,0.006839676760137081,0.007039966061711311,0.002886531176045537,-0.010179306380450726,0.01376741286367178,0.003229884896427393,-0.002050425624474883,-0.006090544629842043,-0.01241382211446762,-0.004899153020232916,-0.007758493069559336,-0.007976759225130081,-0.01766863465309143,0.0025243479758501053,0.0038350399117916822,0.011882581748068333,0.004422273952513933,-0.03836751729249954,-0.01081705279648304,-0.007251629140228033,-0.007358638569712639,0.007515196222811937,0.021443774923682213,-0.011086410842835903,0.003115957835689187,0.01913968101143837,0.023567553609609604,0.0044838543981313705,0.002975921845063567,-0.01662723533809185,-0.006301764864474535,0.011563225649297237,-0.007714479696005583,0.007416438311338425,-0.035197507590055466,0.009823915548622608,-0.017413947731256485,0.011747097596526146,-0.0038893171586096287,0.021576901897788048,0.01757732592523098,0.013345262035727501,-0.006837489083409309,0.029992317780852318,-0.011094197630882263,0.010682325810194016,0.002443913836032152,-0.0005208277725614607,-0.01606852374970913,0.010624848306179047,0.0047839065082371235,0.01419053040444851,-0.01350423227995634,0.012274585664272308,0.012537653557956219,0.007614258676767349,-0.0039986432529985905,0.010640677064657211,-0.0038547625299543142,-0.006087520159780979,0.027305202558636665,0.006098201964050531,-0.00494043156504631,0.004934415221214294,-0.01824975572526455,0.001602957840077579,0.026787754148244858,0.005400836933404207,0.008201074786484241,0.022710701450705528,0.005333361215889454,0.007449979893863201,-0.00023634797253180295,-0.011554860509932041,0.00011505313159432262,0.006364085711538792,0.0009316215291619301,0.012276645749807358,-0.002286005299538374,0.007153740152716637,-0.00578177347779274,-0.003366011893376708,0.016108853742480278,-0.007560239173471928,-0.012466534040868282,5.5177883041324094e-05,0.013790159486234188,-0.012926618568599224,1.878943839983549e-05,0.0008286013035103679,-0.0036813300102949142,-0.0005811856244690716,-0.0008696871809661388,-0.008247340098023415,0.02868564799427986,-0.014315041713416576,-0.017415814101696014,0.006972618401050568,-0.024270612746477127,-0.009373226203024387,0.0051077669486403465,0.0038382895290851593,-0.01722528040409088,0.015512949787080288,0.01026356965303421,0.00711700227111578,-0.010315561667084694,0.01249308604747057,0.014615736901760101,-0.002677438547834754,0.005468305200338364,-0.005088237579911947,-0.018737059086561203,-0.003193721640855074,0.0038784947246313095,0.0009255004115402699,0.006019891239702702,0.0115288645029068,-0.018515832722187042,-0.005315995309501886,0.0148364482447505,0.009229088202118874,-0.002652656752616167,0.005572419613599777,0.007090028841048479,-0.00805481243878603,0.027019791305065155,-0.005165357608348131,0.01384897343814373,-0.01675380766391754,0.014895391650497913,0.001922378083691001,-0.007131235208362341,0.010457383468747139,-0.0060896435752511024,-0.0035761059261858463,-0.017283009365200996,0.013179706409573555,0.01639494299888611,0.0069476836360991,-0.010041441768407822,-0.004489645827561617,-0.01367124542593956,-0.0003028188075404614,0.012466919608414173,-0.010653103701770306,0.008282281458377838,0.003187681082636118,-0.01343492977321148,-0.010245668701827526,-0.011471674777567387,-0.01613684557378292,-0.0010712954681366682,-0.0027505853213369846,-0.001911632250994444,-0.0011440966045483947,-0.02027985267341137,-0.003082658164203167,-0.0005120121641084552,-0.004386079031974077,-0.010168688371777534,0.0036431557964533567,0.006260099820792675,-0.010663633234798908,-0.002148623578250408,-0.002349805785343051,0.0030768970027565956,-0.0034179803915321827,-0.008466539904475212,-0.011844230815768242,-0.005494784563779831,0.0010436181910336018,0.011641600169241428,-0.011137792840600014,7.610687316628173e-05,0.005389544181525707,-0.023192087188363075,-0.005416119936853647,-0.009617231786251068,0.008793344721198082,-0.024386076256632805,0.020657410845160484,5.134117236593738e-05,-0.007362756412476301,-0.009800750762224197,0.006533399689942598,-0.010050579905509949,0.006684471387416124,0.011441572569310665,0.006047689355909824,0.016310229897499084,-0.005246692802757025,0.007157488260418177,0.0017344196094200015,-0.00866750068962574,0.0006803951691836119,0.00713065592572093,-0.0014674743870273232,0.0203915573656559,-0.005685457959771156,-0.007061901036649942,-0.016780640929937363,0.001550675486214459,-0.008510038256645203,-0.011533658020198345,-0.008761588484048843,0.022064397111535072,-0.0017128309700638056,0.0062705883756279945,0.0048079160042107105,0.018406344577670097,0.010051971301436424,0.003991404082626104,0.012091951444745064,-0.005227489396929741,-0.0035770712420344353,-0.009186764247715473,-0.0038295702543109655,-0.00698986416682601,0.012210141867399216,0.005487545393407345,-0.0013136116322129965,0.0018605402437970042,-0.011810770258307457,-0.001065592747181654,0.0004330579249653965,0.024547435343265533,-0.0043790326453745365,-0.0002492174389772117,-0.0189106035977602,-0.010918785817921162,0.020448731258511543,0.007792806718498468,-0.002034664386883378,0.008813790045678616,-0.01989891566336155,0.001182962441816926,0.000261572131421417,-0.0074978540651500225,0.0019776527769863605,-0.011139015667140484,-0.02664639614522457,0.0028707943856716156,0.007007550913840532,-0.017508666962385178,-0.014156038872897625,-0.02033647708594799,0.016214512288570404,0.006000136490911245,-0.016533177345991135,0.018597586080431938,0.005563668441027403,-0.00725555419921875,0.01448176521807909,0.016186457127332687,-0.016622057184576988,0.007171966601163149,0.009879093617200851,0.014025414362549782,0.015332052484154701,0.018447238951921463,0.01657157577574253,-0.01883309707045555,0.0012578627793118358,-0.01160209160298109,-0.0029103304259479046,-0.024813447147607803,-0.008269749581813812,0.019136399030685425,0.12509235739707947,0.00992282573133707,-0.010059620253741741,-0.006295362021774054,-0.009466594085097313,-0.005341983400285244,-0.006175258196890354,-0.00834791548550129,0.0037003285251557827,-0.009935236535966396,-0.022054295986890793,-0.021636681631207466,0.00747463246807456,0.0023884624242782593,0.0020293877460062504,0.000621370563749224,-0.010186834260821342,0.0025970444548875093,0.004555682651698589,0.010875705629587173,-0.00799268577247858,-0.010559020563960075,-0.018151158466935158,0.006607222370803356,0.00013699558621738106,0.0032064514234662056,-0.01213186327368021,0.017665095627307892,-0.001385656651109457,-0.013753159902989864,-0.0032455134205520153,0.004236889537423849,0.011882774531841278,-0.014331771992146969,0.007972095161676407,0.0015528311487287283,0.0077825915068387985,0.0031973575241863728,0.007028214633464813,-0.014710456132888794,0.019549252465367317,-0.013456358574330807,0.006737617775797844,-0.015732519328594208,0.0006138741155155003,0.0037009399384260178,0.011282256804406643,0.010245632380247116,0.002517430577427149,0.007911423221230507,0.00890109408646822,-0.010392270050942898,-0.017399711534380913,-0.02358563430607319,-0.006632172502577305,0.010217915289103985,-0.022281570360064507,0.007806669920682907,0.013242524117231369,-0.0033365730196237564,0.026809824630618095,-0.013774974271655083,-0.00872904434800148,-0.010284706950187683,-0.014805947430431843,0.015970248728990555,0.017862962558865547,0.015086662955582142,0.0027441910933703184,0.010856385342776775,-0.004200211260467768,-0.0081545514985919,0.0031795732211321592,-0.026753583922982216,0.014192008413374424,-0.012117899954319,-0.0035813823342323303,0.015963943675160408,-0.0860016718506813,0.03140305355191231,0.007273109629750252,-0.00939896609634161,0.008446688763797283,-0.00541621632874012,-0.0522768460214138,-0.0012892642989754677,-0.009854674339294434,-0.0076980385929346085,-0.015288103371858597,-0.03279374539852142,-0.014441356062889099,-0.005670452956110239,-0.0029624251183122396,-0.012520995922386646,-0.0102844825014472,-0.017415877431631088,-0.015840580686926842,-0.013365293852984905,-0.009166606701910496,-0.005349005106836557,-0.005249958485364914,0.019897757098078728,-0.007069654297083616,-0.009444724768400192,0.004441514145582914,-0.01018715649843216,0.009931439533829689,0.002962167840451002,-0.013154460117220879,0.014917655847966671,-0.015001467429101467,0.009532036259770393,-0.0044509246945381165,0.028517216444015503,0.00990370661020279,-0.010221325792372227,-0.010877507738769054,0.0023901837412267923,0.02150103636085987,-0.014040149748325348,-0.0007246803143061697,0.00785189401358366,0.0014458857476711273,-0.0006708737928420305,0.004349204711616039,-0.01244916021823883,-0.01190697681158781,-0.1309737116098404,-0.0030378401279449463,0.005152037832885981,-0.025020644068717957,0.013737556524574757,0.01354216504842043,-0.010803540237247944,-0.020594704896211624,-0.010123742744326591,-0.005482333246618509,0.007814539596438408,0.0062471660785377026,0.011471273377537727,0.014933951199054718,0.010366315953433514,-0.017068468034267426,0.0075530968606472015,0.0021459211129695177,-0.005174430552870035,0.004797837696969509,-0.0006980726611800492,-0.01761162281036377,-0.011748763732612133,0.007687899749726057,-0.015306426212191582,0.007811580318957567,-0.004673641175031662,0.019404791295528412,0.006644575856626034,-0.009581189602613449,0.01846865750849247,-0.00799687672406435,-0.008734514936804771,0.025797318667173386,0.004079817328602076,0.01512935757637024,-0.0006804736331105232,-0.0038689833600074053,0.006711303722113371,-0.014750850386917591,0.016202479600906372,0.01031462848186493,-0.005430308170616627,0.01708185113966465,0.008559875190258026,-0.005445751361548901,-0.0028198380023241043,-0.0038498397916555405,-0.006423091981559992,0.013393329456448555,0.008289198391139507,0.019474737346172333,0.013462373986840248,-0.009793463163077831,-0.013543033972382545,0.03380116820335388,0.057620640844106674,0.0037551848217844963,0.01428164541721344,0.011203941889107227,-0.00013776373816654086,-0.007206891197711229,0.011069182306528091,-0.0032131224870681763,0.009809983894228935,0.006570447236299515,-0.002480398863554001,0.022422587499022484,0.011351908557116985,-0.01595130003988743,-0.019222430884838104,0.00509705301374197,-0.006570335011929274,0.0017189440550282598,0.027080731466412544,-0.011916235089302063,0.0015000663697719574,-0.0020198484417051077,-0.02209283970296383,0.006771082524210215,0.0002977755793835968,-0.019696606323122978,0.008564154617488384,-0.0007474914309568703,0.011921319179236889,0.009810338728129864,0.014718177728354931,0.0014345606323331594,0.008807356469333172,-0.006630355026572943,-0.003958745859563351,-0.009559383615851402,-0.005430855322629213,-0.014630086719989777,-0.011925501748919487,0.0004732106754090637,0.018642853945493698,-0.013681734912097454,0.010839325375854969,-0.014961443841457367,0.0016361128073185682,0.0032435106113553047,-0.002405848354101181,-0.018609875813126564,0.0033618290908634663,0.011865722015500069,-0.012829582206904888,0.008958829566836357,-0.011033131740987301,0.007112349383533001,-0.007317069917917252,-0.003843147773295641,0.015338101424276829,0.0060599129647016525,0.013022753410041332,0.022979997098445892,-0.010455581359565258,0.003293846268206835,0.011678189970552921,0.03189416974782944,-0.0003863417077809572,0.006824394688010216,-0.008517374284565449,0.012291766703128815,-0.008964218199253082,0.007173221092671156,0.019597060978412628,0.0208904929459095,-0.008607679978013039,0.02034304104745388,0.010004634968936443,0.011900341138243675,-0.00043498832383193076,0.0033996535930782557,-0.002569137839600444,0.009322158992290497,-0.002651530783623457,-0.008777949027717113,-0.005856899078935385,-0.013607734814286232,0.0010277243563905358,-0.011572104878723621,-0.023325929418206215,0.008436039090156555,0.0016878400929272175,-0.0035754949785768986,0.010810618288815022,0.020025212317705154,-0.009496903046965599,0.01064186729490757,0.0021814408246427774,-0.0061418297700583935,-0.006570986472070217,0.01253622304648161,0.01944899745285511,-0.010414046235382557,0.00017785617092158645,0.006716644857078791,0.011308281682431698,0.014264336787164211,-0.0031749242916703224,-0.020774956792593002,-0.0003114172432105988,0.011388715356588364,-0.009031891822814941,-0.006522138603031635,0.018276477232575417,0.0024473723024129868,0.002980136778205633,-0.007986669428646564,0.010007386095821857,0.009231405332684517,-0.018392913043498993,-0.020028775557875633,0.012274328619241714,-0.008668269030749798,0.0041609592735767365,-0.0037708855234086514,-0.009803260676562786,-0.004945358261466026,-0.01740073226392269,0.0035423238296061754,-0.007416149135679007,0.023602621629834175,0.005355633329600096,-0.0019859694875776768,0.01988109014928341,7.979076144692954e-06,-0.006595607381314039,0.0053070830181241035,0.008229612372815609,0.016438249498605728,0.006289506796747446,0.00754022691398859,0.011281898245215416,0.00024167270748876035,0.006314409431070089,-0.0031186926644295454,-0.02108895592391491,-0.013352083042263985,0.020173614844679832,0.008024762384593487,0.013543741777539253,-0.015686606988310814,-0.008190031163394451,0.015606686472892761,-0.008021931163966656,-0.015871604904532433,0.0037902863696217537,0.0008586193434894085,0.003796238452196121,-0.010971165262162685,0.007283883169293404,-0.016522156074643135,0.0055426545441150665,-0.018035799264907837,-0.009387576021254063,-0.00015417633403558284,-0.009344720281660557,-0.005082639399915934,0.007296253461390734,-0.009880026802420616,-0.002254636026918888,0.02115420438349247,-0.00485372357070446,0.004400492645800114,-0.00884152390062809,-0.006040804088115692,0.011755109764635563,0.008026177994906902,-0.006253858096897602,-0.0029635189566761255,0.007403810508549213,0.0043754614889621735,0.026068542152643204,-0.024823419749736786,-0.004859900567680597,0.0077138361521065235,0.0007009119726717472,-0.018028592690825462,-0.011082421988248825,-0.007141128182411194,-0.01778709888458252,0.009043511003255844,0.0008742235950194299,0.019595323130488396,-0.00226938771083951,-0.0021313303150236607,0.0028745909221470356,0.013393265195190907,0.0035802884958684444,-0.0015817874809727073,0.006639556493610144,0.006195977795869112,-0.007812898606061935,-0.008897827938199043,-0.012519138865172863,0.014377216808497906,0.00478403503075242,-0.004690281115472317,0.003118644468486309,0.027247516438364983,-0.002435001777485013,0.033513087779283524,0.01822897233068943,0.007350771687924862,0.0011077403323724866,0.013501819223165512,-0.015879904851317406,0.013183299452066422,0.011308056302368641,-0.0003690966113936156,-5.669004895025864e-05,0.006077144294977188,-0.0071005732752382755,0.005103584378957748,0.012177292257547379,-0.0015176330925896764,0.00743842963129282,0.006680489517748356,0.004452131222933531,0.004653377924114466,-0.008840574882924557,-0.0031223606783896685,-0.013772077858448029,-0.005994860082864761,0.0052159992046654224,0.00597047246992588,-0.004418735392391682,-0.009556038305163383,-0.005633131135255098,0.02587483637034893,-0.002589789219200611,-0.0176318921148777,-0.009988966397941113,-0.015307571738958359,-0.009621800854802132,-0.002565787872299552,-0.01531350426375866,0.014097933657467365,-0.0033172364346683025,0.001826854539103806,0.0018190363189205527,-0.008359553292393684,-0.0038599425461143255,-0.004618598148226738,-0.0021358828525990248,-0.0039221663028001785,-0.0034684045240283012,-0.004433149006217718,0.006080731749534607,-0.0017949383473023772,-0.008630593307316303,0.001273048692382872,-0.019467659294605255,-6.12587173236534e-05,-0.018115075305104256,-0.006602621171623468,-0.007384441327303648,-0.007939839735627174,0.0019286199240013957,0.0008089773473329842,-0.01783713512122631,0.010118434205651283,-0.014237920753657818,0.01597065106034279,0.016588177531957626,-0.01785440556704998,0.01155418436974287,-0.005966603755950928,-0.014077438972890377,-0.013903025537729263,-0.002557036466896534,-0.021007491275668144,-0.005378428380936384,0.012218442745506763,0.004273728467524052,0.011610778979957104,-0.004312143661081791,0.01642666570842266,-0.023566925898194313,0.013862889260053635,0.015911821275949478,0.004173909313976765,-0.024028481915593147,-0.01222963910549879,-0.005391822662204504,0.011719332076609135,-0.007083456497639418,-0.0073945121839642525,0.010108668357133865,0.013066895306110382,-0.0004766210913658142,-0.006762267090380192,-0.0007032324792817235,0.0023309518583118916,0.012527922168374062,-0.006683377083390951,0.012418627738952637,-0.008594752289354801,-0.0089180339127779,-0.0018390804762020707,-0.01272482518106699,0.015199174173176289,-0.012042034417390823,-0.010652774013578892,0.001955002313479781,0.009363831952214241,-0.009031509980559349,-0.0028586569242179394,-0.0013132980093359947,0.009787592105567455,0.008148052729666233,0.004363750107586384,0.009258558973670006,-0.024081429466605186,0.01084060501307249,0.02108844183385372,-0.01939285360276699,0.011464710347354412,-0.010239985771477222,-0.009829654358327389,0.02925250120460987,-0.006770503241568804,-0.0068392264656722546,0.0012964068446308374,-0.016846660524606705,0.0068872300907969475,-0.003937834873795509,-8.339421765413135e-05,0.008675314486026764,-0.005402928218245506,-0.009232563897967339,0.011987275443971157,0.006109446752816439,-0.006341531407088041,0.007804907858371735,-0.007662084884941578,0.006093183066695929,-0.018207769840955734,-0.006304789334535599,0.000968299456872046,0.011293482035398483,0.0006706284475512803,0.00998291838914156,-0.016655774787068367,0.004729790613055229,0.008077752776443958,-0.0064179119653999805,-0.006763167679309845,0.0055464874021708965,-0.006630998104810715,-0.006346454378217459,0.0029069576412439346,0.004286420997232199,-0.00612212298437953,0.009613017551600933,-0.007194488774985075,-0.014121548272669315,-0.013963254168629646,0.008268116973340511,0.018683167174458504,0.00021566831856034696,0.010583395138382912,0.0023251124657690525,0.005577534902840853,-0.005223962478339672,-0.010808792896568775,-0.00891019869595766,0.0025711446069180965,-0.009238084778189659,0.00847254041582346,0.002356433542445302,-0.020508840680122375,0.008203793317079544,-0.013110458850860596,-0.00429300032556057,0.00894743949174881,-0.0010654800571501255,0.007953747175633907,0.0008857498760335147,0.008226757869124413,0.006239090580493212,-0.003030576976016164,-0.011644785292446613,-0.016018863767385483,0.0014197607524693012,0.012671319767832756,-0.014869586564600468,-0.011633380316197872,-0.0008804009412415326,0.005208792630583048,-0.009140313602983952,-0.004907278809696436,-0.01574484072625637,0.007207204587757587,-0.025614989921450615,0.010377657599747181,0.005622417200356722,0.020156607031822205,-8.534072549082339e-05,-0.013232074677944183,0.0025512452702969313,0.0074208625592291355,0.003769534407183528,0.006363023538142443,0.001976124243810773,-0.009836303070187569,0.014816982671618462,-0.02623211219906807,-0.013312103226780891,0.018329545855522156,0.011043942533433437,0.004413313698023558,-0.0026370524428784847,-0.006824623793363571,-0.01342408824712038,0.01530361082404852,0.02297188900411129,-0.015759512782096863,-0.0038370348047465086,0.008708260953426361,0.0386798270046711,0.006922588218003511,-0.014513103291392326,0.006315784528851509,0.0011656669666990638,-0.00011241488391533494,-0.0043263561092317104,0.006935876328498125,0.01871299184858799,-0.0018523683538660407,0.01645565964281559,0.0006411654176190495,-0.017343293875455856,0.01558641716837883,0.003914637491106987,-0.003911966923624277,0.010716164484620094,0.010333998128771782,0.009289140813052654,0.002327702473849058,-0.0016474217409268022,0.0085306940600276,-0.006147765554487705,-0.0027541646268218756,0.012298844754695892,-0.011853464879095554,0.0022197917569428682,0.009226707741618156,0.02173178642988205,-0.017738966271281242,-0.010917370207607746,-0.0029402251821011305,0.0004863214853685349,-0.0067732385359704494,-0.009347519837319851,-0.0026199843268841505,0.00044122201506979764,0.007049706764519215,-0.005566982086747885,-0.009083359502255917,0.005341717973351479,0.0016353566898033023,0.0075265211053192616,-0.025540797039866447,-0.00833797361701727,-0.00534829730167985,-0.004227929282933474,0.016433872282505035,0.006095499265938997,0.0034416201524436474,0.006703711114823818,-0.013493518345057964,-0.00048759233322925866,0.02160598710179329,-0.018758028745651245,-0.013188640587031841,0.00872473418712616,0.01274280995130539,-0.002263290574774146,-0.0006550966063514352,-0.01119509432464838,-0.010811157524585724,-0.007531395647674799,0.0025357375852763653,0.01623639091849327,0.012533069588243961,-0.11452934145927429,-0.014385758899152279,-0.0036055126693099737,0.002186845988035202,0.013855954632163048,-0.0006583944195881486,0.0048728990368545055,0.009528513066470623,0.003839930286630988,0.01954481191933155,0.001959699671715498,-0.00801488570868969,0.01553120743483305,0.010433783754706383,0.00287243933416903,0.0030284454114735126,0.0027071910444647074,0.005127111449837685,0.007968137040734291,0.004281257279217243,-0.011975499801337719,-0.017328623682260513,0.008220185525715351,0.007401622831821442,-0.013764807023108006,0.007864666171371937,-0.004687312990427017,-0.004217983223497868,-0.01190197467803955,0.005709093064069748,0.012869670987129211,-0.013801033608615398,-0.011998728848993778,0.20357556641101837,-0.0030479426495730877,0.012771195732057095,-0.0171239972114563,0.005747669842094183,0.00899829063564539,-0.014829105697572231,0.00494075333699584,-0.008008965291082859,-0.0036376866046339273,-0.033662255853414536,0.0065314690582454205,-0.009848415851593018,0.013626010157167912,0.012002847157418728,-0.013834439218044281,0.02108149044215679,0.016931405290961266,-0.0017394707538187504,-0.00963470246642828,-0.005704395938664675,0.01754046231508255,-0.015337469056248665,0.015215389430522919,-0.005915905814617872,-0.025276893749833107,-0.005014732480049133,-0.00463339826092124,-0.020541712641716003,-0.001968644093722105,0.000676644966006279,0.01785305328667164,-0.011794249527156353,0.016294624656438828,-0.004089083522558212,0.006442975252866745,-0.02364637888967991,-0.010055324994027615,0.008496284484863281,0.005891228560358286,0.010857462882995605,-0.0347641259431839,-0.014917171560227871,0.017434941604733467,-0.01820305548608303,-0.02300403080880642,-0.01460286695510149,-0.026439635083079338,-0.005786696448922157,0.005840812344104052,-0.002880639396607876,0.005296160001307726,-0.004211021587252617,-0.002037527970969677,-0.010035361163318157,0.004914330784231424,0.004394669085741043,0.005622674711048603,0.0011111185885965824,0.009060111828148365,-0.01080778706818819,-0.014376429840922356,-0.008422542363405228,0.0036981890443712473,-0.026923397555947304,0.009801522828638554,-0.0014322763308882713,-0.013493984937667847,0.012008155696094036,0.012425931170582771,0.009741486981511116,0.02373787946999073,0.0018142102053388953,-0.0050240508280694485,0.01613137498497963,0.005036276765167713,0.0027613716665655375,0.005145667586475611,-0.005073678679764271,0.00631151394918561,0.015935149043798447,0.005443435162305832,-0.0074535515159368515,0.012360554188489914,0.009225227870047092,0.010121893137693405,0.0003564523358363658,0.0020175480749458075,0.0005545940366573632,-0.018256383016705513,-0.0015494207618758082,-0.004463328048586845,0.010256974026560783,0.005540004465728998,-0.005248623434454203,0.005901942495256662,0.010503585450351238,-0.008990907110273838,0.008495476096868515,-0.029623478651046753,-0.0010746014304459095,0.010479615069925785,0.007128741126507521,-0.004881907254457474,-0.012746831402182579,-0.005546809174120426,-0.004563066177070141,0.0002746024983935058,-0.012642459943890572,-0.003734111087396741,0.01777506433427334,0.0049340128898620605,-0.0012290994636714458,-0.00021181550982873887,0.0020156176760792732,0.0010072377044707537,0.003468742361292243,-0.003944575320929289,0.014315459877252579,-0.005033606663346291,0.004686838481575251,-0.012386228889226913,0.0018407534807920456,0.004675609990954399,-0.0087699294090271,-0.005062884651124477,-0.0077690305188298225,0.00480366125702858,-0.012847527861595154,-0.007804791443049908,-0.0020366229582577944,0.010552520863711834,0.0009618164622224867,-0.02200361341238022,-0.02055400423705578,0.007025834172964096,0.005628401413559914,-0.003323606913909316,-0.00350605184212327,0.006432036403566599,0.004809271544218063,0.010274733416736126,0.04477909207344055,-0.009266168810427189,-0.014458194375038147,0.003407451556995511,-0.003966630436480045,0.00690626073628664,-0.005162558518350124,-0.017314080148935318,-0.0033658831380307674,-0.019236072897911072,-0.010986302979290485,-0.009487057104706764,-0.0126802958548069,0.009735309518873692,0.04154672846198082,-0.018142199143767357,0.002596642356365919,-0.0076661063358187675,0.013936100527644157,0.058171678334474564,-0.025674721226096153,-0.006219496950507164,-0.014702396467328072,0.007355244364589453,-0.01217672135680914,-0.01009633019566536,0.008379188366234303,-0.00898730382323265,-0.0017007015412673354,0.003610322717577219,0.0026148527394980192,0.0058074044063687325,-0.016003387048840523,-0.011510750278830528,0.0013994108885526657,-0.005675825756043196,-0.010906624607741833,0.003757855389267206,0.008256155997514725,0.0037957236636430025,0.0004637596430256963,0.0059378482401371,-0.006037457846105099,-0.018181998282670975,0.0013030506670475006,0.007541135419160128,0.009224391542375088,0.010982869192957878,-0.0036199912428855896,-0.002958113793283701,0.01651797443628311,-0.03149764612317085,0.004628603812307119,0.00334406946785748,-0.007923029363155365,0.015490380115807056,0.020828863605856895,0.016824204474687576,-0.0038670848589390516,0.014724436216056347,0.000400498160161078,0.0663076639175415,0.00567030580714345,-0.013410317711532116,0.008589716628193855,-0.008427352644503117,-0.01424303650856018,0.0008962303982116282,-0.009365360252559185,0.008820024318993092,0.013941312208771706,-0.007390265353024006,0.015612092800438404,0.008377837017178535,-0.006962129846215248,0.01604386232793331,0.004204136785119772,0.0069089229218661785,-0.0185052789747715,-0.013314954936504364,0.007275469601154327,0.014722811058163643,0.008437100797891617,0.011726523749530315,0.016620544716715813,0.015615695156157017,0.0120353102684021,0.006396838463842869,-0.008448812179267406,-0.00602632574737072,0.010790380649268627,0.002144247991964221,-0.014843912795186043,0.013109751045703888,-0.0005983744049444795,-0.01191713660955429,-0.0060539147816598415,0.007560625206679106,0.018343864008784294,-0.02141418308019638,-0.0038201757706701756,-0.0008210405358113348,0.0037896588910371065,0.00903385877609253,0.02255813404917717,0.0149000883102417,0.010207773186266422,0.01298686396330595,0.01658656820654869,-0.009689725004136562,-0.000968685548286885,-0.0354095958173275,-0.0020211192313581705,0.0172839667648077,0.017595110461115837,-0.007312276400625706,-0.009096597321331501,-0.012832960113883018,0.006029736716300249,0.01993134617805481,-0.007445869967341423,-0.013995345681905746,-0.021392418071627617,0.013174227438867092,0.0006699688965454698,0.0026909918524324894,0.0032831323333084583,0.012930993922054768,0.0012651460710912943,0.000811227539088577,0.01763002574443817,-0.00523826340213418,0.016636181622743607,-0.011958190239965916,-0.00934743881225586,0.011710581369698048,-0.009352635592222214,0.001517037977464497,0.022132251411676407,-0.0027835392393171787,-0.021134112030267715,0.000661684141959995,0.0020901961252093315,0.008411427959799767,-0.02320259064435959,-0.023216569796204567,-0.02040291577577591,-0.0019324647728353739,-0.012253865599632263,-0.012067129835486412,-0.012556578032672405,-0.006384226027876139,0.008578809909522533,-0.0006862648879177868,0.018786733970046043,0.008309703320264816,-0.004579378291964531,0.008779493160545826,-0.012430795468389988,0.010612075217068195,0.006497509777545929,0.00468828622251749,0.020637301728129387,0.014828919433057308,0.008801830001175404,-0.0012163587380200624,0.011090272106230259,0.00605464493855834,-0.00599315483123064,0.003595965448766947,0.0026772695127874613,0.007111930754035711,-0.0021474009845405817,-0.15517501533031464,-0.007093977648764849,0.016207048669457436,-0.003689244855195284,0.02290702797472477,-0.024147450923919678,0.02058466523885727,-0.003728344105184078,0.0020039579831063747,0.0036031962372362614,-0.00701624620705843,0.001598936039954424,-0.015112241730093956,-0.026839423924684525,-0.0005213304539211094,0.04432762786746025,0.0021426393650472164,0.008228357881307602,0.0006260357331484556,-0.0051366910338401794,0.0046644131653010845,-0.0015309208538383245,0.007084615062922239,-0.010650690644979477,-0.01891385205090046,-0.017962105572223663,-0.019904641434550285,-0.003021359210833907,0.00939719658344984,0.014427713118493557,0.0003639488131739199,0.01590440608561039,-0.007913827896118164,-0.008794532157480717,-0.004160219803452492,-0.00011183575406903401,-0.023288607597351074,0.001976816216483712,0.022937526926398277,-0.009748597629368305,-0.014059019275009632,-0.022420817986130714,0.014181907288730145,0.0013818360166624188,0.0023023937828838825,-0.007540484424680471,0.01842080056667328,0.006028867792338133,-0.022552955895662308,-0.005644746124744415,-0.0043883309699594975,-0.004599744454026222,-0.008561484515666962,0.014006786048412323,-0.011542826890945435,-0.009602931328117847,-0.036284975707530975,0.0013754897518083453,0.012572064064443111,0.006309454329311848,-0.0002941721468232572,-0.004653667565435171,-0.013862421736121178,0.004336177371442318,0.010433993302285671,0.009525666013360023,-0.006532643456012011,-0.0015942708123475313,0.014698229730129242,0.013635436072945595,0.01483591366559267,0.004928945563733578,0.011660551652312279,0.00346562173217535,-0.009555619210004807,0.01836557686328888,0.011766644194722176,0.005703310016542673,-0.005696287844330072,0.008640498854219913,0.00856035016477108,-0.03719845414161682,0.016891704872250557,0.009445746429264545,-0.0034338664263486862,-0.005024726502597332,-0.016796855255961418,-0.008475210517644882,-0.017073003575205803,0.004128266125917435,0.016665266826748848,0.00954902358353138,0.010982382111251354,-0.008389675989747047,-0.012186558917164803,0.008364107459783554,0.017737936228513718,0.01394137553870678,0.013139929622411728,-0.008969285525381565,-0.01151264924556017,-0.007080208044499159,-0.02486042119562626,0.00451834499835968,0.01454064343124628,-0.0027549047954380512,-0.01847361959517002,0.012725340202450752,0.02681497111916542,0.0022874209098517895,0.0060871499590575695,-0.012228837236762047,-0.01910441741347313,-0.02300979010760784,0.004791234154254198,-0.00982105266302824,-0.007742567453533411,0.01883193850517273,0.0016032794956117868,-0.0007860033656470478,-0.00030844920547679067,0.0010288181947544217,-0.01645890437066555,0.014252045191824436,-0.01001357939094305,0.002469572238624096,-0.025139495730400085,-0.007612746674567461,-0.05701448768377304,0.008700916543602943,0.01902882568538189,-0.02189522795379162,0.015759384259581566,0.010229690931737423,-0.013251837342977524,-0.013460122980177402,-0.01524634100496769,0.0020383321680128574,0.014956198632717133,-0.007906491868197918,-0.013498730957508087,0.006993595976382494,0.003018873743712902,0.001712734461762011,0.03202492371201515,0.026156842708587646,0.008240841329097748,-0.017780285328626633,0.006188404746353626,-0.014345478266477585,0.0025132661685347557,0.011938242241740227,-0.00015267223352566361,0.0147481644526124,-0.00812479481101036,-0.0010659064864739776,-0.0005582457524724305,0.006272712256759405,-0.004541509784758091,0.0014816629700362682,-0.02871515043079853,0.0016121916705742478,-0.02394980750977993,0.0008420820813626051,-0.007255136035382748,-0.006515704095363617,-0.005095303524285555,-0.005030743312090635,-0.011658716946840286,0.028127659112215042,0.00975873228162527,0.021014409139752388,-0.0160182137042284,0.008259791880846024,-0.00808415561914444,-0.011482791975140572,-0.0018780268728733063,-0.0016436574514955282,0.01837550289928913,0.0003763035056181252,0.009928029961884022,-0.008596843108534813,-0.0039632199332118034,0.01536337286233902,0.0038513196632266045,0.01520631741732359,-0.012446328997612,0.01358643639832735,-0.01477467454969883,0.0018546526553109288,-0.013842265121638775,-0.0008109700866043568,0.015721803531050682,0.006470515858381987,-0.01047314889729023,-0.017738599330186844,-0.002085148822516203,-0.00151948316488415,0.000500236579682678,-0.011062928475439548,-0.012429083697497845,-0.008604375645518303,-0.0033165609929710627,0.0162813700735569,-0.00872577540576458,0.006237449590116739,0.0014139856211841106,0.00227738288231194,0.007259607780724764,-0.0024163410998880863,-0.000929530244320631,0.01526214275509119,0.0005013305344618857,0.012352321296930313,0.0024202982895076275,-0.004930940456688404,0.005372138228267431,0.013471262529492378,0.011361593380570412,0.020780909806489944,-0.016667872667312622,-0.01875338703393936,-0.0006402565049938858,-0.0038189534097909927,-0.0173107348382473,-0.0007631341577507555,-0.004413474816828966,0.006579649168998003,-0.0007289272034540772,-0.016239607706665993,0.007476409897208214,5.302224599290639e-05,-0.01624462567269802,-0.014696476981043816,-0.0008294378640130162,6.569868855876848e-05,-0.006026261951774359,-0.0035658427514135838,0.00035259153810329735,-0.003949449863284826,0.009364716708660126,-0.010776331648230553,0.002928385278210044,-0.009490063413977623,-0.01819232851266861,0.004032875876873732,-0.0032316383440047503,0.00964342150837183,-0.0010484643280506134,-0.016542362049221992,-0.013282490894198418,-0.02188814990222454,0.014662325382232666,0.003973450977355242,0.01259040366858244,0.003396448213607073,0.0023380222264677286,-0.01695997640490532,0.012070347554981709,0.007248966954648495,0.011380953714251518,-0.009349804371595383,0.005258500576019287,0.01802116073668003,0.00570098590105772,-0.011989140883088112,0.011402743868529797,0.010607988573610783,0.008799505420029163,-0.009475105442106724,0.008064079098403454,-0.012264966033399105,-0.006731090601533651,0.00045869231689721346,-0.014379839412868023,-0.007578159682452679,-0.019541822373867035,0.02880922518670559,-0.01217967364937067,-0.0017422698438167572,0.009241893887519836,0.011424331925809383,-0.0059761349111795425,-0.10590112954378128,0.01093854196369648,-0.019668808206915855,-0.008417797274887562,-0.012183469720184803,-0.015398330055177212,0.022412968799471855,-0.014847170561552048,0.012399098835885525,-0.011321166530251503,-0.020581383258104324,-0.012875880114734173,0.009312482550740242,-0.01491408422589302,0.010381936095654964,0.014163745567202568,-0.00536081288009882,0.0030865189619362354,-0.017042148858308792,0.009154188446700573,0.003824438899755478,0.004048094153404236,-0.005840908735990524,-0.004764570388942957,-0.0011096063535660505,-0.01651327684521675,0.004218435846269131,0.0076619721949100494,0.016768736764788628,-0.010754378512501717,-0.007011130917817354,-0.0018741177627816796,0.004677861928939819,-0.0013004607753828168,0.02279837615787983,0.015664083883166313,-0.003047492355108261,-0.006805235054343939,-0.023204054683446884,0.011979939416050911,-0.01936367340385914,0.020488401874899864,0.0002779807255137712,0.01603945530951023,0.011033518239855766,-0.0034474434796720743,0.003860779106616974,0.0030094629619270563,-0.0025448587257415056,0.016781283542513847,0.0010827252408489585,-0.02335255965590477,0.000616254925262183,-0.0035649340134114027,0.0007393514970317483,-0.008183765225112438,0.0014471083413809538,0.0038755787536501884,0.007099337410181761,-0.012667966075241566,0.006208354607224464,-0.011235825717449188,-0.005788819864392281,-0.013990281149744987,-0.005277065094560385,-0.019661838188767433,-0.011538130231201649,0.011401553638279438,0.0067108855582773685,0.001396434847265482,0.0769028514623642,-0.0029904483817517757,0.002209946746006608,0.009979894384741783,-0.0010606379946693778,-0.016086678951978683,0.007984510622918606,0.018508948385715485,0.0032983184792101383,-0.004930043593049049,0.013569834642112255,1.877335125755053e-05,0.0041457414627075195,-0.0065275197848677635,0.01902691088616848,0.0049742781557142735,-0.008188189007341862,-0.004906102083623409,-0.0191107876598835,0.016605230048298836,-0.017471250146627426,0.010408093221485615,-0.008595138788223267,0.00039457817911170423,0.0075583732686936855,0.01484600454568863,0.011490130797028542,0.0035124020650982857,-0.006972779054194689,0.0128085408359766,0.006472124718129635,-0.011789342388510704,0.006717384327203035,-0.0022378091234713793,0.00325773935765028,0.0053901877254247665,0.008246632292866707,0.0030436997767537832,0.0072782342322170734,0.0012802877463400364,-0.00802643597126007,0.004147414583712816,0.008670682087540627,0.004049904178828001,0.0038673868402838707,0.014705437235534191,0.0026979250833392143,0.001775945769622922,-0.01869085803627968,0.0037806022446602583,0.012721864506602287,0.015738211572170258,-0.008133381605148315,-0.007445990107953548,-0.006062779109925032,0.005171599797904491,-0.007623749785125256,-0.001971603836864233,-0.03202363848686218,0.0014124091248959303,0.00964097585529089,-0.0062558529898524284,0.12542743980884552,-0.023395422846078873,-0.02142343297600746,0.00010404972999822348,0.0040498957969248295,0.009305443614721298,-0.005175766069442034,-0.006316371727734804,0.01862599514424801,0.01787419244647026,0.03209351748228073,-0.013965249061584473,-0.01298594195395708,0.003942033741623163,0.007697572000324726,-0.0037004253827035427,0.001353675965219736,0.004194419831037521,0.038188375532627106,-0.006305979564785957,0.008670156821608543,-0.011301315389573574,0.022354990243911743,0.011309697292745113,-0.006025111768394709,-0.02238098718225956,-0.014605054631829262,0.009788730181753635,-0.02146783284842968,-0.026633543893694878,0.008195299655199051,5.627179052680731e-05,-0.006054638884961605,0.018990008160471916,0.0018300878582522273,-0.006439500488340855,0.0015690467553213239,-0.004935315810143948,-0.005042776465415955,-0.008323850110173225,0.01732305809855461,0.004760194569826126,0.009951967746019363,0.002688618842512369,-0.02490813285112381,0.013938416726887226,-0.008612480014562607,0.017687037587165833,0.0007003569626249373,0.003144141985103488,0.00028641021344810724,0.006280304864048958,0.01704099029302597,-0.031904399394989014,-0.01954682171344757,0.006692659109830856,-0.0029927969444543123,-0.019856123253703117,0.01037242915481329,0.007297733798623085,-0.00034432284883223474,9.271252201870084e-05,3.400759305804968e-05,-0.008098633028566837,-0.017516130581498146,0.0009811046766117215,-0.007083006668835878,-0.013434672728180885,0.006502609234303236,0.00046227165148593485,-0.006619544234126806,-0.011502401903271675,-0.01764489896595478,-0.018358498811721802,-0.016132373362779617,0.01945388875901699,-0.004716904833912849,0.016170112416148186,0.002639401238411665,-0.008305462077260017,-0.030113548040390015,0.014484983868896961,0.049616213887929916,0.0026693870313465595,0.015345823019742966,0.0026869860012084246,0.019824400544166565,0.00838514044880867,0.0023412152659147978,-0.0035702185705304146,-0.007228761445730925,0.009889356791973114,-0.01150357536971569,0.006204118020832539,-0.007316265255212784,0.005138332024216652,-0.004389585927128792,-0.006546832155436277,-0.004268612712621689,0.022032320499420166,-0.014779822900891304,0.011949374340474606,0.0014258417068049312,0.0048449402675032616,0.02138534002006054,-0.0369078628718853,-0.0007908937404863536,-0.009307898581027985,0.009610539302229881,0.010517065413296223,-0.005397812929004431,-0.0021158468443900347,-0.003497409401461482,-0.0037914770655333996,-0.019967637956142426,0.002439747331663966,-0.020455583930015564,-0.006008759140968323,-0.008751148357987404,-0.018866462633013725,0.008806422352790833,-0.0035796293523162603,-0.003078668611124158,-0.004720652941614389,-0.010492903180420399],\"index\":0}],\"model\":\"vicuna-7b-v1.5\",\"usage\":{\"prompt_tokens\":13,\"total_tokens\":13}}" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Try text completion with" + ], + "metadata": { + "id": "-U2SZWTghxzc" + } + }, + { + "cell_type": "code", + "source": [ + "!curl http://127.0.0.1:8000/v1/completions \\\n", + " -H \"Content-Type: application/json\" \\\n", + " -d '{ \\\n", + " \"model\": \"vicuna-7b-v1.5\", \\\n", + " \"prompt\": \"Once upon a time\", \\\n", + " \"max_tokens\": 20, \\\n", + " \"temperature\": 0.5 \\\n", + " }'" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "85T5NO7Wh03R", + "outputId": "1a2c9568-2aa3-4a89-ecd8-8af496be1a41" + }, + "execution_count": 20, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{\"id\":\"cmpl-kB3gg4KtgcGdif9V4eNbh6\",\"object\":\"text_completion\",\"created\":1705782008,\"model\":\"vicuna-7b-v1.5\",\"choices\":[{\"index\":0,\"text\":\", there was a little girl named Alice. Alice lived in a small village nestled in a valley\",\"logprobs\":null,\"finish_reason\":\"length\"}],\"usage\":{\"prompt_tokens\":5,\"total_tokens\":24,\"completion_tokens\":19}}" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Try create_embeddings to analyze the prompts!" + ], + "metadata": { + "id": "EDxLbQDKVLiQ" + } + }, + { + "cell_type": "code", + "source": [ + "import json\n", + "import numpy as np\n", + "import requests\n", + "from scipy.spatial.distance import cosine\n", + "\n", + "\n", + "def get_embedding_from_api(word, model='vicuna-7b-v1.5'):\n", + " url = 'http://127.0.0.1:8000/v1/embeddings'\n", + " headers = {'Content-Type': 'application/json'}\n", + " data = json.dumps({\n", + " 'model': model,\n", + " 'input': word\n", + " })\n", + "\n", + " response = requests.post(url, headers=headers, data=data)\n", + " if response.status_code == 200:\n", + " embedding = np.array(response.json()['data'][0]['embedding'])\n", + " return embedding\n", + " else:\n", + " print(f\"Error: {response.status_code} - {response.text}\")\n", + " return None\n", + "\n", + "\n", + "def cosine_similarity(vec1, vec2):\n", + " return 1 - cosine(vec1, vec2)\n", + "\n", + "\n", + "def print_cosine_similarity(embeddings, texts):\n", + " for i in range(len(texts)):\n", + " for j in range(i + 1, len(texts)):\n", + " sim = cosine_similarity(embeddings[texts[i]], embeddings[texts[j]])\n", + " print(f\"Cosine similarity between '{texts[i]}' and '{texts[j]}': {sim:.2f}\")\n", + "\n", + "\n", + "texts = [\n", + " 'The quick brown fox',\n", + " 'The quick brown dog',\n", + " 'The fast brown fox',\n", + " 'A completely different sentence'\n", + "]\n", + "\n", + "embeddings = {}\n", + "for text in texts:\n", + " embeddings[text] = get_embedding_from_api(text)\n", + "\n", + "print_cosine_similarity(embeddings, texts)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bbrFoxgaplhK", + "outputId": "48e23158-1468-445d-a4cd-b5bd67bd3bde" + }, + "execution_count": 21, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cosine similarity between 'The quick brown fox' and 'The quick brown dog': 0.90\n", + "Cosine similarity between 'The quick brown fox' and 'The fast brown fox': 0.86\n", + "Cosine similarity between 'The quick brown fox' and 'A completely different sentence': 0.58\n", + "Cosine similarity between 'The quick brown dog' and 'The fast brown fox': 0.84\n", + "Cosine similarity between 'The quick brown dog' and 'A completely different sentence': 0.66\n", + "Cosine similarity between 'The fast brown fox' and 'A completely different sentence': 0.62\n" + ] + } + ] + } + ] +} diff --git a/playground/test_embedding/test_sentence_similarity.py b/playground/test_embedding/test_sentence_similarity.py index 0b9a54081..d7a8f6e5f 100644 --- a/playground/test_embedding/test_sentence_similarity.py +++ b/playground/test_embedding/test_sentence_similarity.py @@ -7,7 +7,7 @@ from scipy.spatial.distance import cosine -def get_embedding_from_api(word, model="vicuna-7b-v1.1"): +def get_embedding_from_api(word, model="vicuna-7b-v1.5"): if "ada" in model: resp = openai.Embedding.create( model=model, diff --git a/pyproject.toml b/pyproject.toml index f54ab30de..3770b350d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "fschat" -version = "0.2.33" +version = "0.2.36" description = "An open platform for training, serving, and evaluating large language model based chatbots." readme = "README.md" requires-python = ">=3.8" @@ -13,15 +13,15 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", ] dependencies = [ - "accelerate>=0.21", "einops", "fastapi", "gradio", "httpx", "markdown2[all]", "mysqlclient", "nh3", "numpy", - "peft", "prompt_toolkit>=3.0.0", "pydantic<2,>=1", "redis", "requests", "rich>=10.0.0", "sentencepiece", + "accelerate>=0.21", "aiohttp", "einops", "fastapi", "gradio", "httpx", "markdown2[all]", "mysqlclient", "nh3", "numpy", + "peft", "prompt_toolkit>=3.0.0", "pydantic", "redis", "requests", "rich>=10.0.0", "sentencepiece", "shortuuid", "SQLAlchemy", "slowapi", "tiktoken", "tokenizers>=0.12.1", "torch", - "transformers>=4.31.0", "uvicorn", "wandb", + "transformers>=4.31.0", "uvicorn", "wandb" ] [project.optional-dependencies] model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"] -webui = ["gradio"] +webui = ["gradio>=4.10"] train = ["einops", "flash-attn>=2.0", "wandb"] llm_judge = ["openai<1", "anthropic>=0.3", "ray"] dev = ["black==23.3.0", "pylint==2.8.2"] diff --git a/scripts/build-api.sh b/scripts/build-api.sh new file mode 100644 index 000000000..8198108e0 --- /dev/null +++ b/scripts/build-api.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# A rather convenient script for spinning up models behind screens + + +# Variables +PROJECT_DIR="$(pwd)" +CONDA_ENV_NAME="fastchat" # + +MODEL_PATH="HuggingFaceH4/zephyr-7b-beta" #beta is better than the alpha version, base model w/o quantization +MODEL_PATH="lmsys/vicuna-7b-v1.5" + +API_HOST="0.0.0.0" +API_PORT_NUMBER=8000 + + +# init the screens +check_and_create_screen() { + local SCREENNAME="$1" + if screen -list | grep -q "$SCREENNAME"; then + echo "Screen session '$SCREENNAME' exists. Doing nothing." + else + echo "Screen session '$SCREENNAME' not found. Creating..." + screen -d -m -S "$SCREENNAME" + echo "created!" + fi +} + +# convenience function for sending commands to named screens +send_cmd() { + local SCREENNAME="$1" + local CMD="$2" + screen -DRRS $SCREENNAME -X stuff '$2 \r' +} + +# hardcoded names, for baby api +SCREENNAMES=( + "controller" + "api" + # Worker screens include the devices they are bound to, if 'd0' is only worker it has full GPU access + "worker-d0" + "worker-d1" +) + +for screen in "${SCREENNAMES[@]}"; do + check_and_create_screen "$screen" + sleep 0.1 + # also activate the conda compute environment for these + screen -DRRS "$screen" -X stuff "conda deactivate \r" + screen -DRRS "$screen" -X stuff "conda activate $CONDA_ENV_NAME \r" + +done + + +# Send Commmands on a per Screen Basis +screen -DRRS controller -X stuff "python3 -m fastchat.serve.controller \r" + +screen -DRRS worker-d0 -X stuff "CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.model_worker --model-path $MODEL_PATH --conv-template one_shot --limit-worker-concurrency 1 \r" +screen -DRRS worker-d1 -X stuff "CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.model_worker --model-path $MODEL_PATH --port 21003 --worker-address http://localhost:21003 --conv-template one_shot --limit-worker-concurrency 1 \r" + +screen -DRRS api -X stuff "python3 -m fastchat.serve.openai_api_server --host $API_HOST --port $API_PORT_NUMBER \r" diff --git a/tests/launch_openai_api_test_server.py b/tests/launch_openai_api_test_server.py index f555a3882..823e8734e 100644 --- a/tests/launch_openai_api_test_server.py +++ b/tests/launch_openai_api_test_server.py @@ -2,6 +2,7 @@ Launch an OpenAI API server with multiple model workers. """ import os +import argparse def launch_process(cmd): @@ -9,27 +10,44 @@ def launch_process(cmd): if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--multimodal", action="store_true", default=False) + args = parser.parse_args() + launch_process("python3 -m fastchat.serve.controller") launch_process("python3 -m fastchat.serve.openai_api_server") - models = [ - ("lmsys/vicuna-7b-v1.5", "model_worker"), - ("lmsys/fastchat-t5-3b-v1.0", "model_worker"), - ("THUDM/chatglm-6b", "model_worker"), - ("mosaicml/mpt-7b-chat", "model_worker"), - ("meta-llama/Llama-2-7b-chat-hf", "vllm_worker"), - ] + if args.multimodal: + models = [ + ("liuhaotian/llava-v1.5-7b", "sglang_worker"), + ] + else: + models = [ + ("lmsys/vicuna-7b-v1.5", "model_worker"), + ("lmsys/fastchat-t5-3b-v1.0", "model_worker"), + ("THUDM/chatglm-6b", "model_worker"), + ("mosaicml/mpt-7b-chat", "model_worker"), + ("meta-llama/Llama-2-7b-chat-hf", "vllm_worker"), + ] for i, (model_path, worker_name) in enumerate(models): cmd = ( f"CUDA_VISIBLE_DEVICES={i} python3 -m fastchat.serve.{worker_name} " - f"--model-path {model_path} --port {30000+i} " - f"--worker-address http://localhost:{30000+i} " + f"--model-path {model_path} --port {40000+i} " + f"--worker-address http://localhost:{40000+i} " ) if worker_name == "vllm_worker": cmd += "--tokenizer hf-internal-testing/llama-tokenizer" launch_process(cmd) + if "llava" in model_path.lower(): + cmd += f"--tokenizer-path llava-hf/llava-1.5-7b-hf" + + if worker_name == "vllm_worker": + cmd += "--tokenizer hf-internal-testing/llama-tokenizer" + + launch_process(cmd) + while True: pass diff --git a/tests/test_openai_api.py b/tests/test_openai_api.py index 064069833..4493dce2c 100644 --- a/tests/test_openai_api.py +++ b/tests/test_openai_api.py @@ -4,24 +4,25 @@ Launch: python3 launch_openai_api_test_server.py """ +import warnings import openai - from fastchat.utils import run_cmd + openai.api_key = "EMPTY" # Not support yet -openai.api_base = "http://localhost:8000/v1" +openai.base_url = "http://localhost:8000/v1/" def test_list_models(): - model_list = openai.Model.list() - names = [x["id"] for x in model_list["data"]] + model_list = openai.models.list() + names = [x.id for x in model_list.data] return names def test_completion(model, logprob): prompt = "Once upon a time" - completion = openai.Completion.create( + completion = openai.completions.create( model=model, prompt=prompt, logprobs=logprob, @@ -38,7 +39,7 @@ def test_completion(model, logprob): def test_completion_stream(model): prompt = "Once upon a time" - res = openai.Completion.create( + res = openai.completions.create( model=model, prompt=prompt, max_tokens=64, @@ -47,19 +48,19 @@ def test_completion_stream(model): ) print(prompt, end="") for chunk in res: - content = chunk["choices"][0]["text"] + content = chunk.choices[0].text print(content, end="", flush=True) print() def test_embedding(model): - embedding = openai.Embedding.create(model=model, input="Hello world!") - print(f"embedding len: {len(embedding['data'][0]['embedding'])}") - print(f"embedding value[:5]: {embedding['data'][0]['embedding'][:5]}") + embedding = openai.embeddings.create(model=model, input="Hello world!") + print(f"embedding len: {len(embedding.data[0].embedding)}") + print(f"embedding value[:5]: {embedding.data[0].embedding[:5]}") def test_chat_completion(model): - completion = openai.ChatCompletion.create( + completion = openai.chat.completions.create( model=model, messages=[{"role": "user", "content": "Hello! What is your name?"}], temperature=0, @@ -69,11 +70,16 @@ def test_chat_completion(model): def test_chat_completion_stream(model): messages = [{"role": "user", "content": "Hello! What is your name?"}] - res = openai.ChatCompletion.create( + res = openai.chat.completions.create( model=model, messages=messages, stream=True, temperature=0 ) for chunk in res: - content = chunk["choices"][0]["delta"].get("content", "") + try: + content = chunk.choices[0].delta.content + if content is None: + content = "" + except Exception as e: + content = chunk.choices[0].delta.get("content", "") print(content, end="", flush=True) print() @@ -135,7 +141,7 @@ def test_openai_curl(): test_chat_completion_stream(model) try: test_embedding(model) - except openai.error.APIError as e: + except openai.APIError as e: print(f"Embedding error: {e}") print("===== Test curl =====") diff --git a/tests/test_openai_vision_api.py b/tests/test_openai_vision_api.py new file mode 100644 index 000000000..2f089c418 --- /dev/null +++ b/tests/test_openai_vision_api.py @@ -0,0 +1,162 @@ +""" +Test the OpenAI compatible server + +Launch: +python3 launch_openai_api_test_server.py --multimodal +""" + +import openai + +from fastchat.utils import run_cmd + +openai.api_key = "EMPTY" # Not support yet +openai.base_url = "http://localhost:8000/v1/" + + +def encode_image(image): + import base64 + from io import BytesIO + import requests + + from PIL import Image + + if image.startswith("http://") or image.startswith("https://"): + response = requests.get(image) + image = Image.open(BytesIO(response.content)).convert("RGB") + else: + image = Image.open(image).convert("RGB") + + buffered = BytesIO() + image.save(buffered, format="PNG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + + return img_b64_str + + +def test_list_models(): + model_list = openai.models.list() + names = [x.id for x in model_list.data] + return names + + +def test_chat_completion(model): + image_url = "https://picsum.photos/seed/picsum/1024/1024" + base64_image_url = f"data:image/jpeg;base64,{encode_image(image_url)}" + + # No Image + completion = openai.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about alpacas."}, + ], + } + ], + temperature=0, + ) + print(completion.choices[0].message.content) + print("=" * 25) + + # Image using url link + completion = openai.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s in this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ], + temperature=0, + ) + print(completion.choices[0].message.content) + print("=" * 25) + + # Image using base64 image url + completion = openai.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s in this image?"}, + {"type": "image_url", "image_url": {"url": base64_image_url}}, + ], + } + ], + temperature=0, + ) + print(completion.choices[0].message.content) + print("=" * 25) + + +def test_chat_completion_stream(model): + image_url = "https://picsum.photos/seed/picsum/1024/1024" + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s in this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ] + res = openai.chat.completions.create( + model=model, messages=messages, stream=True, temperature=0 + ) + for chunk in res: + try: + content = chunk.choices[0].delta.content + if content is None: + content = "" + except Exception as e: + content = chunk.choices[0].delta.get("content", "") + print(content, end="", flush=True) + print() + + +def test_openai_curl(): + run_cmd( + """curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "llava-v1.5-7b", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What’s in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://picsum.photos/seed/picsum/1024/1024" + } + } + ] + } + ], + "max_tokens": 300 + }' + """ + ) + + print() + + +if __name__ == "__main__": + models = test_list_models() + print(f"models: {models}") + + for model in models: + print(f"===== Test {model} ======") + test_chat_completion(model) + test_chat_completion_stream(model) + test_openai_curl()