From 5ca18f2cd1ab44f05d14c9902f75312fc121fc4d Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Sun, 29 Sep 2024 16:55:59 +0800 Subject: [PATCH] feat: add voyage ai as a new model provider (#8747) --- .../model_providers/_position.yaml | 1 + .../model_providers/voyage/__init__.py | 0 .../voyage/_assets/icon_l_en.svg | 21 +++ .../voyage/_assets/icon_s_en.svg | 8 + .../model_providers/voyage/rerank/__init__.py | 0 .../voyage/rerank/rerank-1.yaml | 4 + .../voyage/rerank/rerank-lite-1.yaml | 4 + .../model_providers/voyage/rerank/rerank.py | 123 +++++++++++++ .../voyage/text_embedding/__init__.py | 0 .../voyage/text_embedding/text_embedding.py | 172 ++++++++++++++++++ .../voyage/text_embedding/voyage-3-lite.yaml | 8 + .../voyage/text_embedding/voyage-3.yaml | 8 + .../model_providers/voyage/voyage.py | 28 +++ .../model_providers/voyage/voyage.yaml | 31 ++++ api/pyproject.toml | 1 + .../model_runtime/voyage/__init__.py | 0 .../model_runtime/voyage/test_provider.py | 25 +++ .../model_runtime/voyage/test_rerank.py | 92 ++++++++++ .../voyage/test_text_embedding.py | 70 +++++++ dev/pytest/pytest_model_runtime.sh | 3 +- 20 files changed, 598 insertions(+), 1 deletion(-) create mode 100644 api/core/model_runtime/model_providers/voyage/__init__.py create mode 100644 api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg create mode 100644 api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg create mode 100644 api/core/model_runtime/model_providers/voyage/rerank/__init__.py create mode 100644 api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml create mode 100644 api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml create mode 100644 api/core/model_runtime/model_providers/voyage/rerank/rerank.py create mode 100644 api/core/model_runtime/model_providers/voyage/text_embedding/__init__.py create mode 100644 api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py create mode 100644 api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml create mode 100644 api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml create mode 100644 api/core/model_runtime/model_providers/voyage/voyage.py create mode 100644 api/core/model_runtime/model_providers/voyage/voyage.yaml create mode 100644 api/tests/integration_tests/model_runtime/voyage/__init__.py create mode 100644 api/tests/integration_tests/model_runtime/voyage/test_provider.py create mode 100644 api/tests/integration_tests/model_runtime/voyage/test_rerank.py create mode 100644 api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index 80db22ea84fe63..89fccef6598fdd 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -40,3 +40,4 @@ - fireworks - mixedbread - nomic +- voyage diff --git a/api/core/model_runtime/model_providers/voyage/__init__.py b/api/core/model_runtime/model_providers/voyage/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg new file mode 100644 index 00000000000000..a961f5e4355eea --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg @@ -0,0 +1,21 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg new file mode 100644 index 00000000000000..2c4e121dd71f0b --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg @@ -0,0 +1,8 @@ + + + voyage + + + + + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/voyage/rerank/__init__.py b/api/core/model_runtime/model_providers/voyage/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml b/api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml new file mode 100644 index 00000000000000..9c894eda85203b --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml @@ -0,0 +1,4 @@ +model: rerank-1 +model_type: rerank +model_properties: + context_size: 8000 diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml b/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml new file mode 100644 index 00000000000000..b052d6f00028cb --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml @@ -0,0 +1,4 @@ +model: rerank-lite-1 +model_type: rerank +model_properties: + context_size: 4000 diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank.py b/api/core/model_runtime/model_providers/voyage/rerank/rerank.py new file mode 100644 index 00000000000000..33fdebbb45ef36 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank.py @@ -0,0 +1,123 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class VoyageRerankModel(RerankModel): + """ + Model class for Voyage rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + base_url = credentials.get("base_url", "https://api.voyageai.com/v1") + base_url = base_url.removesuffix("/") + + try: + response = httpx.post( + base_url + "/rerank", + json={"model": model, "query": query, "documents": docs, "top_k": top_n, "return_documents": True}, + headers={"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"}, + ) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results["data"]: + rerank_document = RerankDocument( + index=result["index"], + text=result["document"], + score=result["relevance_score"], + ) + if score_threshold is None or result["relevance_score"] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "8000"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/__init__.py b/api/core/model_runtime/model_providers/voyage/text_embedding/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py new file mode 100644 index 00000000000000..a8a4d3c15bbb13 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py @@ -0,0 +1,172 @@ +import time +from json import JSONDecodeError, dumps +from typing import Optional + +import requests + +from core.embedding.embedding_constant import EmbeddingInputType +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + + +class VoyageTextEmbeddingModel(TextEmbeddingModel): + """ + Model class for Voyage text embedding model. + """ + + api_base: str = "https://api.voyageai.com/v1" + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + api_key = credentials["api_key"] + if not api_key: + raise CredentialsValidateFailedError("api_key is required") + + base_url = credentials.get("base_url", self.api_base) + base_url = base_url.removesuffix("/") + + url = base_url + "/embeddings" + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} + voyage_input_type = "null" + if input_type is not None: + voyage_input_type = input_type.value + data = {"model": model, "input": texts, "input_type": voyage_input_type} + + try: + response = requests.post(url, headers=headers, data=dumps(data)) + except Exception as e: + raise InvokeConnectionError(str(e)) + + if response.status_code != 200: + try: + resp = response.json() + msg = resp["detail"] + if response.status_code == 401: + raise InvokeAuthorizationError(msg) + elif response.status_code == 429: + raise InvokeRateLimitError(msg) + elif response.status_code == 500: + raise InvokeServerUnavailableError(msg) + else: + raise InvokeBadRequestError(msg) + except JSONDecodeError as e: + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) + + try: + resp = response.json() + embeddings = resp["data"] + usage = resp["usage"] + except Exception as e: + raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) + + result = TextEmbeddingResult( + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage + ) + + return result + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + return sum(self._get_num_tokens_by_gpt2(text) for text in texts) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke(model=model, credentials=credentials, texts=["ping"]) + except Exception as e: + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError, InvokeBadRequestError], + } + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at, + ) + + return usage + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.TEXT_EMBEDDING, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml new file mode 100644 index 00000000000000..a06bb7639feacd --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml @@ -0,0 +1,8 @@ +model: voyage-3-lite +model_type: text-embedding +model_properties: + context_size: 32000 +pricing: + input: '0.00002' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml new file mode 100644 index 00000000000000..117afbcaf3c808 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml @@ -0,0 +1,8 @@ +model: voyage-3 +model_type: text-embedding +model_properties: + context_size: 32000 +pricing: + input: '0.00006' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/voyage/voyage.py b/api/core/model_runtime/model_providers/voyage/voyage.py new file mode 100644 index 00000000000000..3e33b45e110d56 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/voyage.py @@ -0,0 +1,28 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class VoyageProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING) + + # Use `voyage-3` model for validate, + # no matter what model you pass in, text completion model or chat model + model_instance.validate_credentials(model="voyage-3", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/voyage/voyage.yaml b/api/core/model_runtime/model_providers/voyage/voyage.yaml new file mode 100644 index 00000000000000..c64707800eebe0 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/voyage.yaml @@ -0,0 +1,31 @@ +provider: voyage +label: + en_US: Voyage +description: + en_US: Embedding and Rerank Model Supported +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#EFFDFD" +help: + title: + en_US: Get your API key from Voyage AI + zh_Hans: 从 Voyage 获取 API Key + url: + en_US: https://dash.voyageai.com/ +supported_model_types: + - text-embedding + - rerank +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/pyproject.toml b/api/pyproject.toml index 64b35621b2ec60..e737761f3b2c0b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -123,6 +123,7 @@ FIRECRAWL_API_KEY = "fc-" TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451" TEI_RERANK_SERVER_URL = "http://a.abc.com:11451" MIXEDBREAD_API_KEY = "mk-aaaaaaaaaaaaaaaaaaaa" +VOYAGE_API_KEY = "va-aaaaaaaaaaaaaaaaaaaa" [tool.poetry] name = "dify-api" diff --git a/api/tests/integration_tests/model_runtime/voyage/__init__.py b/api/tests/integration_tests/model_runtime/voyage/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/model_runtime/voyage/test_provider.py b/api/tests/integration_tests/model_runtime/voyage/test_provider.py new file mode 100644 index 00000000000000..08978c88a961e7 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/voyage/test_provider.py @@ -0,0 +1,25 @@ +import os +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.voyage.voyage import VoyageProvider + + +def test_validate_provider_credentials(): + provider = VoyageProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={"api_key": "hahahaha"}) + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "object": "list", + "data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}], + "model": "voyage-3", + "usage": {"total_tokens": 1}, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + provider.validate_provider_credentials(credentials={"api_key": os.environ.get("VOYAGE_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/voyage/test_rerank.py b/api/tests/integration_tests/model_runtime/voyage/test_rerank.py new file mode 100644 index 00000000000000..e97a9e4c811c82 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/voyage/test_rerank.py @@ -0,0 +1,92 @@ +import os +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.voyage.rerank.rerank import VoyageRerankModel + + +def test_validate_credentials(): + model = VoyageRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="rerank-lite-1", + credentials={"api_key": "invalid_key"}, + ) + with patch("httpx.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "object": "list", + "data": [ + { + "relevance_score": 0.546875, + "index": 0, + "document": "Carson City is the capital city of the American state of Nevada. At the 2010 United " + "States Census, Carson City had a population of 55,274.", + }, + { + "relevance_score": 0.4765625, + "index": 1, + "document": "The Commonwealth of the Northern Mariana Islands is a group of islands in the " + "Pacific Ocean that are a political division controlled by the United States. Its " + "capital is Saipan.", + }, + ], + "model": "rerank-lite-1", + "usage": {"total_tokens": 96}, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + model.validate_credentials( + model="rerank-lite-1", + credentials={ + "api_key": os.environ.get("VOYAGE_API_KEY"), + }, + ) + + +def test_invoke_model(): + model = VoyageRerankModel() + with patch("httpx.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "object": "list", + "data": [ + { + "relevance_score": 0.84375, + "index": 0, + "document": "Kasumi is a girl name of Japanese origin meaning mist.", + }, + { + "relevance_score": 0.4765625, + "index": 1, + "document": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she " + "leads a team named PopiParty.", + }, + ], + "model": "rerank-lite-1", + "usage": {"total_tokens": 59}, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + result = model.invoke( + model="rerank-lite-1", + credentials={ + "api_key": os.environ.get("VOYAGE_API_KEY"), + }, + query="Who is Kasumi?", + docs=[ + "Kasumi is a girl name of Japanese origin meaning mist.", + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named " + "PopiParty.", + ], + score_threshold=0.5, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 0 + assert result.docs[0].score >= 0.5 diff --git a/api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py b/api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py new file mode 100644 index 00000000000000..75719672a9ecc9 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/voyage/test_text_embedding.py @@ -0,0 +1,70 @@ +import os +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.voyage.text_embedding.text_embedding import VoyageTextEmbeddingModel + + +def test_validate_credentials(): + model = VoyageTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="voyage-3", credentials={"api_key": "invalid_key"}) + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "object": "list", + "data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}], + "model": "voyage-3", + "usage": {"total_tokens": 1}, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + model.validate_credentials(model="voyage-3", credentials={"api_key": os.environ.get("VOYAGE_API_KEY")}) + + +def test_invoke_model(): + model = VoyageTextEmbeddingModel() + + with patch("requests.post") as mock_post: + mock_response = Mock() + mock_response.json.return_value = { + "object": "list", + "data": [ + {"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}, + {"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 1}, + ], + "model": "voyage-3", + "usage": {"total_tokens": 2}, + } + mock_response.status_code = 200 + mock_post.return_value = mock_response + result = model.invoke( + model="voyage-3", + credentials={ + "api_key": os.environ.get("VOYAGE_API_KEY"), + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 2 + + +def test_get_num_tokens(): + model = VoyageTextEmbeddingModel() + + num_tokens = model.get_num_tokens( + model="voyage-3", + credentials={ + "api_key": os.environ.get("VOYAGE_API_KEY"), + }, + texts=["ping"], + ) + + assert num_tokens == 1 diff --git a/dev/pytest/pytest_model_runtime.sh b/dev/pytest/pytest_model_runtime.sh index b60ff64fdcd901..63891eb9f8d13f 100755 --- a/dev/pytest/pytest_model_runtime.sh +++ b/dev/pytest/pytest_model_runtime.sh @@ -9,4 +9,5 @@ pytest api/tests/integration_tests/model_runtime/anthropic \ api/tests/integration_tests/model_runtime/upstage \ api/tests/integration_tests/model_runtime/fireworks \ api/tests/integration_tests/model_runtime/nomic \ - api/tests/integration_tests/model_runtime/mixedbread + api/tests/integration_tests/model_runtime/mixedbread \ + api/tests/integration_tests/model_runtime/voyage \ No newline at end of file