Skip to content

Commit

Permalink
chore: change Yi model SDK to OpenAI (#2910)
Browse files Browse the repository at this point in the history
  • Loading branch information
soulteary authored Mar 20, 2024
1 parent 180775a commit 5a1c29f
Showing 1 changed file with 93 additions and 4 deletions.
97 changes: 93 additions & 4 deletions api/core/model_runtime/model_providers/yi/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,119 @@
from collections.abc import Generator
from typing import Optional, Union
from urllib.parse import urlparse

import tiktoken

from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel


class YiLargeLanguageModel(OAIAPICompatLargeLanguageModel):
class YiLargeLanguageModel(OpenAILargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)

# yi-vl-plus not support system prompt yet.
if model == "yi-vl-plus":
prompt_message_except_system: list[PromptMessage] = []
for message in prompt_messages:
if not isinstance(message, SystemPromptMessage):
prompt_message_except_system.append(message)
return super()._invoke(model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream)

return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)

def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)

# refactored from openai model runtime, use cl100k_base for calculate token number
def _num_tokens_from_string(self, model: str, text: str,
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Calculate num tokens for text completion model with tiktoken package.
:param model: model name
:param text: prompt text
:param tools: tools for tool calling
:return: number of tokens
"""
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = len(encoding.encode(text))

if tools:
num_tokens += self._num_tokens_for_tools(encoding, tools)

return num_tokens

# refactored from openai model runtime, use cl100k_base for calculate token number
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
encoding = tiktoken.get_encoding("cl100k_base")
tokens_per_message = 3
tokens_per_name = 1

num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
# TODO: The current token calculation method for the image type is not implemented,
# which need to download the image and then get the resolution for calculation,
# and will increase the request delay
if isinstance(value, list):
text = ''
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']

value = text

if key == "tool_calls":
for tool_call in value:
for t_key, t_value in tool_call.items():
num_tokens += len(encoding.encode(t_key))
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(t_key))
num_tokens += len(encoding.encode(t_value))
else:
num_tokens += len(encoding.encode(str(value)))

if key == "name":
num_tokens += tokens_per_name

# every reply is primed with <im_start>assistant
num_tokens += 3

if tools:
num_tokens += self._num_tokens_for_tools(encoding, tools)

return num_tokens

@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials['mode'] = 'chat'

credentials['openai_api_key']=credentials['api_key']
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
credentials['endpoint_url'] = 'https://api.lingyiwanwu.com/v1'
credentials['openai_api_base']='https://api.lingyiwanwu.com'
else:
parsed_url = urlparse(credentials['endpoint_url'])
credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}"

0 comments on commit 5a1c29f

Please sign in to comment.