From 6a40877c262b1861663b9dac0075d373fbf2d7eb Mon Sep 17 00:00:00 2001 From: Sophie Xie Date: Thu, 12 Dec 2024 23:44:50 -0800 Subject: [PATCH] add price plots --- fastchat/serve/gradio_web_server_multi.py | 29 +++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/fastchat/serve/gradio_web_server_multi.py b/fastchat/serve/gradio_web_server_multi.py index ec3f803eb..3d3ffc09c 100644 --- a/fastchat/serve/gradio_web_server_multi.py +++ b/fastchat/serve/gradio_web_server_multi.py @@ -7,6 +7,7 @@ import pickle import time from typing import List +import plotly.express as px import gradio as gr @@ -91,10 +92,34 @@ def build_visualizer(): with gr.Tab("Price Analysis", id=1): price_markdown = """ - ## *Price Control Data Visualizations* - Coming soon: Visualizations showing models' arena scores compared to their cost-effectiveness and output token prices. + ## *Price Analysis Visualizations* + Below is a scatterplot depicting a model’s arena score against its cost effectiveness. Start exploring and discover some interesting trends in the data! """ gr.Markdown(price_markdown) + model_keys = ['chatgpt-4o-latest', 'gemini-1.5-pro-exp-0827','gpt-4o-mini-2024-07-18','claude-3-5-sonnet-20240620','gemini-1.5-flash-exp-0827','llama-3.1-405b-instruct','gemini-1.5-pro-api-0514','mistral-large-2407','reka-core-20240722','gemini-1.5-flash-api-0514', 'deepseek-coder-v2-0724','yi-large','llama-3-70b-instruct','qwen2-72b-instruct','claude-3-haiku-20240307','llama-3.1-8b-instruct','mistral-large-2402','command-r','mixtral-8x22b-instruct-v0.1','gpt-3.5-turbo-0613'] + output_tokens_per_USD = [66.66666667000001,200.0,1666.666667,66.66666667000001,3333.333333,333.3333333,200.0,166.6666667,166.6666667,3333.333333,3333.333333,333.3333333,1265.8227849999998,1111.111111,800.0,11111.11111,166.6666667,666.6666667,166.6666667,500.0] + score=[1316.1559008799543,1300.8583398843484,1273.6004783067303,1270.113546648134,1270.530573909608,1266.244657076764,1259.2844314017723,1249.8268751367714,1229.2148108171098,1226.8769924152105,1214.5634252743123,1212.4668382698005,1206.3236747009742,1186.7832147344182,1178.5484948812955,1167.8793593807711,1157.271872307139,1148.6665817312062,1147.0325504217642,1117.0289441863001] + fig = px.scatter(x=output_tokens_per_USD, y=score, title="Quality vs. Cost Effectiveness", labels={ + "output_tokens_per_USD": "# of output tokens per USD (in thousands)", + "score": "Arena Score"}, log_x=True, text=model_keys) + fig.update_traces( + textposition="bottom center", + textfont=dict(size=16), + texttemplate='%{text}', + marker=dict(size=8), + hovertemplate=( + 'Model: %{text}
' # Show the model name + 'Output Tokens Per USD: %{x}
' # Show the x value (Output Price) + 'Arena Score: %{y}
' # Show the y value (Arena Score) + ) + ) + fig.update_xaxes(range=[1,4.5]) + fig.update_yaxes(range=[1100,1320]) + fig.update_layout(autosize=True, height=850, width=None, xaxis_title="# of output tokens per USD (in thousands)", yaxis_title= "Arena Score") + + + gr.Plot(fig, elem_id="plotly-graph") + def load_demo(context: Context, request: gr.Request):