diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index d10314ba039e63..1f5f64019a1663 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -37,3 +37,4 @@ - siliconflow - perfxcloud - zhinao +- fireworks diff --git a/api/core/model_runtime/model_providers/fireworks/__init__.py b/api/core/model_runtime/model_providers/fireworks/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/fireworks/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/fireworks/_assets/icon_l_en.svg new file mode 100644 index 00000000000000..582605cc422cce --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/_assets/icon_l_en.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/fireworks/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/fireworks/_assets/icon_s_en.svg new file mode 100644 index 00000000000000..86eeba66f9290a --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/_assets/icon_s_en.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/api/core/model_runtime/model_providers/fireworks/_common.py b/api/core/model_runtime/model_providers/fireworks/_common.py new file mode 100644 index 00000000000000..378ced3a4019ba --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/_common.py @@ -0,0 +1,52 @@ +from collections.abc import Mapping + +import openai + +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +class _CommonFireworks: + def _to_credential_kwargs(self, credentials: Mapping) -> dict: + """ + Transform credentials to kwargs for model instance + + :param credentials: + :return: + """ + credentials_kwargs = { + "api_key": credentials["fireworks_api_key"], + "base_url": "https://api.fireworks.ai/inference/v1", + "max_retries": 1, + } + + return credentials_kwargs + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError], + InvokeServerUnavailableError: [openai.InternalServerError], + InvokeRateLimitError: [openai.RateLimitError], + InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError], + InvokeBadRequestError: [ + openai.BadRequestError, + openai.NotFoundError, + openai.UnprocessableEntityError, + openai.APIError, + ], + } diff --git a/api/core/model_runtime/model_providers/fireworks/fireworks.py b/api/core/model_runtime/model_providers/fireworks/fireworks.py new file mode 100644 index 00000000000000..15f25badab994f --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/fireworks.py @@ -0,0 +1,27 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class FireworksProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + model_instance.validate_credentials( + model="accounts/fireworks/models/llama-v3p1-8b-instruct", credentials=credentials + ) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/fireworks/fireworks.yaml b/api/core/model_runtime/model_providers/fireworks/fireworks.yaml new file mode 100644 index 00000000000000..f886fa23b5bd82 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/fireworks.yaml @@ -0,0 +1,29 @@ +provider: fireworks +label: + zh_Hans: Fireworks AI + en_US: Fireworks AI +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#FCFDFF" +help: + title: + en_US: Get your API Key from Fireworks AI + zh_Hans: 从 Fireworks AI 获取 API Key + url: + en_US: https://fireworks.ai/account/api-keys +supported_model_types: + - llm +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: fireworks_api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/core/model_runtime/model_providers/fireworks/llm/__init__.py b/api/core/model_runtime/model_providers/fireworks/llm/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/fireworks/llm/_position.yaml b/api/core/model_runtime/model_providers/fireworks/llm/_position.yaml new file mode 100644 index 00000000000000..9f7c1af68cef72 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/_position.yaml @@ -0,0 +1,16 @@ +- llama-v3p1-405b-instruct +- llama-v3p1-70b-instruct +- llama-v3p1-8b-instruct +- llama-v3-70b-instruct +- mixtral-8x22b-instruct +- mixtral-8x7b-instruct +- firefunction-v2 +- firefunction-v1 +- gemma2-9b-it +- llama-v3-70b-instruct-hf +- llama-v3-8b-instruct +- llama-v3-8b-instruct-hf +- mixtral-8x7b-instruct-hf +- mythomax-l2-13b +- phi-3-vision-128k-instruct +- yi-large diff --git a/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v1.yaml b/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v1.yaml new file mode 100644 index 00000000000000..f6bac12832d646 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v1.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/firefunction-v1 +label: + zh_Hans: Firefunction V1 + en_US: Firefunction V1 +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.5' + output: '0.5' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v2.yaml b/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v2.yaml new file mode 100644 index 00000000000000..2979cb46d572a3 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v2.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/firefunction-v2 +label: + zh_Hans: Firefunction V2 + en_US: Firefunction V2 +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.9' + output: '0.9' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/gemma2-9b-it.yaml b/api/core/model_runtime/model_providers/fireworks/llm/gemma2-9b-it.yaml new file mode 100644 index 00000000000000..ee41a7e2fdc3d5 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/gemma2-9b-it.yaml @@ -0,0 +1,45 @@ +model: accounts/fireworks/models/gemma2-9b-it +label: + zh_Hans: Gemma2 9B Instruct + en_US: Gemma2 9B Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.2' + output: '0.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct-hf.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct-hf.yaml new file mode 100644 index 00000000000000..2ae89b88165d12 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct-hf.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3-70b-instruct-hf +label: + zh_Hans: Llama3 70B Instruct(HF version) + en_US: Llama3 70B Instruct(HF version) +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.9' + output: '0.9' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct.yaml new file mode 100644 index 00000000000000..7c24b08ca5cca1 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3-70b-instruct +label: + zh_Hans: Llama3 70B Instruct + en_US: Llama3 70B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.9' + output: '0.9' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct-hf.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct-hf.yaml new file mode 100644 index 00000000000000..83507ef3e5276e --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct-hf.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3-8b-instruct-hf +label: + zh_Hans: Llama3 8B Instruct(HF version) + en_US: Llama3 8B Instruct(HF version) +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.2' + output: '0.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct.yaml new file mode 100644 index 00000000000000..d8ac9537b80e7f --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3-8b-instruct +label: + zh_Hans: Llama3 8B Instruct + en_US: Llama3 8B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.2' + output: '0.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-405b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-405b-instruct.yaml new file mode 100644 index 00000000000000..c4ddb3e9246d4a --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-405b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3p1-405b-instruct +label: + zh_Hans: Llama3.1 405B Instruct + en_US: Llama3.1 405B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '3' + output: '3' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-70b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-70b-instruct.yaml new file mode 100644 index 00000000000000..62f84f87fa5609 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-70b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3p1-70b-instruct +label: + zh_Hans: Llama3.1 70B Instruct + en_US: Llama3.1 70B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.2' + output: '0.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-8b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-8b-instruct.yaml new file mode 100644 index 00000000000000..9bb99c91b65b0b --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-8b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3p1-8b-instruct +label: + zh_Hans: Llama3.1 8B Instruct + en_US: Llama3.1 8B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.2' + output: '0.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llm.py b/api/core/model_runtime/model_providers/fireworks/llm/llm.py new file mode 100644 index 00000000000000..2dcf1adba64518 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llm.py @@ -0,0 +1,610 @@ +import logging +from collections.abc import Generator +from typing import Optional, Union, cast + +from openai import OpenAI, Stream +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall +from openai.types.chat.chat_completion_message import FunctionCall + +from core.model_runtime.callbacks.base_callback import Callback +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.fireworks._common import _CommonFireworks + +logger = logging.getLogger(__name__) + +FIREWORKS_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" # noqa: E501 + + +class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel): + """ + Model class for Fireworks large language model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + + return self._chat_generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) + + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: + """ + Code block mode wrapper for invoking large language model + """ + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: + stop = stop or [] + self._transform_chat_json_prompts( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + response_format=model_parameters["response_format"], + ) + model_parameters.pop("response_format") + + return self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) + + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: + """ + Transform json prompts + """ + if stop is None: + stop = [] + if "```\n" not in stop: + stop.append("```\n") + if "\n```" not in stop: + stop.append("\n```") + + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + prompt_messages[0] = SystemPromptMessage( + content=FIREWORKS_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) + ) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n")) + else: + prompt_messages.insert( + 0, + SystemPromptMessage( + content=FIREWORKS_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ + return self._num_tokens_from_messages(model, prompt_messages, tools) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + credentials_kwargs = self._to_credential_kwargs(credentials) + client = OpenAI(**credentials_kwargs) + + client.chat.completions.create( + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=10, stream=False + ) + except Exception as e: + raise CredentialsValidateFailedError(str(e)) + + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + credentials_kwargs = self._to_credential_kwargs(credentials) + client = OpenAI(**credentials_kwargs) + + extra_model_kwargs = {} + + if tools: + extra_model_kwargs["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] + + if stop: + extra_model_kwargs["stop"] = stop + + if user: + extra_model_kwargs["user"] = user + + # chat model + response = client.chat.completions.create( + messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], + model=model, + stream=stream, + **model_parameters, + **extra_model_kwargs, + ) + + if stream: + return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) + return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: + """ + Handle llm chat response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: llm response + """ + assistant_message = response.choices[0].message + # assistant_message_tool_calls = assistant_message.tool_calls + assistant_message_function_call = assistant_message.function_call + + # extract tool calls from response + # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) + function_call = self._extract_response_function_call(assistant_message_function_call) + tool_calls = [function_call] if function_call else [] + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) + + # calculate num tokens + if response.usage: + # transform usage + prompt_tokens = response.usage.prompt_tokens + completion_tokens = response.usage.completion_tokens + else: + # calculate num tokens + prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools) + completion_tokens = self._num_tokens_from_messages(model, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + # transform response + response = LLMResult( + model=response.model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage, + system_fingerprint=response.system_fingerprint, + ) + + return response + + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: + """ + Handle llm chat stream response + + :param model: model name + :param response: response + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: llm response chunk generator + """ + full_assistant_content = "" + delta_assistant_message_function_call_storage: Optional[ChoiceDeltaFunctionCall] = None + prompt_tokens = 0 + completion_tokens = 0 + final_tool_calls = [] + final_chunk = LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=""), + ), + ) + + for chunk in response: + if len(chunk.choices) == 0: + if chunk.usage: + # calculate num tokens + prompt_tokens = chunk.usage.prompt_tokens + completion_tokens = chunk.usage.completion_tokens + continue + + delta = chunk.choices[0] + has_finish_reason = delta.finish_reason is not None + + if ( + not has_finish_reason + and (delta.delta.content is None or delta.delta.content == "") + and delta.delta.function_call is None + ): + continue + + # assistant_message_tool_calls = delta.delta.tool_calls + assistant_message_function_call = delta.delta.function_call + + # extract tool calls from response + if delta_assistant_message_function_call_storage is not None: + # handle process of stream function call + if assistant_message_function_call: + # message has not ended ever + delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments + continue + else: + # message has ended + assistant_message_function_call = delta_assistant_message_function_call_storage + delta_assistant_message_function_call_storage = None + else: + if assistant_message_function_call: + # start of stream function call + delta_assistant_message_function_call_storage = assistant_message_function_call + if delta_assistant_message_function_call_storage.arguments is None: + delta_assistant_message_function_call_storage.arguments = "" + if not has_finish_reason: + continue + + # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) + function_call = self._extract_response_function_call(assistant_message_function_call) + tool_calls = [function_call] if function_call else [] + if tool_calls: + final_tool_calls.extend(tool_calls) + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) + + full_assistant_content += delta.delta.content or "" + + if has_finish_reason: + final_chunk = LLMResultChunk( + model=chunk.model, + prompt_messages=prompt_messages, + system_fingerprint=chunk.system_fingerprint, + delta=LLMResultChunkDelta( + index=delta.index, + message=assistant_prompt_message, + finish_reason=delta.finish_reason, + ), + ) + else: + yield LLMResultChunk( + model=chunk.model, + prompt_messages=prompt_messages, + system_fingerprint=chunk.system_fingerprint, + delta=LLMResultChunkDelta( + index=delta.index, + message=assistant_prompt_message, + ), + ) + + if not prompt_tokens: + prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools) + + if not completion_tokens: + full_assistant_prompt_message = AssistantPromptMessage( + content=full_assistant_content, tool_calls=final_tool_calls + ) + completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + final_chunk.delta.usage = usage + + yield final_chunk + + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: + """ + Extract tool calls from response + + :param response_tool_calls: response tool calls + :return: list of tool calls + """ + tool_calls = [] + if response_tool_calls: + for response_tool_call in response_tool_calls: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call.id, type=response_tool_call.type, function=function + ) + tool_calls.append(tool_call) + + return tool_calls + + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: + """ + Extract function call from response + + :param response_function_call: response function call + :return: tool call + """ + tool_call = None + if response_function_call: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_function_call.name, arguments=response_function_call.arguments + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_function_call.name, type="function", function=function + ) + + return tool_call + + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict for Fireworks API + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(TextPromptMessageContent, message_content) + sub_message_dict = {"type": "text", "text": message_content.data} + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + sub_message_dict = { + "type": "image_url", + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, + } + sub_messages.append(sub_message_dict) + + message_dict = {"role": "user", "content": sub_messages} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls: + # message_dict["tool_calls"] = [tool_call.dict() for tool_call in + # message.tool_calls] + function_call = message.tool_calls[0] + message_dict["function_call"] = { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + } + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + # message_dict = { + # "role": "tool", + # "content": message.content, + # "tool_call_id": message.tool_call_id + # } + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} + else: + raise ValueError(f"Got unknown type {message}") + + if message.name: + message_dict["name"] = message.name + + return message_dict + + def _num_tokens_from_messages( + self, + model: str, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + credentials: dict = None, + ) -> int: + """ + Approximate num tokens with GPT2 tokenizer. + """ + + tokens_per_message = 3 + tokens_per_name = 1 + + num_tokens = 0 + messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + # Cast str(value) in case the message value is not a string + # This occurs with function messages + # TODO: The current token calculation method for the image type is not implemented, + # which need to download the image and then get the resolution for calculation, + # and will increase the request delay + if isinstance(value, list): + text = "" + for item in value: + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] + + value = text + + if key == "tool_calls": + for tool_call in value: + for t_key, t_value in tool_call.items(): + num_tokens += self._get_num_tokens_by_gpt2(t_key) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += self._get_num_tokens_by_gpt2(f_key) + num_tokens += self._get_num_tokens_by_gpt2(f_value) + else: + num_tokens += self._get_num_tokens_by_gpt2(t_key) + num_tokens += self._get_num_tokens_by_gpt2(t_value) + else: + num_tokens += self._get_num_tokens_by_gpt2(str(value)) + + if key == "name": + num_tokens += tokens_per_name + + # every reply is primed with assistant + num_tokens += 3 + + if tools: + num_tokens += self._num_tokens_for_tools(tools) + + return num_tokens + + def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: + """ + Calculate num tokens for tool calling with tiktoken package. + + :param tools: tools for tool calling + :return: number of tokens + """ + num_tokens = 0 + for tool in tools: + num_tokens += self._get_num_tokens_by_gpt2("type") + num_tokens += self._get_num_tokens_by_gpt2("function") + num_tokens += self._get_num_tokens_by_gpt2("function") + + # calculate num tokens for function object + num_tokens += self._get_num_tokens_by_gpt2("name") + num_tokens += self._get_num_tokens_by_gpt2(tool.name) + num_tokens += self._get_num_tokens_by_gpt2("description") + num_tokens += self._get_num_tokens_by_gpt2(tool.description) + parameters = tool.parameters + num_tokens += self._get_num_tokens_by_gpt2("parameters") + if "title" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("title") + num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title")) + num_tokens += self._get_num_tokens_by_gpt2("type") + num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type")) + if "properties" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("properties") + for key, value in parameters.get("properties").items(): + num_tokens += self._get_num_tokens_by_gpt2(key) + for field_key, field_value in value.items(): + num_tokens += self._get_num_tokens_by_gpt2(field_key) + if field_key == "enum": + for enum_field in field_value: + num_tokens += 3 + num_tokens += self._get_num_tokens_by_gpt2(enum_field) + else: + num_tokens += self._get_num_tokens_by_gpt2(field_key) + num_tokens += self._get_num_tokens_by_gpt2(str(field_value)) + if "required" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("required") + for required_field in parameters["required"]: + num_tokens += 3 + num_tokens += self._get_num_tokens_by_gpt2(required_field) + + return num_tokens diff --git a/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x22b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x22b-instruct.yaml new file mode 100644 index 00000000000000..87d977e26cf1b2 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x22b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/mixtral-8x22b-instruct +label: + zh_Hans: Mixtral MoE 8x22B Instruct + en_US: Mixtral MoE 8x22B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 65536 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '1.2' + output: '1.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct-hf.yaml b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct-hf.yaml new file mode 100644 index 00000000000000..e3d5a90858c5ae --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct-hf.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/mixtral-8x7b-instruct-hf +label: + zh_Hans: Mixtral MoE 8x7B Instruct(HF version) + en_US: Mixtral MoE 8x7B Instruct(HF version) +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.5' + output: '0.5' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct.yaml new file mode 100644 index 00000000000000..45f632ceff2cfc --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/mixtral-8x7b-instruct +label: + zh_Hans: Mixtral MoE 8x7B Instruct + en_US: Mixtral MoE 8x7B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.5' + output: '0.5' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/mythomax-l2-13b.yaml b/api/core/model_runtime/model_providers/fireworks/llm/mythomax-l2-13b.yaml new file mode 100644 index 00000000000000..9c3486ba10751b --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/mythomax-l2-13b.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/mythomax-l2-13b +label: + zh_Hans: MythoMax L2 13b + en_US: MythoMax L2 13b +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.2' + output: '0.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/phi-3-vision-128k-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/phi-3-vision-128k-instruct.yaml new file mode 100644 index 00000000000000..e399f2edb1b1bd --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/phi-3-vision-128k-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/phi-3-vision-128k-instruct +label: + zh_Hans: Phi3.5 Vision Instruct + en_US: Phi3.5 Vision Instruct +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.2' + output: '0.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/yi-large.yaml b/api/core/model_runtime/model_providers/fireworks/llm/yi-large.yaml new file mode 100644 index 00000000000000..bb4b6f994ec12a --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/yi-large.yaml @@ -0,0 +1,45 @@ +model: accounts/yi-01-ai/models/yi-large +label: + zh_Hans: Yi-Large + en_US: Yi-Large +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '3' + output: '3' + unit: '0.000001' + currency: USD diff --git a/api/pyproject.toml b/api/pyproject.toml index 8c10f1dad97952..9b49d632f28cc2 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -100,6 +100,7 @@ exclude = [ [tool.pytest_env] OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii" UPSTAGE_API_KEY = "up-aaaaaaaaaaaaaaaaaaaa" +FIREWORKS_API_KEY = "fw_aaaaaaaaaaaaaaaaaaaa" AZURE_OPENAI_API_BASE = "https://difyai-openai.openai.azure.com" AZURE_OPENAI_API_KEY = "xxxxb1707exxxxxxxxxxaaxxxxxf94" ANTHROPIC_API_KEY = "sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz" diff --git a/api/tests/integration_tests/model_runtime/fireworks/__init__.py b/api/tests/integration_tests/model_runtime/fireworks/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/model_runtime/fireworks/test_llm.py b/api/tests/integration_tests/model_runtime/fireworks/test_llm.py new file mode 100644 index 00000000000000..699ca293a2fca8 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/fireworks/test_llm.py @@ -0,0 +1,186 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.fireworks.llm.llm import FireworksLargeLanguageModel + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +def test_predefined_models(): + model = FireworksLargeLanguageModel() + model_schemas = model.predefined_models() + + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_credentials_for_chat_model(setup_openai_mock): + model = FireworksLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + # model name to gpt-3.5-turbo because of mocking + model.validate_credentials(model="gpt-3.5-turbo", credentials={"fireworks_api_key": "invalid_key"}) + + model.validate_credentials( + model="accounts/fireworks/models/llama-v3p1-8b-instruct", + credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}, + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model(setup_openai_mock): + model = FireworksLargeLanguageModel() + + result = model.invoke( + model="accounts/fireworks/models/llama-v3p1-8b-instruct", + credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, + }, + stop=["How"], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model_with_tools(setup_openai_mock): + model = FireworksLargeLanguageModel() + + result = model.invoke( + model="accounts/fireworks/models/llama-v3p1-8b-instruct", + credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage( + content="what's the weather today in London?", + ), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + PromptMessageTool( + name="get_stock_price", + description="Get the current stock price", + parameters={ + "type": "object", + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), + ], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert isinstance(result.message, AssistantPromptMessage) + assert len(result.message.tool_calls) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_stream_chat_model(setup_openai_mock): + model = FireworksLargeLanguageModel() + + result = model.invoke( + model="accounts/fireworks/models/llama-v3p1-8b-instruct", + credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="foo", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + assert chunk.delta.usage.completion_tokens > 0 + + +def test_get_num_tokens(): + model = FireworksLargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="accounts/fireworks/models/llama-v3p1-8b-instruct", + credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert num_tokens == 10 + + num_tokens = model.get_num_tokens( + model="accounts/fireworks/models/llama-v3p1-8b-instruct", + credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + ], + ) + + assert num_tokens == 77 diff --git a/api/tests/integration_tests/model_runtime/fireworks/test_provider.py b/api/tests/integration_tests/model_runtime/fireworks/test_provider.py new file mode 100644 index 00000000000000..a68cf1a1a8fbda --- /dev/null +++ b/api/tests/integration_tests/model_runtime/fireworks/test_provider.py @@ -0,0 +1,17 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.fireworks.fireworks import FireworksProvider +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_provider_credentials(setup_openai_mock): + provider = FireworksProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + + provider.validate_provider_credentials(credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}) diff --git a/dev/pytest/pytest_model_runtime.sh b/dev/pytest/pytest_model_runtime.sh index aba13292ab8315..4c1c6bf4f3ab19 100755 --- a/dev/pytest/pytest_model_runtime.sh +++ b/dev/pytest/pytest_model_runtime.sh @@ -6,5 +6,5 @@ pytest api/tests/integration_tests/model_runtime/anthropic \ api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm \ api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \ api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py \ - api/tests/integration_tests/model_runtime/upstage - + api/tests/integration_tests/model_runtime/upstage \ + api/tests/integration_tests/model_runtime/fireworks