Skip to content

Commit

Permalink
Add sandbox code over fastchat
Browse files Browse the repository at this point in the history
  • Loading branch information
digitsisyph committed Nov 16, 2024
1 parent b2aabdb commit fe0e4bb
Show file tree
Hide file tree
Showing 4 changed files with 381 additions and 9 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ FastChat's core features include:
- [Fine-tuning](#fine-tuning)
- [Citation](#citation)

----

For Software Areana, please follow the following extra steps:
1. Set your E2B API Key: `export E2B_API_KEY=<YOUR_API_KEY>`
2. Custom Component Build: Follow https://www.gradio.app/guides/custom-components-in-five-minutes to set up environment. Go into `custom_components/sandboxcomponent` and run `gradio cc build`.
3. Use `pip install custom_components/sandboxcomponent/dist/gradio_sandboxcomponent-xxx-py3-none-any.whl` to install the custom components.

----

## Install

### Method 1: With pip
Expand Down
95 changes: 88 additions & 7 deletions fastchat/serve/gradio_block_arena_named.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time

import gradio as gr
from gradio_sandboxcomponent import SandboxComponent
import numpy as np

from fastchat.constants import (
Expand All @@ -30,6 +31,7 @@
get_model_description_md,
)
from fastchat.serve.remote_logger import get_remote_logger
from fastchat.serve.sandbox.code_runner import DEFAULT_SANDBOX_INSTRUCTION, create_chatbot_sandbox_state, on_click_run_code, update_sandbox_config
from fastchat.utils import (
build_logger,
moderation_filter,
Expand Down Expand Up @@ -152,7 +154,10 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re


def add_text(
state0, state1, model_selector0, model_selector1, text, request: gr.Request
state0, state1,
model_selector0, model_selector1,
sandbox_state0, sandbox_state1,
text, request: gr.Request
):
ip = get_ip(request)
logger.info(f"add_text (named). ip: {ip}. len: {len(text)}")
Expand Down Expand Up @@ -204,6 +209,10 @@ def add_text(
* 6
)

# add snadbox instructions if enabled
if sandbox_state0['enable_sandbox']:
text = f"> {sandbox_state0['sandbox_instruction']}\n\n" + text

text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
for i in range(num_sides):
states[i].conv.append_message(states[i].conv.roles[0], text)
Expand All @@ -227,6 +236,8 @@ def bot_response_multi(
temperature,
top_p,
max_new_tokens,
sandbox_state0,
sandbox_state1,
request: gr.Request,
):
logger.info(f"bot_response_multi (named). ip: {get_ip(request)}")
Expand All @@ -251,6 +262,7 @@ def bot_response_multi(
top_p,
max_new_tokens,
request,
sandbox_state=sandbox_state0,
)
)

Expand Down Expand Up @@ -327,7 +339,7 @@ def build_side_by_side_ui_named(models):

states = [gr.State() for _ in range(num_sides)]
model_selectors = [None] * num_sides
chatbots = [None] * num_sides
chatbots: list[gr.Chatbot | None] = [None] * num_sides

notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")

Expand Down Expand Up @@ -366,6 +378,38 @@ def build_side_by_side_ui_named(models):
],
)

# sandbox states and components
sandbox_states: list[gr.State | None] = [None for _ in range(num_sides)]
sandboxes_components: list[tuple[
gr.Markdown, # sandbox_output
SandboxComponent, # sandbox_ui
gr.Code, # sandbox_code
] | None] = [None for _ in range(num_sides)]

with gr.Group():
with gr.Row():
for chatbotIdx in range(num_sides):
with gr.Column(scale=1):
sandbox_state = gr.State(create_chatbot_sandbox_state())
# Add containers for the sandbox output
sandbox_title = gr.Markdown(value=f"### Model {chatbotIdx + 1} Sandbox", visible=True)
with gr.Tab(label="Output"):
sandbox_output = gr.Markdown(value="", visible=False)
sandbox_ui = SandboxComponent(
value=("", ""),
show_label=True,
visible=False,
)
with gr.Tab(label="Code"):
sandbox_code = gr.Code(value="", interactive=False, visible=False)

sandbox_states[chatbotIdx] = sandbox_state
sandboxes_components[chatbotIdx] = (
sandbox_output,
sandbox_ui,
sandbox_code,
)

with gr.Row():
leftvote_btn = gr.Button(
value="👈 A is better", visible=False, interactive=False
Expand All @@ -378,6 +422,30 @@ def build_side_by_side_ui_named(models):
value="👎 Both are bad", visible=False, interactive=False
)


# chatbox sandbox global config
with gr.Group():
with gr.Row():
enable_sandbox_checkbox = gr.Checkbox(value=False, label="Enable Sandbox", interactive=True)
sandbox_env_choice = gr.Dropdown(choices=["React", "Auto"], label="Sandbox Environment", interactive=True)
with gr.Group():
with gr.Accordion("Sandbox Instructions", open=False):
sandbox_instruction_textarea = gr.TextArea(
value=DEFAULT_SANDBOX_INSTRUCTION
)

# update sandbox global config
enable_sandbox_checkbox.change(
fn=update_sandbox_config,
inputs=[
enable_sandbox_checkbox,
sandbox_env_choice,
sandbox_instruction_textarea,
*sandbox_states
],
outputs=[*sandbox_states]
)

