From 5ca18f2cd1ab44f05d14c9902f75312fc121fc4d Mon Sep 17 00:00:00 2001
From: zhuhao <37029601+hwzhuhao@users.noreply.github.com>
Date: Sun, 29 Sep 2024 16:55:59 +0800
Subject: [PATCH] feat: add voyage ai as a new model provider (#8747)
---
.../model_providers/_position.yaml | 1 +
.../model_providers/voyage/__init__.py | 0
.../voyage/_assets/icon_l_en.svg | 21 +++
.../voyage/_assets/icon_s_en.svg | 8 +
.../model_providers/voyage/rerank/__init__.py | 0
.../voyage/rerank/rerank-1.yaml | 4 +
.../voyage/rerank/rerank-lite-1.yaml | 4 +
.../model_providers/voyage/rerank/rerank.py | 123 +++++++++++++
.../voyage/text_embedding/__init__.py | 0
.../voyage/text_embedding/text_embedding.py | 172 ++++++++++++++++++
.../voyage/text_embedding/voyage-3-lite.yaml | 8 +
.../voyage/text_embedding/voyage-3.yaml | 8 +
.../model_providers/voyage/voyage.py | 28 +++
.../model_providers/voyage/voyage.yaml | 31 ++++
api/pyproject.toml | 1 +
.../model_runtime/voyage/__init__.py | 0
.../model_runtime/voyage/test_provider.py | 25 +++
.../model_runtime/voyage/test_rerank.py | 92 ++++++++++
.../voyage/test_text_embedding.py | 70 +++++++
dev/pytest/pytest_model_runtime.sh | 3 +-
20 files changed, 598 insertions(+), 1 deletion(-)
create mode 100644 api/core/model_runtime/model_providers/voyage/__init__.py
create mode 100644 api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg
create mode 100644 api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg
create mode 100644 api/core/model_runtime/model_providers/voyage/rerank/__init__.py
create mode 100644 api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml
create mode 100644 api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml
create mode 100644 api/core/model_runtime/model_providers/voyage/rerank/rerank.py
create mode 100644 api/core/model_runtime/model_providers/voyage/text_embedding/__init__.py
create mode 100644 api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py
create mode 100644 api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml
create mode 100644 api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml
create mode 100644 api/core/model_runtime/model_providers/voyage/voyage.py
create mode 100644 api/core/model_runtime/model_providers/voyage/voyage.yaml
create mode 100644 api/tests/integration_tests/model_runtime/voyage/__init__.py
create mode 100644 api/tests/integration_tests/model_runtime/voyage/test_provider.py
create mode 100644 api/tests/integration_tests/model_runtime/voyage/test_rerank.py
create mode 100644 api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py
diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml
index 80db22ea84fe63..89fccef6598fdd 100644
--- a/api/core/model_runtime/model_providers/_position.yaml
+++ b/api/core/model_runtime/model_providers/_position.yaml
@@ -40,3 +40,4 @@
- fireworks
- mixedbread
- nomic
+- voyage
diff --git a/api/core/model_runtime/model_providers/voyage/__init__.py b/api/core/model_runtime/model_providers/voyage/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg
new file mode 100644
index 00000000000000..a961f5e4355eea
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg
@@ -0,0 +1,21 @@
+
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg
new file mode 100644
index 00000000000000..2c4e121dd71f0b
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg
@@ -0,0 +1,8 @@
+
+
\ No newline at end of file
diff --git a/api/core/model_runtime/model_providers/voyage/rerank/__init__.py b/api/core/model_runtime/model_providers/voyage/rerank/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml b/api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml
new file mode 100644
index 00000000000000..9c894eda85203b
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml
@@ -0,0 +1,4 @@
+model: rerank-1
+model_type: rerank
+model_properties:
+ context_size: 8000
diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml b/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml
new file mode 100644
index 00000000000000..b052d6f00028cb
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml
@@ -0,0 +1,4 @@
+model: rerank-lite-1
+model_type: rerank
+model_properties:
+ context_size: 4000
diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank.py b/api/core/model_runtime/model_providers/voyage/rerank/rerank.py
new file mode 100644
index 00000000000000..33fdebbb45ef36
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank.py
@@ -0,0 +1,123 @@
+from typing import Optional
+
+import httpx
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
+from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+
+
+class VoyageRerankModel(RerankModel):
+ """
+ Model class for Voyage rerank model.
+ """
+
+ def _invoke(
+ self,
+ model: str,
+ credentials: dict,
+ query: str,
+ docs: list[str],
+ score_threshold: Optional[float] = None,
+ top_n: Optional[int] = None,
+ user: Optional[str] = None,
+ ) -> RerankResult:
+ """
+ Invoke rerank model
+ :param model: model name
+ :param credentials: model credentials
+ :param query: search query
+ :param docs: docs for reranking
+ :param score_threshold: score threshold
+ :param top_n: top n documents to return
+ :param user: unique user id
+ :return: rerank result
+ """
+ if len(docs) == 0:
+ return RerankResult(model=model, docs=[])
+
+ base_url = credentials.get("base_url", "https://api.voyageai.com/v1")
+ base_url = base_url.removesuffix("/")
+
+ try:
+ response = httpx.post(
+ base_url + "/rerank",
+ json={"model": model, "query": query, "documents": docs, "top_k": top_n, "return_documents": True},
+ headers={"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"},
+ )
+ response.raise_for_status()
+ results = response.json()
+
+ rerank_documents = []
+ for result in results["data"]:
+ rerank_document = RerankDocument(
+ index=result["index"],
+ text=result["document"],
+ score=result["relevance_score"],
+ )
+ if score_threshold is None or result["relevance_score"] >= score_threshold:
+ rerank_documents.append(rerank_document)
+
+ return RerankResult(model=model, docs=rerank_documents)
+ except httpx.HTTPStatusError as e:
+ raise InvokeServerUnavailableError(str(e))
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+ self._invoke(
+ model=model,
+ credentials=credentials,
+ query="What is the capital of the United States?",
+ docs=[
+ "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
+ "Census, Carson City had a population of 55,274.",
+ "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
+ "are a political division controlled by the United States. Its capital is Saipan.",
+ ],
+ score_threshold=0.8,
+ )
+ except Exception as ex:
+ raise CredentialsValidateFailedError(str(ex))
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ """
+ Map model invoke error to unified error
+ """
+ return {
+ InvokeConnectionError: [httpx.ConnectError],
+ InvokeServerUnavailableError: [httpx.RemoteProtocolError],
+ InvokeRateLimitError: [],
+ InvokeAuthorizationError: [httpx.HTTPStatusError],
+ InvokeBadRequestError: [httpx.RequestError],
+ }
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+ """
+ generate custom model entities from credentials
+ """
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(en_US=model),
+ model_type=ModelType.RERANK,
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "8000"))},
+ )
+
+ return entity
diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/__init__.py b/api/core/model_runtime/model_providers/voyage/text_embedding/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py
new file mode 100644
index 00000000000000..a8a4d3c15bbb13
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py
@@ -0,0 +1,172 @@
+import time
+from json import JSONDecodeError, dumps
+from typing import Optional
+
+import requests
+
+from core.embedding.embedding_constant import EmbeddingInputType
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.invoke import (
+ InvokeAuthorizationError,
+ InvokeBadRequestError,
+ InvokeConnectionError,
+ InvokeError,
+ InvokeRateLimitError,
+ InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+
+
+class VoyageTextEmbeddingModel(TextEmbeddingModel):
+ """
+ Model class for Voyage text embedding model.
+ """
+
+ api_base: str = "https://api.voyageai.com/v1"
+
+ def _invoke(
+ self,
+ model: str,
+ credentials: dict,
+ texts: list[str],
+ user: Optional[str] = None,
+ input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
+ ) -> TextEmbeddingResult:
+ """
+ Invoke text embedding model
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :param user: unique user id
+ :param input_type: input type
+ :return: embeddings result
+ """
+ api_key = credentials["api_key"]
+ if not api_key:
+ raise CredentialsValidateFailedError("api_key is required")
+
+ base_url = credentials.get("base_url", self.api_base)
+ base_url = base_url.removesuffix("/")
+
+ url = base_url + "/embeddings"
+ headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
+ voyage_input_type = "null"
+ if input_type is not None:
+ voyage_input_type = input_type.value
+ data = {"model": model, "input": texts, "input_type": voyage_input_type}
+
+ try:
+ response = requests.post(url, headers=headers, data=dumps(data))
+ except Exception as e:
+ raise InvokeConnectionError(str(e))
+
+ if response.status_code != 200:
+ try:
+ resp = response.json()
+ msg = resp["detail"]
+ if response.status_code == 401:
+ raise InvokeAuthorizationError(msg)
+ elif response.status_code == 429:
+ raise InvokeRateLimitError(msg)
+ elif response.status_code == 500:
+ raise InvokeServerUnavailableError(msg)
+ else:
+ raise InvokeBadRequestError(msg)
+ except JSONDecodeError as e:
+ raise InvokeServerUnavailableError(
+ f"Failed to convert response to json: {e} with text: {response.text}"
+ )
+
+ try:
+ resp = response.json()
+ embeddings = resp["data"]
+ usage = resp["usage"]
+ except Exception as e:
+ raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
+
+ usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
+
+ result = TextEmbeddingResult(
+ model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
+ )
+
+ return result
+
+ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ """
+ Get number of tokens for given prompt messages
+
+ :param model: model name
+ :param credentials: model credentials
+ :param texts: texts to embed
+ :return:
+ """
+ return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
+
+ def validate_credentials(self, model: str, credentials: dict) -> None:
+ """
+ Validate model credentials
+
+ :param model: model name
+ :param credentials: model credentials
+ :return:
+ """
+ try:
+ self._invoke(model=model, credentials=credentials, texts=["ping"])
+ except Exception as e:
+ raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
+
+ @property
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ return {
+ InvokeConnectionError: [InvokeConnectionError],
+ InvokeServerUnavailableError: [InvokeServerUnavailableError],
+ InvokeRateLimitError: [InvokeRateLimitError],
+ InvokeAuthorizationError: [InvokeAuthorizationError],
+ InvokeBadRequestError: [KeyError, InvokeBadRequestError],
+ }
+
+ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+ """
+ Calculate response usage
+
+ :param model: model name
+ :param credentials: model credentials
+ :param tokens: input tokens
+ :return: usage
+ """
+ # get input price info
+ input_price_info = self.get_price(
+ model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
+ )
+
+ # transform usage
+ usage = EmbeddingUsage(
+ tokens=tokens,
+ total_tokens=tokens,
+ unit_price=input_price_info.unit_price,
+ price_unit=input_price_info.unit,
+ total_price=input_price_info.total_amount,
+ currency=input_price_info.currency,
+ latency=time.perf_counter() - self.started_at,
+ )
+
+ return usage
+
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+ """
+ generate custom model entities from credentials
+ """
+ entity = AIModelEntity(
+ model=model,
+ label=I18nObject(en_US=model),
+ model_type=ModelType.TEXT_EMBEDDING,
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+ model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
+ )
+
+ return entity
diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml
new file mode 100644
index 00000000000000..a06bb7639feacd
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml
@@ -0,0 +1,8 @@
+model: voyage-3-lite
+model_type: text-embedding
+model_properties:
+ context_size: 32000
+pricing:
+ input: '0.00002'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml
new file mode 100644
index 00000000000000..117afbcaf3c808
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml
@@ -0,0 +1,8 @@
+model: voyage-3
+model_type: text-embedding
+model_properties:
+ context_size: 32000
+pricing:
+ input: '0.00006'
+ unit: '0.001'
+ currency: USD
diff --git a/api/core/model_runtime/model_providers/voyage/voyage.py b/api/core/model_runtime/model_providers/voyage/voyage.py
new file mode 100644
index 00000000000000..3e33b45e110d56
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/voyage.py
@@ -0,0 +1,28 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class VoyageProvider(ModelProvider):
+ def validate_provider_credentials(self, credentials: dict) -> None:
+ """
+ Validate provider credentials
+ if validate failed, raise exception
+
+ :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+ """
+ try:
+ model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING)
+
+ # Use `voyage-3` model for validate,
+ # no matter what model you pass in, text completion model or chat model
+ model_instance.validate_credentials(model="voyage-3", credentials=credentials)
+ except CredentialsValidateFailedError as ex:
+ raise ex
+ except Exception as ex:
+ logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
+ raise ex
diff --git a/api/core/model_runtime/model_providers/voyage/voyage.yaml b/api/core/model_runtime/model_providers/voyage/voyage.yaml
new file mode 100644
index 00000000000000..c64707800eebe0
--- /dev/null
+++ b/api/core/model_runtime/model_providers/voyage/voyage.yaml
@@ -0,0 +1,31 @@
+provider: voyage
+label:
+ en_US: Voyage
+description:
+ en_US: Embedding and Rerank Model Supported
+icon_small:
+ en_US: icon_s_en.svg
+icon_large:
+ en_US: icon_l_en.svg
+background: "#EFFDFD"
+help:
+ title:
+ en_US: Get your API key from Voyage AI
+ zh_Hans: 从 Voyage 获取 API Key
+ url:
+ en_US: https://dash.voyageai.com/
+supported_model_types:
+ - text-embedding
+ - rerank
+configurate_methods:
+ - predefined-model
+provider_credential_schema:
+ credential_form_schemas:
+ - variable: api_key
+ label:
+ en_US: API Key
+ type: secret-input
+ required: true
+ placeholder:
+ zh_Hans: 在此输入您的 API Key
+ en_US: Enter your API Key
diff --git a/api/pyproject.toml b/api/pyproject.toml
index 64b35621b2ec60..e737761f3b2c0b 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -123,6 +123,7 @@ FIRECRAWL_API_KEY = "fc-"
TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451"
TEI_RERANK_SERVER_URL = "http://a.abc.com:11451"
MIXEDBREAD_API_KEY = "mk-aaaaaaaaaaaaaaaaaaaa"
+VOYAGE_API_KEY = "va-aaaaaaaaaaaaaaaaaaaa"
[tool.poetry]
name = "dify-api"
diff --git a/api/tests/integration_tests/model_runtime/voyage/__init__.py b/api/tests/integration_tests/model_runtime/voyage/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/api/tests/integration_tests/model_runtime/voyage/test_provider.py b/api/tests/integration_tests/model_runtime/voyage/test_provider.py
new file mode 100644
index 00000000000000..08978c88a961e7
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/voyage/test_provider.py
@@ -0,0 +1,25 @@
+import os
+from unittest.mock import Mock, patch
+
+import pytest
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.voyage.voyage import VoyageProvider
+
+
+def test_validate_provider_credentials():
+ provider = VoyageProvider()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
+ with patch("requests.post") as mock_post:
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "object": "list",
+ "data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}],
+ "model": "voyage-3",
+ "usage": {"total_tokens": 1},
+ }
+ mock_response.status_code = 200
+ mock_post.return_value = mock_response
+ provider.validate_provider_credentials(credentials={"api_key": os.environ.get("VOYAGE_API_KEY")})
diff --git a/api/tests/integration_tests/model_runtime/voyage/test_rerank.py b/api/tests/integration_tests/model_runtime/voyage/test_rerank.py
new file mode 100644
index 00000000000000..e97a9e4c811c82
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/voyage/test_rerank.py
@@ -0,0 +1,92 @@
+import os
+from unittest.mock import Mock, patch
+
+import pytest
+
+from core.model_runtime.entities.rerank_entities import RerankResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.voyage.rerank.rerank import VoyageRerankModel
+
+
+def test_validate_credentials():
+ model = VoyageRerankModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(
+ model="rerank-lite-1",
+ credentials={"api_key": "invalid_key"},
+ )
+ with patch("httpx.post") as mock_post:
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "object": "list",
+ "data": [
+ {
+ "relevance_score": 0.546875,
+ "index": 0,
+ "document": "Carson City is the capital city of the American state of Nevada. At the 2010 United "
+ "States Census, Carson City had a population of 55,274.",
+ },
+ {
+ "relevance_score": 0.4765625,
+ "index": 1,
+ "document": "The Commonwealth of the Northern Mariana Islands is a group of islands in the "
+ "Pacific Ocean that are a political division controlled by the United States. Its "
+ "capital is Saipan.",
+ },
+ ],
+ "model": "rerank-lite-1",
+ "usage": {"total_tokens": 96},
+ }
+ mock_response.status_code = 200
+ mock_post.return_value = mock_response
+ model.validate_credentials(
+ model="rerank-lite-1",
+ credentials={
+ "api_key": os.environ.get("VOYAGE_API_KEY"),
+ },
+ )
+
+
+def test_invoke_model():
+ model = VoyageRerankModel()
+ with patch("httpx.post") as mock_post:
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "object": "list",
+ "data": [
+ {
+ "relevance_score": 0.84375,
+ "index": 0,
+ "document": "Kasumi is a girl name of Japanese origin meaning mist.",
+ },
+ {
+ "relevance_score": 0.4765625,
+ "index": 1,
+ "document": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she "
+ "leads a team named PopiParty.",
+ },
+ ],
+ "model": "rerank-lite-1",
+ "usage": {"total_tokens": 59},
+ }
+ mock_response.status_code = 200
+ mock_post.return_value = mock_response
+ result = model.invoke(
+ model="rerank-lite-1",
+ credentials={
+ "api_key": os.environ.get("VOYAGE_API_KEY"),
+ },
+ query="Who is Kasumi?",
+ docs=[
+ "Kasumi is a girl name of Japanese origin meaning mist.",
+ "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named "
+ "PopiParty.",
+ ],
+ score_threshold=0.5,
+ )
+
+ assert isinstance(result, RerankResult)
+ assert len(result.docs) == 1
+ assert result.docs[0].index == 0
+ assert result.docs[0].score >= 0.5
diff --git a/api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py b/api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py
new file mode 100644
index 00000000000000..75719672a9ecc9
--- /dev/null
+++ b/api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py
@@ -0,0 +1,70 @@
+import os
+from unittest.mock import Mock, patch
+
+import pytest
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.voyage.text_embedding.text_embedding import VoyageTextEmbeddingModel
+
+
+def test_validate_credentials():
+ model = VoyageTextEmbeddingModel()
+
+ with pytest.raises(CredentialsValidateFailedError):
+ model.validate_credentials(model="voyage-3", credentials={"api_key": "invalid_key"})
+ with patch("requests.post") as mock_post:
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "object": "list",
+ "data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}],
+ "model": "voyage-3",
+ "usage": {"total_tokens": 1},
+ }
+ mock_response.status_code = 200
+ mock_post.return_value = mock_response
+ model.validate_credentials(model="voyage-3", credentials={"api_key": os.environ.get("VOYAGE_API_KEY")})
+
+
+def test_invoke_model():
+ model = VoyageTextEmbeddingModel()
+
+ with patch("requests.post") as mock_post:
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "object": "list",
+ "data": [
+ {"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0},
+ {"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 1},
+ ],
+ "model": "voyage-3",
+ "usage": {"total_tokens": 2},
+ }
+ mock_response.status_code = 200
+ mock_post.return_value = mock_response
+ result = model.invoke(
+ model="voyage-3",
+ credentials={
+ "api_key": os.environ.get("VOYAGE_API_KEY"),
+ },
+ texts=["hello", "world"],
+ user="abc-123",
+ )
+
+ assert isinstance(result, TextEmbeddingResult)
+ assert len(result.embeddings) == 2
+ assert result.usage.total_tokens == 2
+
+
+def test_get_num_tokens():
+ model = VoyageTextEmbeddingModel()
+
+ num_tokens = model.get_num_tokens(
+ model="voyage-3",
+ credentials={
+ "api_key": os.environ.get("VOYAGE_API_KEY"),
+ },
+ texts=["ping"],
+ )
+
+ assert num_tokens == 1
diff --git a/dev/pytest/pytest_model_runtime.sh b/dev/pytest/pytest_model_runtime.sh
index b60ff64fdcd901..63891eb9f8d13f 100755
--- a/dev/pytest/pytest_model_runtime.sh
+++ b/dev/pytest/pytest_model_runtime.sh
@@ -9,4 +9,5 @@ pytest api/tests/integration_tests/model_runtime/anthropic \
api/tests/integration_tests/model_runtime/upstage \
api/tests/integration_tests/model_runtime/fireworks \
api/tests/integration_tests/model_runtime/nomic \
- api/tests/integration_tests/model_runtime/mixedbread
+ api/tests/integration_tests/model_runtime/mixedbread \
+ api/tests/integration_tests/model_runtime/voyage
\ No newline at end of file