diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml
index d10314ba039e63..1f5f64019a1663 100644
--- a/api/core/model_runtime/model_providers/_position.yaml
+++ b/api/core/model_runtime/model_providers/_position.yaml
@@ -37,3 +37,4 @@
- siliconflow
- perfxcloud
- zhinao
+- fireworks
diff --git a/api/core/model_runtime/model_providers/fireworks/__init__.py b/api/core/model_runtime/model_providers/fireworks/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/api/core/model_runtime/model_providers/fireworks/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/fireworks/_assets/icon_l_en.svg
new file mode 100644
index 00000000000000..582605cc422cce
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/_assets/icon_l_en.svg
@@ -0,0 +1,3 @@
+
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/fireworks/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/fireworks/_assets/icon_s_en.svg
new file mode 100644
index 00000000000000..86eeba66f9290a
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/_assets/icon_s_en.svg
@@ -0,0 +1,5 @@
+
diff --git a/api/core/model_runtime/model_providers/fireworks/_common.py b/api/core/model_runtime/model_providers/fireworks/_common.py
new file mode 100644
index 00000000000000..378ced3a4019ba
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/_common.py
@@ -0,0 +1,52 @@
+from collections.abc import Mapping
+
+import openai
+
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+
+
+class _CommonFireworks:
+ def _to_credential_kwargs(self, credentials: Mapping) -> dict:
+ """
+ Transform credentials to kwargs for model instance
+
+ :param credentials:
+ :return:
+ """
+ credentials_kwargs = {
+ "api_key": credentials["fireworks_api_key"],
+ "base_url": "https://api.fireworks.ai/inference/v1",
+ "max_retries": 1,
+ }
+
+ return credentials_kwargs
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ The key is the error type thrown to the caller
+ The value is the error type thrown by the model,
+ which needs to be converted into a unified error type for the caller.
+
+ :return: Invoke error mapping
+ """
+ return {
+ InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
+ InvokeServerUnavailableError: [openai.InternalServerError],
+ InvokeRateLimitError: [openai.RateLimitError],
+ InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
+ InvokeBadRequestError: [
+ openai.BadRequestError,
+ openai.NotFoundError,
+ openai.UnprocessableEntityError,
+ openai.APIError,
+ ],
+ }
diff --git a/api/core/model_runtime/model_providers/fireworks/fireworks.py b/api/core/model_runtime/model_providers/fireworks/fireworks.py
new file mode 100644
index 00000000000000..15f25badab994f
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/fireworks.py
@@ -0,0 +1,27 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class FireworksProvider(ModelProvider):
+ def validate_provider_credentials(self, credentials: dict) -> None:
+ """
+ Validate provider credentials
+ if validate failed, raise exception
+
+ :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+ """
+ try:
+ model_instance = self.get_model_instance(ModelType.LLM)
+ model_instance.validate_credentials(
+ model="accounts/fireworks/models/llama-v3p1-8b-instruct", credentials=credentials
+ )
+ except CredentialsValidateFailedError as ex:
+ raise ex
+ except Exception as ex:
+ logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
+ raise ex
diff --git a/api/core/model_runtime/model_providers/fireworks/fireworks.yaml b/api/core/model_runtime/model_providers/fireworks/fireworks.yaml
new file mode 100644
index 00000000000000..f886fa23b5bd82
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/fireworks.yaml
@@ -0,0 +1,29 @@
+provider: fireworks
+label:
+ zh_Hans: Fireworks AI
+ en_US: Fireworks AI
+icon_small:
+ en_US: icon_s_en.svg
+icon_large:
+ en_US: icon_l_en.svg
+background: "#FCFDFF"
+help:
+ title:
+ en_US: Get your API Key from Fireworks AI
+ zh_Hans: 从 Fireworks AI 获取 API Key
+ url:
+ en_US: https://fireworks.ai/account/api-keys
+supported_model_types:
+ - llm
+configurate_methods:
+ - predefined-model
+provider_credential_schema:
+ credential_form_schemas:
+ - variable: fireworks_api_key
+ label:
+ en_US: API Key
+ type: secret-input
+ required: true
+ placeholder:
+ zh_Hans: 在此输入您的 API Key
+ en_US: Enter your API Key
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/__init__.py b/api/core/model_runtime/model_providers/fireworks/llm/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/_position.yaml b/api/core/model_runtime/model_providers/fireworks/llm/_position.yaml
new file mode 100644
index 00000000000000..9f7c1af68cef72
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/_position.yaml
@@ -0,0 +1,16 @@
+- llama-v3p1-405b-instruct
+- llama-v3p1-70b-instruct
+- llama-v3p1-8b-instruct
+- llama-v3-70b-instruct
+- mixtral-8x22b-instruct
+- mixtral-8x7b-instruct
+- firefunction-v2
+- firefunction-v1
+- gemma2-9b-it
+- llama-v3-70b-instruct-hf
+- llama-v3-8b-instruct
+- llama-v3-8b-instruct-hf
+- mixtral-8x7b-instruct-hf
+- mythomax-l2-13b
+- phi-3-vision-128k-instruct
+- yi-large
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v1.yaml b/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v1.yaml
new file mode 100644
index 00000000000000..f6bac12832d646
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v1.yaml
@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/firefunction-v1
+label:
+ zh_Hans: Firefunction V1
+ en_US: Firefunction V1
+model_type: llm
+features:
+ - agent-thought
+ - tool-call
+model_properties:
+ mode: chat
+ context_size: 32768
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '0.5'
+ output: '0.5'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v2.yaml b/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v2.yaml
new file mode 100644
index 00000000000000..2979cb46d572a3
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v2.yaml
@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/firefunction-v2
+label:
+ zh_Hans: Firefunction V2
+ en_US: Firefunction V2
+model_type: llm
+features:
+ - agent-thought
+ - tool-call
+model_properties:
+ mode: chat
+ context_size: 8192
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '0.9'
+ output: '0.9'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/gemma2-9b-it.yaml b/api/core/model_runtime/model_providers/fireworks/llm/gemma2-9b-it.yaml
new file mode 100644
index 00000000000000..ee41a7e2fdc3d5
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/gemma2-9b-it.yaml
@@ -0,0 +1,45 @@
+model: accounts/fireworks/models/gemma2-9b-it
+label:
+ zh_Hans: Gemma2 9B Instruct
+ en_US: Gemma2 9B Instruct
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 8192
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '0.2'
+ output: '0.2'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct-hf.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct-hf.yaml
new file mode 100644
index 00000000000000..2ae89b88165d12
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct-hf.yaml
@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/llama-v3-70b-instruct-hf
+label:
+ zh_Hans: Llama3 70B Instruct(HF version)
+ en_US: Llama3 70B Instruct(HF version)
+model_type: llm
+features:
+ - agent-thought
+ - tool-call
+model_properties:
+ mode: chat
+ context_size: 8192
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '0.9'
+ output: '0.9'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct.yaml
new file mode 100644
index 00000000000000..7c24b08ca5cca1
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct.yaml
@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/llama-v3-70b-instruct
+label:
+ zh_Hans: Llama3 70B Instruct
+ en_US: Llama3 70B Instruct
+model_type: llm
+features:
+ - agent-thought
+ - tool-call
+model_properties:
+ mode: chat
+ context_size: 8192
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '0.9'
+ output: '0.9'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct-hf.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct-hf.yaml
new file mode 100644
index 00000000000000..83507ef3e5276e
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct-hf.yaml
@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/llama-v3-8b-instruct-hf
+label:
+ zh_Hans: Llama3 8B Instruct(HF version)
+ en_US: Llama3 8B Instruct(HF version)
+model_type: llm
+features:
+ - agent-thought
+ - tool-call
+model_properties:
+ mode: chat
+ context_size: 8192
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '0.2'
+ output: '0.2'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct.yaml
new file mode 100644
index 00000000000000..d8ac9537b80e7f
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct.yaml
@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/llama-v3-8b-instruct
+label:
+ zh_Hans: Llama3 8B Instruct
+ en_US: Llama3 8B Instruct
+model_type: llm
+features:
+ - agent-thought
+ - tool-call
+model_properties:
+ mode: chat
+ context_size: 8192
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '0.2'
+ output: '0.2'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-405b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-405b-instruct.yaml
new file mode 100644
index 00000000000000..c4ddb3e9246d4a
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-405b-instruct.yaml
@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/llama-v3p1-405b-instruct
+label:
+ zh_Hans: Llama3.1 405B Instruct
+ en_US: Llama3.1 405B Instruct
+model_type: llm
+features:
+ - agent-thought
+ - tool-call
+model_properties:
+ mode: chat
+ context_size: 131072
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '3'
+ output: '3'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-70b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-70b-instruct.yaml
new file mode 100644
index 00000000000000..62f84f87fa5609
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-70b-instruct.yaml
@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/llama-v3p1-70b-instruct
+label:
+ zh_Hans: Llama3.1 70B Instruct
+ en_US: Llama3.1 70B Instruct
+model_type: llm
+features:
+ - agent-thought
+ - tool-call
+model_properties:
+ mode: chat
+ context_size: 131072
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '0.2'
+ output: '0.2'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-8b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-8b-instruct.yaml
new file mode 100644
index 00000000000000..9bb99c91b65b0b
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-8b-instruct.yaml
@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/llama-v3p1-8b-instruct
+label:
+ zh_Hans: Llama3.1 8B Instruct
+ en_US: Llama3.1 8B Instruct
+model_type: llm
+features:
+ - agent-thought
+ - tool-call
+model_properties:
+ mode: chat
+ context_size: 131072
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '0.2'
+ output: '0.2'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llm.py b/api/core/model_runtime/model_providers/fireworks/llm/llm.py
new file mode 100644
index 00000000000000..2dcf1adba64518
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/llm.py
@@ -0,0 +1,610 @@
+import logging
+from collections.abc import Generator
+from typing import Optional, Union, cast
+
+from openai import OpenAI, Stream
+from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
+from openai.types.chat.chat_completion_message import FunctionCall
+
+from core.model_runtime.callbacks.base_callback import Callback
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
+ ImagePromptMessageContent,
+ PromptMessage,
+ PromptMessageContentType,
+ PromptMessageTool,
+ SystemPromptMessage,
+ TextPromptMessageContent,
+ ToolPromptMessage,
+ UserPromptMessage,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
+
+logger = logging.getLogger(__name__)
+
+FIREWORKS_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
+The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
+if you are not sure about the structure.
+
+
+{{instructions}}
+
+""" # noqa: E501
+
+
+class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel):
+ """
+ Model class for Fireworks large language model.
+ """
+
+ def _invoke(
+ self,
+ model: str,
+ credentials: dict,
+ prompt_messages: list[PromptMessage],
+ model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None,
+ stop: Optional[list[str]] = None,
+ stream: bool = True,
+ user: Optional[str] = None,
+ ) -> Union[LLMResult, Generator]:
+ """
+ Invoke large language model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_messages: prompt messages
+ :param model_parameters: model parameters
+ :param tools: tools for tool calling
+ :param stop: stop words
+ :param stream: is stream response
+ :param user: unique user id
+ :return: full response or stream response chunk generator result
+ """
+
+ return self._chat_generate(
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ )
+
+ def _code_block_mode_wrapper(
+ self,
+ model: str,
+ credentials: dict,
+ prompt_messages: list[PromptMessage],
+ model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None,
+ stop: Optional[list[str]] = None,
+ stream: bool = True,
+ user: Optional[str] = None,
+ callbacks: Optional[list[Callback]] = None,
+ ) -> Union[LLMResult, Generator]:
+ """
+ Code block mode wrapper for invoking large language model
+ """
+ if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
+ stop = stop or []
+ self._transform_chat_json_prompts(
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ response_format=model_parameters["response_format"],
+ )
+ model_parameters.pop("response_format")
+
+ return self._invoke(
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ )
+
+ def _transform_chat_json_prompts(
+ self,
+ model: str,
+ credentials: dict,
+ prompt_messages: list[PromptMessage],
+ model_parameters: dict,
+ tools: list[PromptMessageTool] | None = None,
+ stop: list[str] | None = None,
+ stream: bool = True,
+ user: str | None = None,
+ response_format: str = "JSON",
+ ) -> None:
+ """
+ Transform json prompts
+ """
+ if stop is None:
+ stop = []
+ if "```\n" not in stop:
+ stop.append("```\n")
+ if "\n```" not in stop:
+ stop.append("\n```")
+
+ if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
+ prompt_messages[0] = SystemPromptMessage(
+ content=FIREWORKS_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace(
+ "{{block}}", response_format
+ )
+ )
+ prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n"))
+ else:
+ prompt_messages.insert(
+ 0,
+ SystemPromptMessage(
+ content=FIREWORKS_BLOCK_MODE_PROMPT.replace(
+ "{{instructions}}", f"Please output a valid {response_format} object."
+ ).replace("{{block}}", response_format)
+ ),
+ )
+ prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
+
+ def get_num_tokens(
+ self,
+ model: str,
+ credentials: dict,
+ prompt_messages: list[PromptMessage],
+ tools: Optional[list[PromptMessageTool]] = None,
+ ) -> int:
+ """
+ Get number of tokens for given prompt messages
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_messages: prompt messages
+ :param tools: tools for tool calling
+ :return:
+ """
+ return self._num_tokens_from_messages(model, prompt_messages, tools)
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+ credentials_kwargs = self._to_credential_kwargs(credentials)
+ client = OpenAI(**credentials_kwargs)
+
+ client.chat.completions.create(
+ messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=10, stream=False
+ )
+ except Exception as e:
+ raise CredentialsValidateFailedError(str(e))
+
+ def _chat_generate(
+ self,
+ model: str,
+ credentials: dict,
+ prompt_messages: list[PromptMessage],
+ model_parameters: dict,
+ tools: Optional[list[PromptMessageTool]] = None,
+ stop: Optional[list[str]] = None,
+ stream: bool = True,
+ user: Optional[str] = None,
+ ) -> Union[LLMResult, Generator]:
+ credentials_kwargs = self._to_credential_kwargs(credentials)
+ client = OpenAI(**credentials_kwargs)
+
+ extra_model_kwargs = {}
+
+ if tools:
+ extra_model_kwargs["functions"] = [
+ {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools
+ ]
+
+ if stop:
+ extra_model_kwargs["stop"] = stop
+
+ if user:
+ extra_model_kwargs["user"] = user
+
+ # chat model
+ response = client.chat.completions.create(
+ messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
+ model=model,
+ stream=stream,
+ **model_parameters,
+ **extra_model_kwargs,
+ )
+
+ if stream:
+ return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
+ return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
+
+ def _handle_chat_generate_response(
+ self,
+ model: str,
+ credentials: dict,
+ response: ChatCompletion,
+ prompt_messages: list[PromptMessage],
+ tools: Optional[list[PromptMessageTool]] = None,
+ ) -> LLMResult:
+ """
+ Handle llm chat response
+
+ :param model: model name
+ :param credentials: credentials
+ :param response: response
+ :param prompt_messages: prompt messages
+ :param tools: tools for tool calling
+ :return: llm response
+ """
+ assistant_message = response.choices[0].message
+ # assistant_message_tool_calls = assistant_message.tool_calls
+ assistant_message_function_call = assistant_message.function_call
+
+ # extract tool calls from response
+ # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
+ function_call = self._extract_response_function_call(assistant_message_function_call)
+ tool_calls = [function_call] if function_call else []
+
+ # transform assistant message to prompt message
+ assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
+
+ # calculate num tokens
+ if response.usage:
+ # transform usage
+ prompt_tokens = response.usage.prompt_tokens
+ completion_tokens = response.usage.completion_tokens
+ else:
+ # calculate num tokens
+ prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
+ completion_tokens = self._num_tokens_from_messages(model, [assistant_prompt_message])
+
+ # transform usage
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+ # transform response
+ response = LLMResult(
+ model=response.model,
+ prompt_messages=prompt_messages,
+ message=assistant_prompt_message,
+ usage=usage,
+ system_fingerprint=response.system_fingerprint,
+ )
+
+ return response
+
+ def _handle_chat_generate_stream_response(
+ self,
+ model: str,
+ credentials: dict,
+ response: Stream[ChatCompletionChunk],
+ prompt_messages: list[PromptMessage],
+ tools: Optional[list[PromptMessageTool]] = None,
+ ) -> Generator:
+ """
+ Handle llm chat stream response
+
+ :param model: model name
+ :param response: response
+ :param prompt_messages: prompt messages
+ :param tools: tools for tool calling
+ :return: llm response chunk generator
+ """
+ full_assistant_content = ""
+ delta_assistant_message_function_call_storage: Optional[ChoiceDeltaFunctionCall] = None
+ prompt_tokens = 0
+ completion_tokens = 0
+ final_tool_calls = []
+ final_chunk = LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=AssistantPromptMessage(content=""),
+ ),
+ )
+
+ for chunk in response:
+ if len(chunk.choices) == 0:
+ if chunk.usage:
+ # calculate num tokens
+ prompt_tokens = chunk.usage.prompt_tokens
+ completion_tokens = chunk.usage.completion_tokens
+ continue
+
+ delta = chunk.choices[0]
+ has_finish_reason = delta.finish_reason is not None
+
+ if (
+ not has_finish_reason
+ and (delta.delta.content is None or delta.delta.content == "")
+ and delta.delta.function_call is None
+ ):
+ continue
+
+ # assistant_message_tool_calls = delta.delta.tool_calls
+ assistant_message_function_call = delta.delta.function_call
+
+ # extract tool calls from response
+ if delta_assistant_message_function_call_storage is not None:
+ # handle process of stream function call
+ if assistant_message_function_call:
+ # message has not ended ever
+ delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
+ continue
+ else:
+ # message has ended
+ assistant_message_function_call = delta_assistant_message_function_call_storage
+ delta_assistant_message_function_call_storage = None
+ else:
+ if assistant_message_function_call:
+ # start of stream function call
+ delta_assistant_message_function_call_storage = assistant_message_function_call
+ if delta_assistant_message_function_call_storage.arguments is None:
+ delta_assistant_message_function_call_storage.arguments = ""
+ if not has_finish_reason:
+ continue
+
+ # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
+ function_call = self._extract_response_function_call(assistant_message_function_call)
+ tool_calls = [function_call] if function_call else []
+ if tool_calls:
+ final_tool_calls.extend(tool_calls)
+
+ # transform assistant message to prompt message
+ assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
+
+ full_assistant_content += delta.delta.content or ""
+
+ if has_finish_reason:
+ final_chunk = LLMResultChunk(
+ model=chunk.model,
+ prompt_messages=prompt_messages,
+ system_fingerprint=chunk.system_fingerprint,
+ delta=LLMResultChunkDelta(
+ index=delta.index,
+ message=assistant_prompt_message,
+ finish_reason=delta.finish_reason,
+ ),
+ )
+ else:
+ yield LLMResultChunk(
+ model=chunk.model,
+ prompt_messages=prompt_messages,
+ system_fingerprint=chunk.system_fingerprint,
+ delta=LLMResultChunkDelta(
+ index=delta.index,
+ message=assistant_prompt_message,
+ ),
+ )
+
+ if not prompt_tokens:
+ prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools)
+
+ if not completion_tokens:
+ full_assistant_prompt_message = AssistantPromptMessage(
+ content=full_assistant_content, tool_calls=final_tool_calls
+ )
+ completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message])
+
+ # transform usage
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+ final_chunk.delta.usage = usage
+
+ yield final_chunk
+
+ def _extract_response_tool_calls(
+ self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]
+ ) -> list[AssistantPromptMessage.ToolCall]:
+ """
+ Extract tool calls from response
+
+ :param response_tool_calls: response tool calls
+ :return: list of tool calls
+ """
+ tool_calls = []
+ if response_tool_calls:
+ for response_tool_call in response_tool_calls:
+ function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+ name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
+ )
+
+ tool_call = AssistantPromptMessage.ToolCall(
+ id=response_tool_call.id, type=response_tool_call.type, function=function
+ )
+ tool_calls.append(tool_call)
+
+ return tool_calls
+
+ def _extract_response_function_call(
+ self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall
+ ) -> AssistantPromptMessage.ToolCall:
+ """
+ Extract function call from response
+
+ :param response_function_call: response function call
+ :return: tool call
+ """
+ tool_call = None
+ if response_function_call:
+ function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+ name=response_function_call.name, arguments=response_function_call.arguments
+ )
+
+ tool_call = AssistantPromptMessage.ToolCall(
+ id=response_function_call.name, type="function", function=function
+ )
+
+ return tool_call
+
+ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
+ """
+ Convert PromptMessage to dict for Fireworks API
+ """
+ if isinstance(message, UserPromptMessage):
+ message = cast(UserPromptMessage, message)
+ if isinstance(message.content, str):
+ message_dict = {"role": "user", "content": message.content}
+ else:
+ sub_messages = []
+ for message_content in message.content:
+ if message_content.type == PromptMessageContentType.TEXT:
+ message_content = cast(TextPromptMessageContent, message_content)
+ sub_message_dict = {"type": "text", "text": message_content.data}
+ sub_messages.append(sub_message_dict)
+ elif message_content.type == PromptMessageContentType.IMAGE:
+ message_content = cast(ImagePromptMessageContent, message_content)
+ sub_message_dict = {
+ "type": "image_url",
+ "image_url": {"url": message_content.data, "detail": message_content.detail.value},
+ }
+ sub_messages.append(sub_message_dict)
+
+ message_dict = {"role": "user", "content": sub_messages}
+ elif isinstance(message, AssistantPromptMessage):
+ message = cast(AssistantPromptMessage, message)
+ message_dict = {"role": "assistant", "content": message.content}
+ if message.tool_calls:
+ # message_dict["tool_calls"] = [tool_call.dict() for tool_call in
+ # message.tool_calls]
+ function_call = message.tool_calls[0]
+ message_dict["function_call"] = {
+ "name": function_call.function.name,
+ "arguments": function_call.function.arguments,
+ }
+ elif isinstance(message, SystemPromptMessage):
+ message = cast(SystemPromptMessage, message)
+ message_dict = {"role": "system", "content": message.content}
+ elif isinstance(message, ToolPromptMessage):
+ message = cast(ToolPromptMessage, message)
+ # message_dict = {
+ # "role": "tool",
+ # "content": message.content,
+ # "tool_call_id": message.tool_call_id
+ # }
+ message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id}
+ else:
+ raise ValueError(f"Got unknown type {message}")
+
+ if message.name:
+ message_dict["name"] = message.name
+
+ return message_dict
+
+ def _num_tokens_from_messages(
+ self,
+ model: str,
+ messages: list[PromptMessage],
+ tools: Optional[list[PromptMessageTool]] = None,
+ credentials: dict = None,
+ ) -> int:
+ """
+ Approximate num tokens with GPT2 tokenizer.
+ """
+
+ tokens_per_message = 3
+ tokens_per_name = 1
+
+ num_tokens = 0
+ messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
+ for message in messages_dict:
+ num_tokens += tokens_per_message
+ for key, value in message.items():
+ # Cast str(value) in case the message value is not a string
+ # This occurs with function messages
+ # TODO: The current token calculation method for the image type is not implemented,
+ # which need to download the image and then get the resolution for calculation,
+ # and will increase the request delay
+ if isinstance(value, list):
+ text = ""
+ for item in value:
+ if isinstance(item, dict) and item["type"] == "text":
+ text += item["text"]
+
+ value = text
+
+ if key == "tool_calls":
+ for tool_call in value:
+ for t_key, t_value in tool_call.items():
+ num_tokens += self._get_num_tokens_by_gpt2(t_key)
+ if t_key == "function":
+ for f_key, f_value in t_value.items():
+ num_tokens += self._get_num_tokens_by_gpt2(f_key)
+ num_tokens += self._get_num_tokens_by_gpt2(f_value)
+ else:
+ num_tokens += self._get_num_tokens_by_gpt2(t_key)
+ num_tokens += self._get_num_tokens_by_gpt2(t_value)
+ else:
+ num_tokens += self._get_num_tokens_by_gpt2(str(value))
+
+ if key == "name":
+ num_tokens += tokens_per_name
+
+ # every reply is primed with assistant
+ num_tokens += 3
+
+ if tools:
+ num_tokens += self._num_tokens_for_tools(tools)
+
+ return num_tokens
+
+ def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
+ """
+ Calculate num tokens for tool calling with tiktoken package.
+
+ :param tools: tools for tool calling
+ :return: number of tokens
+ """
+ num_tokens = 0
+ for tool in tools:
+ num_tokens += self._get_num_tokens_by_gpt2("type")
+ num_tokens += self._get_num_tokens_by_gpt2("function")
+ num_tokens += self._get_num_tokens_by_gpt2("function")
+
+ # calculate num tokens for function object
+ num_tokens += self._get_num_tokens_by_gpt2("name")
+ num_tokens += self._get_num_tokens_by_gpt2(tool.name)
+ num_tokens += self._get_num_tokens_by_gpt2("description")
+ num_tokens += self._get_num_tokens_by_gpt2(tool.description)
+ parameters = tool.parameters
+ num_tokens += self._get_num_tokens_by_gpt2("parameters")
+ if "title" in parameters:
+ num_tokens += self._get_num_tokens_by_gpt2("title")
+ num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title"))
+ num_tokens += self._get_num_tokens_by_gpt2("type")
+ num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type"))
+ if "properties" in parameters:
+ num_tokens += self._get_num_tokens_by_gpt2("properties")
+ for key, value in parameters.get("properties").items():
+ num_tokens += self._get_num_tokens_by_gpt2(key)
+ for field_key, field_value in value.items():
+ num_tokens += self._get_num_tokens_by_gpt2(field_key)
+ if field_key == "enum":
+ for enum_field in field_value:
+ num_tokens += 3
+ num_tokens += self._get_num_tokens_by_gpt2(enum_field)
+ else:
+ num_tokens += self._get_num_tokens_by_gpt2(field_key)
+ num_tokens += self._get_num_tokens_by_gpt2(str(field_value))
+ if "required" in parameters:
+ num_tokens += self._get_num_tokens_by_gpt2("required")
+ for required_field in parameters["required"]:
+ num_tokens += 3
+ num_tokens += self._get_num_tokens_by_gpt2(required_field)
+
+ return num_tokens
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x22b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x22b-instruct.yaml
new file mode 100644
index 00000000000000..87d977e26cf1b2
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x22b-instruct.yaml
@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/mixtral-8x22b-instruct
+label:
+ zh_Hans: Mixtral MoE 8x22B Instruct
+ en_US: Mixtral MoE 8x22B Instruct
+model_type: llm
+features:
+ - agent-thought
+ - tool-call
+model_properties:
+ mode: chat
+ context_size: 65536
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '1.2'
+ output: '1.2'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct-hf.yaml b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct-hf.yaml
new file mode 100644
index 00000000000000..e3d5a90858c5ae
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct-hf.yaml
@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/mixtral-8x7b-instruct-hf
+label:
+ zh_Hans: Mixtral MoE 8x7B Instruct(HF version)
+ en_US: Mixtral MoE 8x7B Instruct(HF version)
+model_type: llm
+features:
+ - agent-thought
+ - tool-call
+model_properties:
+ mode: chat
+ context_size: 32768
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '0.5'
+ output: '0.5'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct.yaml
new file mode 100644
index 00000000000000..45f632ceff2cfc
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct.yaml
@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/mixtral-8x7b-instruct
+label:
+ zh_Hans: Mixtral MoE 8x7B Instruct
+ en_US: Mixtral MoE 8x7B Instruct
+model_type: llm
+features:
+ - agent-thought
+ - tool-call
+model_properties:
+ mode: chat
+ context_size: 32768
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '0.5'
+ output: '0.5'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/mythomax-l2-13b.yaml b/api/core/model_runtime/model_providers/fireworks/llm/mythomax-l2-13b.yaml
new file mode 100644
index 00000000000000..9c3486ba10751b
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/mythomax-l2-13b.yaml
@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/mythomax-l2-13b
+label:
+ zh_Hans: MythoMax L2 13b
+ en_US: MythoMax L2 13b
+model_type: llm
+features:
+ - agent-thought
+ - tool-call
+model_properties:
+ mode: chat
+ context_size: 4096
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '0.2'
+ output: '0.2'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/phi-3-vision-128k-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/phi-3-vision-128k-instruct.yaml
new file mode 100644
index 00000000000000..e399f2edb1b1bd
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/phi-3-vision-128k-instruct.yaml
@@ -0,0 +1,46 @@
+model: accounts/fireworks/models/phi-3-vision-128k-instruct
+label:
+ zh_Hans: Phi3.5 Vision Instruct
+ en_US: Phi3.5 Vision Instruct
+model_type: llm
+features:
+ - agent-thought
+ - vision
+model_properties:
+ mode: chat
+ context_size: 8192
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '0.2'
+ output: '0.2'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/fireworks/llm/yi-large.yaml b/api/core/model_runtime/model_providers/fireworks/llm/yi-large.yaml
new file mode 100644
index 00000000000000..bb4b6f994ec12a
--- /dev/null
+++ b/api/core/model_runtime/model_providers/fireworks/llm/yi-large.yaml
@@ -0,0 +1,45 @@
+model: accounts/yi-01-ai/models/yi-large
+label:
+ zh_Hans: Yi-Large
+ en_US: Yi-Large
+model_type: llm
+features:
+ - agent-thought
+model_properties:
+ mode: chat
+ context_size: 32768
+parameter_rules:
+ - name: temperature
+ use_template: temperature
+ - name: top_p
+ use_template: top_p
+ - name: top_k
+ label:
+ zh_Hans: 取样数量
+ en_US: Top k
+ type: int
+ help:
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+ en_US: Only sample from the top K options for each subsequent token.
+ - name: max_tokens
+ use_template: max_tokens
+ - name: context_length_exceeded_behavior
+ default: None
+ label:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ help:
+ zh_Hans: 上下文长度超出行为
+ en_US: Context Length Exceeded Behavior
+ type: string
+ options:
+ - None
+ - truncate
+ - error
+ - name: response_format
+ use_template: response_format
+pricing:
+ input: '3'
+ output: '3'
+ unit: '0.000001'
+ currency: USD
diff --git a/api/pyproject.toml b/api/pyproject.toml
index 1f483fc49f0326..93482b032d5369 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -100,6 +100,7 @@ exclude = [
[tool.pytest_env]
OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii"
UPSTAGE_API_KEY = "up-aaaaaaaaaaaaaaaaaaaa"
+FIREWORKS_API_KEY = "fw_aaaaaaaaaaaaaaaaaaaa"
AZURE_OPENAI_API_BASE = "https://difyai-openai.openai.azure.com"
AZURE_OPENAI_API_KEY = "xxxxb1707exxxxxxxxxxaaxxxxxf94"
ANTHROPIC_API_KEY = "sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz"
diff --git a/api/tests/integration_tests/model_runtime/fireworks/__init__.py b/api/tests/integration_tests/model_runtime/fireworks/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/api/tests/integration_tests/model_runtime/fireworks/test_llm.py b/api/tests/integration_tests/model_runtime/fireworks/test_llm.py
new file mode 100644
index 00000000000000..699ca293a2fca8
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/fireworks/test_llm.py
@@ -0,0 +1,186 @@
+import os
+from collections.abc import Generator
+
+import pytest
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
+ PromptMessageTool,
+ SystemPromptMessage,
+ UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import AIModelEntity
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.fireworks.llm.llm import FireworksLargeLanguageModel
+
+"""FOR MOCK FIXTURES, DO NOT REMOVE"""
+from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
+
+
+def test_predefined_models():
+ model = FireworksLargeLanguageModel()
+ model_schemas = model.predefined_models()
+
+ assert len(model_schemas) >= 1
+ assert isinstance(model_schemas[0], AIModelEntity)
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_validate_credentials_for_chat_model(setup_openai_mock):
+ model = FireworksLargeLanguageModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ # model name to gpt-3.5-turbo because of mocking
+ model.validate_credentials(model="gpt-3.5-turbo", credentials={"fireworks_api_key": "invalid_key"})
+
+ model.validate_credentials(
+ model="accounts/fireworks/models/llama-v3p1-8b-instruct",
+ credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
+ )
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_invoke_chat_model(setup_openai_mock):
+ model = FireworksLargeLanguageModel()
+
+ result = model.invoke(
+ model="accounts/fireworks/models/llama-v3p1-8b-instruct",
+ credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
+ prompt_messages=[
+ SystemPromptMessage(
+ content="You are a helpful AI assistant.",
+ ),
+ UserPromptMessage(content="Hello World!"),
+ ],
+ model_parameters={
+ "temperature": 0.0,
+ "top_p": 1.0,
+ "presence_penalty": 0.0,
+ "frequency_penalty": 0.0,
+ "max_tokens": 10,
+ },
+ stop=["How"],
+ stream=False,
+ user="foo",
+ )
+
+ assert isinstance(result, LLMResult)
+ assert len(result.message.content) > 0
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_invoke_chat_model_with_tools(setup_openai_mock):
+ model = FireworksLargeLanguageModel()
+
+ result = model.invoke(
+ model="accounts/fireworks/models/llama-v3p1-8b-instruct",
+ credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
+ prompt_messages=[
+ SystemPromptMessage(
+ content="You are a helpful AI assistant.",
+ ),
+ UserPromptMessage(
+ content="what's the weather today in London?",
+ ),
+ ],
+ model_parameters={"temperature": 0.0, "max_tokens": 100},
+ tools=[
+ PromptMessageTool(
+ name="get_weather",
+ description="Determine weather in my location",
+ parameters={
+ "type": "object",
+ "properties": {
+ "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+ "unit": {"type": "string", "enum": ["c", "f"]},
+ },
+ "required": ["location"],
+ },
+ ),
+ PromptMessageTool(
+ name="get_stock_price",
+ description="Get the current stock price",
+ parameters={
+ "type": "object",
+ "properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
+ "required": ["symbol"],
+ },
+ ),
+ ],
+ stream=False,
+ user="foo",
+ )
+
+ assert isinstance(result, LLMResult)
+ assert isinstance(result.message, AssistantPromptMessage)
+ assert len(result.message.tool_calls) > 0
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_invoke_stream_chat_model(setup_openai_mock):
+ model = FireworksLargeLanguageModel()
+
+ result = model.invoke(
+ model="accounts/fireworks/models/llama-v3p1-8b-instruct",
+ credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
+ prompt_messages=[
+ SystemPromptMessage(
+ content="You are a helpful AI assistant.",
+ ),
+ UserPromptMessage(content="Hello World!"),
+ ],
+ model_parameters={"temperature": 0.0, "max_tokens": 100},
+ stream=True,
+ user="foo",
+ )
+
+ assert isinstance(result, Generator)
+
+ for chunk in result:
+ assert isinstance(chunk, LLMResultChunk)
+ assert isinstance(chunk.delta, LLMResultChunkDelta)
+ assert isinstance(chunk.delta.message, AssistantPromptMessage)
+ assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+ if chunk.delta.finish_reason is not None:
+ assert chunk.delta.usage is not None
+ assert chunk.delta.usage.completion_tokens > 0
+
+
+def test_get_num_tokens():
+ model = FireworksLargeLanguageModel()
+
+ num_tokens = model.get_num_tokens(
+ model="accounts/fireworks/models/llama-v3p1-8b-instruct",
+ credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
+ prompt_messages=[UserPromptMessage(content="Hello World!")],
+ )
+
+ assert num_tokens == 10
+
+ num_tokens = model.get_num_tokens(
+ model="accounts/fireworks/models/llama-v3p1-8b-instruct",
+ credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
+ prompt_messages=[
+ SystemPromptMessage(
+ content="You are a helpful AI assistant.",
+ ),
+ UserPromptMessage(content="Hello World!"),
+ ],
+ tools=[
+ PromptMessageTool(
+ name="get_weather",
+ description="Determine weather in my location",
+ parameters={
+ "type": "object",
+ "properties": {
+ "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+ "unit": {"type": "string", "enum": ["c", "f"]},
+ },
+ "required": ["location"],
+ },
+ ),
+ ],
+ )
+
+ assert num_tokens == 77
diff --git a/api/tests/integration_tests/model_runtime/fireworks/test_provider.py b/api/tests/integration_tests/model_runtime/fireworks/test_provider.py
new file mode 100644
index 00000000000000..a68cf1a1a8fbda
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/fireworks/test_provider.py
@@ -0,0 +1,17 @@
+import os
+
+import pytest
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.fireworks.fireworks import FireworksProvider
+from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_validate_provider_credentials(setup_openai_mock):
+ provider = FireworksProvider()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ provider.validate_provider_credentials(credentials={})
+
+ provider.validate_provider_credentials(credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")})
diff --git a/dev/pytest/pytest_model_runtime.sh b/dev/pytest/pytest_model_runtime.sh
index aba13292ab8315..4c1c6bf4f3ab19 100755
--- a/dev/pytest/pytest_model_runtime.sh
+++ b/dev/pytest/pytest_model_runtime.sh
@@ -6,5 +6,5 @@ pytest api/tests/integration_tests/model_runtime/anthropic \
api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm \
api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \
api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py \
- api/tests/integration_tests/model_runtime/upstage
-
+ api/tests/integration_tests/model_runtime/upstage \
+ api/tests/integration_tests/model_runtime/fireworks