-
Notifications
You must be signed in to change notification settings - Fork 8.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: change Yi model SDK to OpenAI (#2910)
- Loading branch information
Showing
1 changed file
with
93 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,119 @@ | ||
from collections.abc import Generator | ||
from typing import Optional, Union | ||
from urllib.parse import urlparse | ||
|
||
import tiktoken | ||
|
||
from core.model_runtime.entities.llm_entities import LLMResult | ||
from core.model_runtime.entities.message_entities import ( | ||
PromptMessage, | ||
PromptMessageTool, | ||
SystemPromptMessage, | ||
) | ||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel | ||
from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel | ||
|
||
|
||
class YiLargeLanguageModel(OAIAPICompatLargeLanguageModel): | ||
class YiLargeLanguageModel(OpenAILargeLanguageModel): | ||
def _invoke(self, model: str, credentials: dict, | ||
prompt_messages: list[PromptMessage], model_parameters: dict, | ||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | ||
stream: bool = True, user: Optional[str] = None) \ | ||
-> Union[LLMResult, Generator]: | ||
self._add_custom_parameters(credentials) | ||
|
||
# yi-vl-plus not support system prompt yet. | ||
if model == "yi-vl-plus": | ||
prompt_message_except_system: list[PromptMessage] = [] | ||
for message in prompt_messages: | ||
if not isinstance(message, SystemPromptMessage): | ||
prompt_message_except_system.append(message) | ||
return super()._invoke(model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream) | ||
|
||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) | ||
|
||
def validate_credentials(self, model: str, credentials: dict) -> None: | ||
self._add_custom_parameters(credentials) | ||
super().validate_credentials(model, credentials) | ||
|
||
# refactored from openai model runtime, use cl100k_base for calculate token number | ||
def _num_tokens_from_string(self, model: str, text: str, | ||
tools: Optional[list[PromptMessageTool]] = None) -> int: | ||
""" | ||
Calculate num tokens for text completion model with tiktoken package. | ||
:param model: model name | ||
:param text: prompt text | ||
:param tools: tools for tool calling | ||
:return: number of tokens | ||
""" | ||
encoding = tiktoken.get_encoding("cl100k_base") | ||
num_tokens = len(encoding.encode(text)) | ||
|
||
if tools: | ||
num_tokens += self._num_tokens_for_tools(encoding, tools) | ||
|
||
return num_tokens | ||
|
||
# refactored from openai model runtime, use cl100k_base for calculate token number | ||
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], | ||
tools: Optional[list[PromptMessageTool]] = None) -> int: | ||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. | ||
Official documentation: https://github.com/openai/openai-cookbook/blob/ | ||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" | ||
encoding = tiktoken.get_encoding("cl100k_base") | ||
tokens_per_message = 3 | ||
tokens_per_name = 1 | ||
|
||
num_tokens = 0 | ||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] | ||
for message in messages_dict: | ||
num_tokens += tokens_per_message | ||
for key, value in message.items(): | ||
# Cast str(value) in case the message value is not a string | ||
# This occurs with function messages | ||
# TODO: The current token calculation method for the image type is not implemented, | ||
# which need to download the image and then get the resolution for calculation, | ||
# and will increase the request delay | ||
if isinstance(value, list): | ||
text = '' | ||
for item in value: | ||
if isinstance(item, dict) and item['type'] == 'text': | ||
text += item['text'] | ||
|
||
value = text | ||
|
||
if key == "tool_calls": | ||
for tool_call in value: | ||
for t_key, t_value in tool_call.items(): | ||
num_tokens += len(encoding.encode(t_key)) | ||
if t_key == "function": | ||
for f_key, f_value in t_value.items(): | ||
num_tokens += len(encoding.encode(f_key)) | ||
num_tokens += len(encoding.encode(f_value)) | ||
else: | ||
num_tokens += len(encoding.encode(t_key)) | ||
num_tokens += len(encoding.encode(t_value)) | ||
else: | ||
num_tokens += len(encoding.encode(str(value))) | ||
|
||
if key == "name": | ||
num_tokens += tokens_per_name | ||
|
||
# every reply is primed with <im_start>assistant | ||
num_tokens += 3 | ||
|
||
if tools: | ||
num_tokens += self._num_tokens_for_tools(encoding, tools) | ||
|
||
return num_tokens | ||
|
||
@staticmethod | ||
def _add_custom_parameters(credentials: dict) -> None: | ||
credentials['mode'] = 'chat' | ||
|
||
credentials['openai_api_key']=credentials['api_key'] | ||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": | ||
credentials['endpoint_url'] = 'https://api.lingyiwanwu.com/v1' | ||
credentials['openai_api_base']='https://api.lingyiwanwu.com' | ||
else: | ||
parsed_url = urlparse(credentials['endpoint_url']) | ||
credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" |