From 3060aad3ff01d9caf3945c42032041d4cf18649f Mon Sep 17 00:00:00 2001 From: Roberto Montoya Date: Fri, 25 Oct 2024 15:18:21 -0500 Subject: [PATCH] CD-153 - ran format.sh --- fastchat/conversation.py | 2 +- fastchat/model/model_adapter.py | 1 + fastchat/serve/api_provider.py | 15 +++++++++------ 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 4cb784643..5cf9c7e89 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -622,7 +622,7 @@ def to_jab_api_messages(self): if msg is not None: ret.append({"role": "assistant", "content": msg}) return ret - + def save_new_images(self, has_csam_images=False, use_remote_storage=False): import hashlib from fastchat.constants import LOGDIR diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 916997ab1..29763c9a1 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -2509,6 +2509,7 @@ def match(self, model_path: str): def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("api_based_default") + # Note: the registration order matters. # The one registered earlier has a higher matching priority. register_model_adapter(PeftModelAdapter) diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py index d165e54ea..9cdff1f08 100644 --- a/fastchat/serve/api_provider.py +++ b/fastchat/serve/api_provider.py @@ -252,7 +252,7 @@ def get_api_provider_stream_iter( messages=messages, api_base=model_api_dict["api_base"], api_key=model_api_dict["api_key"], - ) + ) else: raise NotImplementedError() @@ -1270,20 +1270,21 @@ def metagen_api_stream_iter( "error_code": 1, } + def jab_api_stream_iter( model_name, messages, api_base, - api_key, + api_key, ): import requests - headers = {'Content-Type': 'application/json', 'x-api-key': api_key} + headers = {"Content-Type": "application/json", "x-api-key": api_key} text_messages = [] for message in messages: text_messages.append(message) - + payload = { "model": model_name, "messages": text_messages, @@ -1295,7 +1296,9 @@ def jab_api_stream_iter( response = requests.post(api_base, json=payload, headers=headers) if response.status_code != 200: - logger.error(f"Unexpected response ({response.status_code}): {response.text}") + logger.error( + f"Unexpected response ({response.status_code}): {response.text}" + ) yield { "text": f"**API REQUEST FAILED** Reason: {response.status_code}.", "error_code": 1, @@ -1317,4 +1320,4 @@ def jab_api_stream_iter( yield { "text": f"**API REQUEST ERROR** Reason: Unknown.", "error_code": 1, - } \ No newline at end of file + }