Skip to content

Commit

Permalink
Fix arena (#2522)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 6, 2023
1 parent c3ad73a commit 5573aae
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
19 changes: 11 additions & 8 deletions fastchat/serve/gradio_block_arena_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,22 +161,24 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
"gpt-3.5-turbo": 2,
"claude-2": 2,
"claude-instant-1": 2,
"deluxe-chat-v1": 4,
# tire 1
"palm-2": 1.5,
"llama-2-70b-chat": 1.5,
"llama-2-13b-chat": 1.5,
"codellama-34b-instruct": 1.5,
"vicuna-33b": 1.5,
"vicuna-13b": 1.5,
"mpt-30b-chat": 1.5,
"wizardlm-70b": 1.5,
"wizardlm-13b": 1.5,
# tier 2
"codellama-13b-instruct": 1.0,
"vicuna-7b": 1.0,
"llama-2-7b-chat": 1.0,
"chatglm2-6b": 1.0,
"mistral-7b-instruct": 1.0,
# deprecated
"codellama-13b-instruct": 1.0,
"mpt-30b-chat": 1.5,
"guanaco-33b": 1.0,
"fastchat-t5-3b": 0.5,
"alpaca-13b": 0.5,
Expand All @@ -193,9 +195,6 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re

SAMPLING_BOOST_MODELS = []

model_pairs = []
model_pairs_weights = []


def add_text(
state0, state1, model_selector0, model_selector1, text, request: gr.Request
Expand All @@ -208,7 +207,8 @@ def add_text(
# Init states if necessary
if states[0] is None:
assert states[1] is None
global model_pairs, model_pairs_weights
model_pairs = []
model_pairs_weights = []

# Pick two models
if len(model_pairs) == 0:
Expand All @@ -226,9 +226,12 @@ def add_text(

model_pairs_weights = model_pairs_weights / np.sum(model_pairs_weights)
# for p, w in zip(model_pairs, model_pairs_weights):
# print(p, w)
# print(p, w)

if len(model_pairs) >= 1:
# if len(model_pairs) != len(model_pairs_weights):
# print("model pairs", model_pairs, model_pairs_weights)
# print("#model pairs", len(model_pairs), len(model_pairs_weights))
idx = np.random.choice(len(model_pairs), p=model_pairs_weights)
model_left, model_right = model_pairs[idx]
else:
Expand Down Expand Up @@ -326,7 +329,7 @@ def bot_response_multi(
):
logger.info(f"bot_response_multi (anony). ip: {request.client.host}")

if state0.skip_next:
if state0 is None or state0.skip_next:
# This generate call is skipped due to invalid inputs
yield (
state0,
Expand Down
2 changes: 1 addition & 1 deletion fastchat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request)
try:
for i, data in enumerate(stream_iter):
if data["error_code"] == 0:
if i % 5 != 0: # reduce gradio's overhead
if i % 8 != 0: # reduce gradio's overhead
continue
output = data["text"].strip()
conv.update_last_message(output + "▌")
Expand Down

0 comments on commit 5573aae

Please sign in to comment.