From eb24fc6a3f02cec7085f65e156571e21b5c82886 Mon Sep 17 00:00:00 2001 From: CodingWithTim Date: Wed, 6 Nov 2024 20:58:41 +0000 Subject: [PATCH] add cerebras api --- fastchat/serve/monitor/classify/config.yaml | 6 +++--- fastchat/serve/monitor/classify/label.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/fastchat/serve/monitor/classify/config.yaml b/fastchat/serve/monitor/classify/config.yaml index 315f0dccc..21035d898 100644 --- a/fastchat/serve/monitor/classify/config.yaml +++ b/fastchat/serve/monitor/classify/config.yaml @@ -13,10 +13,10 @@ task_name: - creative_writing_v0.1 model_name: null -name: llama-3-70b-instruct +name: llama-3.1-70b endpoints: - - api_base: null - api_key: null + - api_base: https://api.cerebras.ai/v1 + api_key: CEREBRAS_API_KEY parallel: 50 temperature: 0.0 max_token: 512 diff --git a/fastchat/serve/monitor/classify/label.py b/fastchat/serve/monitor/classify/label.py index 2d0471a1f..76ab4b112 100644 --- a/fastchat/serve/monitor/classify/label.py +++ b/fastchat/serve/monitor/classify/label.py @@ -108,6 +108,7 @@ def get_answer( for category in categories: conv = category.pre_process(question["prompt"]) + start_time = time.time() output = chat_completion_openai( model=model_name, messages=conv, @@ -116,7 +117,9 @@ def get_answer( api_dict=api_dict, ) # Dump answers - category_tag[category.name_tag] = category.post_process(output) + category_tag[category.name_tag] = category.post_process( + output + ) | {"completion_time": round(time.time() - start_time, 5)} if testing: output_log[category.name_tag] = output