diff --git a/README.md b/README.md index e2465f46c..6510b8ab7 100644 --- a/README.md +++ b/README.md @@ -237,6 +237,33 @@ This is the user interface that users will interact with. By following these steps, you will be able to serve your models using the web UI. You can open your browser and chat with a model now. If the models do not show up, try to reboot the gradio web server. +## Launch Chatbot Arena (side-by-side battle UI) + +Currently, Chatbot Arena is powered by FastChat. Here is how you can launch an instance of Chatbot Arena locally. + +FastChat supports popular API-based models such as OpenAI, Anthropic, Gemini, Mistral and more. To add a custom API, please refer to the model support [doc](./docs/model_support.md). Below we take OpenAI models as an example. + +Create a JSON configuration file `api_endpoint.json` with the api endpoints of the models you want to serve, for example: +``` +{ + "gpt-4o-2024-05-13": { + "model_name": "gpt-4o-2024-05-13", + "api_base": "https://api.openai.com/v1", + "api_type": "openai", + "api_key": [Insert API Key], + "anony_only": false + } +} +``` +For Anthropic models, specify `"api_type": "anthropic_message"` with your Anthropic key. Similarly, for gemini model, specify `"api_type": "gemini"`. More details can be found in [api_provider.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/api_provider.py). + +To serve your own model using local gpus, follow the instructions in [Serving with Web GUI](#serving-with-web-gui). + +Now you're ready to launch the server: +``` +python3 -m fastchat.serve.gradio_web_server_multi --register-api-endpoint-file api_endpoint.json +``` + #### (Optional): Advanced Features, Scalability, Third Party UI - You can register multiple model workers to a single controller, which can be used for serving a single model with higher throughput or serving multiple models at the same time. When doing so, please allocate different GPUs and ports for different model workers. ``` diff --git a/fastchat/serve/gradio_block_arena_anony.py b/fastchat/serve/gradio_block_arena_anony.py index 80ee2aaca..625c69c44 100644 --- a/fastchat/serve/gradio_block_arena_anony.py +++ b/fastchat/serve/gradio_block_arena_anony.py @@ -480,6 +480,12 @@ def build_side_by_side_ui_anony(models): elem_id="chatbot", height=650, show_copy_button=True, + latex_delimiters=[ + {"left": "$", "right": "$", "display": False}, + {"left": "$$", "right": "$$", "display": True}, + {"left": r"\(", "right": r"\)", "display": False}, + {"left": r"\[", "right": r"\]", "display": True}, + ], ) with gr.Row(): diff --git a/fastchat/serve/gradio_block_arena_named.py b/fastchat/serve/gradio_block_arena_named.py index 38fa2e9a5..2f7b39adb 100644 --- a/fastchat/serve/gradio_block_arena_named.py +++ b/fastchat/serve/gradio_block_arena_named.py @@ -358,6 +358,12 @@ def build_side_by_side_ui_named(models): elem_id=f"chatbot", height=650, show_copy_button=True, + latex_delimiters=[ + {"left": "$", "right": "$", "display": False}, + {"left": "$$", "right": "$$", "display": True}, + {"left": r"\(", "right": r"\)", "display": False}, + {"left": r"\[", "right": r"\]", "display": True}, + ], ) with gr.Row(): diff --git a/fastchat/serve/gradio_block_arena_vision.py b/fastchat/serve/gradio_block_arena_vision.py index 07f2d3a5b..b3d812220 100644 --- a/fastchat/serve/gradio_block_arena_vision.py +++ b/fastchat/serve/gradio_block_arena_vision.py @@ -356,6 +356,12 @@ def build_single_vision_language_model_ui( label="Scroll down and start chatting", height=650, show_copy_button=True, + latex_delimiters=[ + {"left": "$", "right": "$", "display": False}, + {"left": "$$", "right": "$$", "display": True}, + {"left": r"\(", "right": r"\)", "display": False}, + {"left": r"\[", "right": r"\]", "display": True}, + ], ) with gr.Row(): diff --git a/fastchat/serve/gradio_block_arena_vision_anony.py b/fastchat/serve/gradio_block_arena_vision_anony.py index 2dade176c..d4d4d484e 100644 --- a/fastchat/serve/gradio_block_arena_vision_anony.py +++ b/fastchat/serve/gradio_block_arena_vision_anony.py @@ -432,6 +432,12 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): elem_id="chatbot", height=650, show_copy_button=True, + latex_delimiters=[ + {"left": "$", "right": "$", "display": False}, + {"left": "$$", "right": "$$", "display": True}, + {"left": r"\(", "right": r"\)", "display": False}, + {"left": r"\[", "right": r"\]", "display": True}, + ], ) with gr.Row(): diff --git a/fastchat/serve/gradio_block_arena_vision_named.py b/fastchat/serve/gradio_block_arena_vision_named.py index 3048ac935..7c653acf3 100644 --- a/fastchat/serve/gradio_block_arena_vision_named.py +++ b/fastchat/serve/gradio_block_arena_vision_named.py @@ -372,6 +372,12 @@ def build_side_by_side_vision_ui_named(context: Context, random_questions=None): elem_id=f"chatbot", height=650, show_copy_button=True, + latex_delimiters=[ + {"left": "$", "right": "$", "display": False}, + {"left": "$$", "right": "$$", "display": True}, + {"left": r"\(", "right": r"\)", "display": False}, + {"left": r"\[", "right": r"\]", "display": True}, + ], ) with gr.Row(): diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index 98399e575..4f0521da0 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -873,6 +873,12 @@ def build_single_model_ui(models, add_promotion_links=False): label="Scroll down and start chatting", height=650, show_copy_button=True, + latex_delimiters=[ + {"left": "$", "right": "$", "display": False}, + {"left": "$$", "right": "$$", "display": True}, + {"left": r"\(", "right": r"\)", "display": False}, + {"left": r"\[", "right": r"\]", "display": True}, + ], ) with gr.Row(): textbox = gr.Textbox( diff --git a/fastchat/serve/monitor/clean_chat_data.py b/fastchat/serve/monitor/clean_chat_data.py index 2bda0e2c3..ec6da4a65 100644 --- a/fastchat/serve/monitor/clean_chat_data.py +++ b/fastchat/serve/monitor/clean_chat_data.py @@ -5,13 +5,16 @@ python3 clean_chat_data.py """ import argparse -import datetime import json import os +import hashlib from pytz import timezone -import time - +from functools import partial +from math import ceil +from datetime import datetime, timedelta from tqdm import tqdm +import time +import multiprocessing as mp from fastchat.serve.monitor.basic_stats import NUM_SERVERS from fastchat.serve.monitor.clean_battle_data import ( @@ -26,12 +29,20 @@ ) -def get_log_files(max_num_files=None): - dates = [] - for month in range(4, 12): - for day in range(1, 33): - dates.append(f"2023-{month:02d}-{day:02d}") +def date_range(start="2023-04-01"): + start_date = datetime.strptime(start, "%Y-%m-%d").date() + end_date = datetime.now().date() + delta = end_date - start_date + dates = [ + (start_date + timedelta(days=d)).strftime("%Y-%m-%d") + for d in range(delta.days + 2) + ] + return dates + + +def get_log_files(max_num_files=None): + dates = date_range() filenames = [] for d in dates: for i in range(NUM_SERVERS): @@ -44,90 +55,141 @@ def get_log_files(max_num_files=None): return filenames -def clean_chat_data(log_files, action_type): +def get_action_type_data(filename, action_type): + for _ in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + rows = [] + for l in lines: + row = json.loads(l) + if row["type"] == action_type: + rows.append(row) + return rows + + +def process_data(row, action_type): + try: + if action_type in ["chat", "upvote", "downvote"]: + state = row["state"] + model = row["model"] + elif action_type == "leftvote": + state = row["states"][0] + model = row["states"][0]["model_name"] + elif action_type == "rightvote": + state = row["states"][1] + model = row["states"][1]["model_name"] + conversation_id = state["conv_id"] + except KeyError: + return { + "ct_invalid_conv_id": 1, + } + + if conversation_id is None: + return { + "ct_invalid_conv_id": 1, + } + + conversation = to_openai_format(state["messages"][state["offset"] :]) + if not isinstance(model, str): + return { + "ct_invalid": 1, + } + model = replace_model_name(model, row["tstamp"]) + + try: + lang_code = detect_language(state["messages"][state["offset"]][1]) + except IndexError: + return { + "ct_invalid": 1, + } + + if not all(isinstance(x["content"], str) for x in conversation): + return { + "ct_invalid": 1, + } + + messages = "".join([x["content"] for x in conversation]).lower() + if NETWORK_ERROR_MSG in messages: + return { + "ct_network_error": 1, + } + user_id = hashlib.md5(row["ip"].encode()).hexdigest() + + # Prepare the result data + result = dict( + conversation_id=conversation_id, + model=model, + conversation=conversation, + turn=len(conversation) // 2, + language=lang_code, + user_id=user_id, + tstamp=row["tstamp"], + ) + + return { + "result": result, + "model": model, + } + + +def clean_chat_data(log_files, action_type, num_parallel): + with mp.Pool(num_parallel) as pool: + # Use partial to pass action_type to get_action_type_data + func = partial(get_action_type_data, action_type=action_type) + file_data = list( + tqdm( + pool.imap( + func, log_files, chunksize=ceil(len(log_files) / len(pool._pool)) + ), + total=len(log_files), + desc="Processing Log Files", + ) + ) + # filter out Nones as some files may not contain any data belong to action_type raw_data = [] - for filename in tqdm(log_files, desc="read files"): - for retry in range(5): - try: - lines = open(filename).readlines() - break - except FileNotFoundError: - time.sleep(2) - - for l in lines: - row = json.loads(l) - if row["type"] == action_type: - raw_data.append(row) + for data in file_data: + raw_data.extend(data) + raw_data = [r for r in raw_data if not (r is None)] + + # Use the multiprocessing Pool + with mp.Pool(num_parallel) as pool: + func = partial(process_data, action_type=action_type) + results = list( + tqdm( + pool.imap( + func, raw_data, chunksize=ceil(len(raw_data) / len(pool._pool)) + ), + total=len(raw_data), + desc="Processing Raw Data", + ) + ) - all_models = set() - all_ips = dict() - chats = [] + # Aggregate results from child processes ct_invalid_conv_id = 0 ct_invalid = 0 ct_network_error = 0 - for row in raw_data: - try: - if action_type in ["chat", "upvote", "downvote"]: - state = row["state"] - model = row["model"] - elif action_type == "leftvote": - state = row["states"][0] - model = row["states"][0]["model_name"] - elif action_type == "rightvote": - state = row["states"][1] - model = row["states"][1]["model_name"] - conversation_id = state["conv_id"] - except KeyError: - ct_invalid_conv_id += 1 - continue - - if conversation_id is None: - ct_invalid_conv_id += 1 - continue - - conversation = to_openai_format(state["messages"][state["offset"] :]) - if not isinstance(model, str): - ct_invalid += 1 - continue - model = replace_model_name(model, row["tstamp"]) - - try: - lang_code = detect_language(state["messages"][state["offset"]][1]) - except IndexError: - ct_invalid += 1 + all_models = set() + chats = [] + for data in tqdm(results): + if "ct_invalid_conv_id" in data: + ct_invalid_conv_id += data["ct_invalid_conv_id"] continue - - if not all(isinstance(x["content"], str) for x in conversation): - ct_invalid += 1 + if "ct_invalid" in data: + ct_invalid += data["ct_invalid"] continue - - messages = "".join([x["content"] for x in conversation]).lower() - if NETWORK_ERROR_MSG in messages: - ct_network_error += 1 + if "ct_network_error" in data: + ct_network_error += data["ct_network_error"] continue - - ip = row["ip"] - if ip not in all_ips: - all_ips[ip] = len(all_ips) - user_id = all_ips[ip] - - chats.append( - dict( - conversation_id=conversation_id, - model=model, - conversation=conversation, - turn=len(conversation) // 2, - language=lang_code, - user_id=user_id, - tstamp=row["tstamp"], - ) - ) - - all_models.update([model]) + all_models.update([data["model"]]) + chats.append(data["result"]) chats.sort(key=lambda x: x["tstamp"]) last_updated_tstamp = chats[-1]["tstamp"] - last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_datetime = datetime.fromtimestamp( last_updated_tstamp, tz=timezone("US/Pacific") ).strftime("%Y-%m-%d %H:%M:%S %Z") @@ -156,12 +218,13 @@ def clean_chat_data(log_files, action_type): parser = argparse.ArgumentParser() parser.add_argument("--action-type", type=str, default="chat") parser.add_argument("--max-num-files", type=int) + parser.add_argument("--num-parallel", type=int, default=16) args = parser.parse_args() log_files = get_log_files(args.max_num_files) - chats = clean_chat_data(log_files, args.action_type) + chats = clean_chat_data(log_files, args.action_type, args.num_parallel) last_updated_tstamp = chats[-1]["tstamp"] - cutoff_date = datetime.datetime.fromtimestamp( + cutoff_date = datetime.fromtimestamp( last_updated_tstamp, tz=timezone("US/Pacific") ).strftime("%Y%m%d")