with gr.Row():
textbox = gr.Textbox(
show_label=False,
Expand Down Expand Up @@ -452,7 +520,7 @@ def build_side_by_side_ui_named(models):
regenerate, states, states + chatbots + [textbox] + btn_list
).then(
bot_response_multi,
states + [temperature, top_p, max_output_tokens],
states + [temperature, top_p, max_output_tokens] + sandbox_states,
states + chatbots + btn_list,
).then(
flash_buttons, [], btn_list
Expand Down Expand Up @@ -488,25 +556,38 @@ def build_side_by_side_ui_named(models):

textbox.submit(
add_text,
states + model_selectors + [textbox],
states + model_selectors + sandbox_states + [textbox],
states + chatbots + [textbox] + btn_list,
).then(
bot_response_multi,
states + [temperature, top_p, max_output_tokens],
states + [temperature, top_p, max_output_tokens] + sandbox_states,
states + chatbots + btn_list,
).then(
flash_buttons, [], btn_list
)
send_btn.click(
add_text,
states + model_selectors + [textbox],
states + model_selectors + sandbox_states + [textbox],
states + chatbots + [textbox] + btn_list,
).then(
bot_response_multi,
states + [temperature, top_p, max_output_tokens],
states + [temperature, top_p, max_output_tokens] + sandbox_states,
states + chatbots + btn_list,
).then(
flash_buttons, [], btn_list
)

for chatbotIdx in range(num_sides):
chatbot = chatbots[chatbotIdx]
state = states[chatbotIdx]
sandbox_state = sandbox_states[chatbotIdx]
sandbox_components = sandboxes_components[chatbotIdx]

# trigger sandbox run
chatbot.select(
fn=on_click_run_code,
inputs=[state, sandbox_state, *sandbox_components],
outputs=[*sandbox_components],
)

return states + model_selectors
37 changes: 35 additions & 2 deletions fastchat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import time
import uuid
from typing import List
from gradio_sandboxcomponent import SandboxComponent

import gradio as gr
import requests
Expand All @@ -29,13 +30,15 @@
SESSION_EXPIRATION_TIME,
SURVEY_LINK,
)
from fastchat.conversation import Conversation
from fastchat.model.model_adapter import (
get_conversation_template,
)
from fastchat.model.model_registry import get_model_info, model_info
from fastchat.serve.api_provider import get_api_provider_stream_iter
from fastchat.serve.gradio_global_state import Context
from fastchat.serve.remote_logger import get_remote_logger
from fastchat.serve.sandbox.code_runner import RUN_CODE_BUTTON_HTML, ChatbotSandboxState
from fastchat.utils import (
build_logger,
get_window_url_params_js,
Expand Down Expand Up @@ -427,7 +430,11 @@ def bot_response(
request: gr.Request,
apply_rate_limit=True,
use_recommended_config=False,
sandbox_state: ChatbotSandboxState | None = None,
):
'''
The main function for generating responses from the model.
'''
ip = get_ip(request)
logger.info(f"bot_response. ip: {ip}")
start_tstamp = time.time()
Expand All @@ -450,7 +457,9 @@ def bot_response(
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return

conv, model_name = state.conv, state.model_name
conv: Conversation = state.conv
model_name: str = state.model_name

model_api_dict = (
api_endpoint_info[model_name] if model_name in api_endpoint_info else None
)
Expand Down Expand Up @@ -550,6 +559,14 @@ def bot_response(
return
output = data["text"].strip()
conv.update_last_message(output)

# Add a "Run in Sandbox" button to the last message if code is detected
if sandbox_state is not None and sandbox_state["enable_sandbox"]:
last_message = conv.messages[-1]
if "```" in last_message[1]:
if not last_message[1].endswith(RUN_CODE_BUTTON_HTML):
last_message[1] += "\n\n" + RUN_CODE_BUTTON_HTML

yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
except requests.exceptions.RequestException as e:
conv.update_last_message(
Expand Down Expand Up @@ -880,6 +897,17 @@ def build_single_model_ui(models, add_promotion_links=False):
{"left": r"\[", "right": r"\]", "display": True},
],
)

# Add containers for the sandbox output and JavaScript
# with gr.Column():
# sandbox_output = gr.Markdown(value="", visible=False)
# sandbox = SandboxComponent(
# label="Sandbox",
# value=("", ""),
# show_label=True,
# visible=False,
# )

with gr.Row():
textbox = gr.Textbox(
show_label=False,
Expand Down Expand Up @@ -969,10 +997,15 @@ def build_single_model_ui(models, add_promotion_links=False):
[state, chatbot] + btn_list,
)

# trigger sandbox run
# chatbot.select(on_click_run_code,
# inputs=[state, sandbox_output, sandbox],
# outputs=[sandbox_output, sandbox])

return [state, model_selector]


def build_demo(models):
def build_demo(models) -> gr.Blocks:
with gr.Blocks(
title="Chatbot Arena (formerly LMSYS): Free AI Chat to Compare & Test Best AI Chatbots",
theme=gr.themes.Default(),
Expand Down
Loading

0 comments on commit fe0e4bb

Please sign in to comment.