Skip to content

Commit

Permalink
fix: minimax streaming function_call message (#4271)
Browse files Browse the repository at this point in the history
  • Loading branch information
Weaxs authored May 11, 2024
1 parent a80fe20 commit 8cc4927
Showing 1 changed file with 31 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ class MinimaxChatCompletionPro:
Minimax Chat Completion Pro API, supports function calling
however, we do not have enough time and energy to implement it, but the parameters are reserved
"""
def generate(self, model: str, api_key: str, group_id: str,
def generate(self, model: str, api_key: str, group_id: str,
prompt_messages: list[MinimaxMessage], model_parameters: dict,
tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
"""
generate chat completion
"""
if not api_key or not group_id:
raise InvalidAPIKeyError('Invalid API key or group ID')

url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}'

extra_kwargs = {}
Expand All @@ -42,7 +42,7 @@ def generate(self, model: str, api_key: str, group_id: str,

if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
extra_kwargs['top_p'] = model_parameters['top_p']

if 'plugin_web_search' in model_parameters and model_parameters['plugin_web_search']:
extra_kwargs['plugins'] = [
'plugin_web_search'
Expand All @@ -61,7 +61,7 @@ def generate(self, model: str, api_key: str, group_id: str,
# check if there is a system message
if len(prompt_messages) == 0:
raise BadRequestError('At least one message is required')

if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
if prompt_messages[0].content:
bot_setting['content'] = prompt_messages[0].content
Expand All @@ -70,7 +70,7 @@ def generate(self, model: str, api_key: str, group_id: str,
# check if there is a user message
if len(prompt_messages) == 0:
raise BadRequestError('At least one user message is required')

messages = [message.to_dict() for message in prompt_messages]

headers = {
Expand All @@ -89,21 +89,21 @@ def generate(self, model: str, api_key: str, group_id: str,

if tools:
body['functions'] = tools
body['function_call'] = { 'type': 'auto' }
body['function_call'] = {'type': 'auto'}

try:
response = post(
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
except Exception as e:
raise InternalServerError(e)

if response.status_code != 200:
raise InternalServerError(response.text)

if stream:
return self._handle_stream_chat_generate_response(response)
return self._handle_chat_generate_response(response)

def _handle_error(self, code: int, msg: str):
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
raise InternalServerError(msg)
Expand All @@ -127,7 +127,7 @@ def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage:
code = response['base_resp']['status_code']
msg = response['base_resp']['status_msg']
self._handle_error(code, msg)

message = MinimaxMessage(
content=response['reply'],
role=MinimaxMessage.Role.ASSISTANT.value
Expand All @@ -144,7 +144,6 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator
"""
handle stream chat generate response
"""
function_call_storage = None
for line in response.iter_lines():
if not line:
continue
Expand All @@ -158,54 +157,41 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator
msg = data['base_resp']['status_msg']
self._handle_error(code, msg)

# final chunk
if data['reply'] or 'usage' in data and data['usage']:
total_tokens = data['usage']['total_tokens']
message = MinimaxMessage(
minimax_message = MinimaxMessage(
role=MinimaxMessage.Role.ASSISTANT.value,
content=''
)
message.usage = {
minimax_message.usage = {
'prompt_tokens': 0,
'completion_tokens': total_tokens,
'total_tokens': total_tokens
}
message.stop_reason = data['choices'][0]['finish_reason']

if function_call_storage:
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
function_call_message.function_call = function_call_storage
yield function_call_message

yield message
minimax_message.stop_reason = data['choices'][0]['finish_reason']

choices = data.get('choices', [])
if len(choices) > 0:
for choice in choices:
message = choice['messages'][0]
# append function_call message
if 'function_call' in message:
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
function_call_message.function_call = message['function_call']
yield function_call_message

yield minimax_message
return

# partial chunk
choices = data.get('choices', [])
if len(choices) == 0:
continue

for choice in choices:
message = choice['messages'][0]

if 'function_call' in message:
if not function_call_storage:
function_call_storage = message['function_call']
if 'arguments' not in function_call_storage or not function_call_storage['arguments']:
function_call_storage['arguments'] = ''
continue
else:
function_call_storage['arguments'] += message['function_call']['arguments']
continue
else:
if function_call_storage:
message['function_call'] = function_call_storage
function_call_storage = None

minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)

if 'function_call' in message:
minimax_message.function_call = message['function_call']

# append text message
if 'text' in message:
minimax_message.content = message['text']

yield minimax_message
minimax_message = MinimaxMessage(content=message['text'], role=MinimaxMessage.Role.ASSISTANT.value)
yield minimax_message

0 comments on commit 8cc4927

Please sign in to comment.