-
Notifications
You must be signed in to change notification settings - Fork 8.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add voyage ai as a new model provider (#8747)
- Loading branch information
Showing
20 changed files
with
598 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,3 +40,4 @@ | |
- fireworks | ||
- mixedbread | ||
- nomic | ||
- voyage |
Empty file.
21 changes: 21 additions & 0 deletions
21
api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions
8
api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
4 changes: 4 additions & 0 deletions
4
api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
model: rerank-1 | ||
model_type: rerank | ||
model_properties: | ||
context_size: 8000 |
4 changes: 4 additions & 0 deletions
4
api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
model: rerank-lite-1 | ||
model_type: rerank | ||
model_properties: | ||
context_size: 4000 |
123 changes: 123 additions & 0 deletions
123
api/core/model_runtime/model_providers/voyage/rerank/rerank.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
from typing import Optional | ||
|
||
import httpx | ||
|
||
from core.model_runtime.entities.common_entities import I18nObject | ||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType | ||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult | ||
from core.model_runtime.errors.invoke import ( | ||
InvokeAuthorizationError, | ||
InvokeBadRequestError, | ||
InvokeConnectionError, | ||
InvokeError, | ||
InvokeRateLimitError, | ||
InvokeServerUnavailableError, | ||
) | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel | ||
|
||
|
||
class VoyageRerankModel(RerankModel): | ||
""" | ||
Model class for Voyage rerank model. | ||
""" | ||
|
||
def _invoke( | ||
self, | ||
model: str, | ||
credentials: dict, | ||
query: str, | ||
docs: list[str], | ||
score_threshold: Optional[float] = None, | ||
top_n: Optional[int] = None, | ||
user: Optional[str] = None, | ||
) -> RerankResult: | ||
""" | ||
Invoke rerank model | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param query: search query | ||
:param docs: docs for reranking | ||
:param score_threshold: score threshold | ||
:param top_n: top n documents to return | ||
:param user: unique user id | ||
:return: rerank result | ||
""" | ||
if len(docs) == 0: | ||
return RerankResult(model=model, docs=[]) | ||
|
||
base_url = credentials.get("base_url", "https://api.voyageai.com/v1") | ||
base_url = base_url.removesuffix("/") | ||
|
||
try: | ||
response = httpx.post( | ||
base_url + "/rerank", | ||
json={"model": model, "query": query, "documents": docs, "top_k": top_n, "return_documents": True}, | ||
headers={"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"}, | ||
) | ||
response.raise_for_status() | ||
results = response.json() | ||
|
||
rerank_documents = [] | ||
for result in results["data"]: | ||
rerank_document = RerankDocument( | ||
index=result["index"], | ||
text=result["document"], | ||
score=result["relevance_score"], | ||
) | ||
if score_threshold is None or result["relevance_score"] >= score_threshold: | ||
rerank_documents.append(rerank_document) | ||
|
||
return RerankResult(model=model, docs=rerank_documents) | ||
except httpx.HTTPStatusError as e: | ||
raise InvokeServerUnavailableError(str(e)) | ||
|
||
def validate_credentials(self, model: str, credentials: dict) -> None: | ||
""" | ||
Validate model credentials | ||
:param model: model name | ||
:param credentials: model credentials | ||
:return: | ||
""" | ||
try: | ||
self._invoke( | ||
model=model, | ||
credentials=credentials, | ||
query="What is the capital of the United States?", | ||
docs=[ | ||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States " | ||
"Census, Carson City had a population of 55,274.", | ||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " | ||
"are a political division controlled by the United States. Its capital is Saipan.", | ||
], | ||
score_threshold=0.8, | ||
) | ||
except Exception as ex: | ||
raise CredentialsValidateFailedError(str(ex)) | ||
|
||
@property | ||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | ||
""" | ||
Map model invoke error to unified error | ||
""" | ||
return { | ||
InvokeConnectionError: [httpx.ConnectError], | ||
InvokeServerUnavailableError: [httpx.RemoteProtocolError], | ||
InvokeRateLimitError: [], | ||
InvokeAuthorizationError: [httpx.HTTPStatusError], | ||
InvokeBadRequestError: [httpx.RequestError], | ||
} | ||
|
||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: | ||
""" | ||
generate custom model entities from credentials | ||
""" | ||
entity = AIModelEntity( | ||
model=model, | ||
label=I18nObject(en_US=model), | ||
model_type=ModelType.RERANK, | ||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "8000"))}, | ||
) | ||
|
||
return entity |
Empty file.
172 changes: 172 additions & 0 deletions
172
api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import time | ||
from json import JSONDecodeError, dumps | ||
from typing import Optional | ||
|
||
import requests | ||
|
||
from core.embedding.embedding_constant import EmbeddingInputType | ||
from core.model_runtime.entities.common_entities import I18nObject | ||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType | ||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | ||
from core.model_runtime.errors.invoke import ( | ||
InvokeAuthorizationError, | ||
InvokeBadRequestError, | ||
InvokeConnectionError, | ||
InvokeError, | ||
InvokeRateLimitError, | ||
InvokeServerUnavailableError, | ||
) | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | ||
|
||
|
||
class VoyageTextEmbeddingModel(TextEmbeddingModel): | ||
""" | ||
Model class for Voyage text embedding model. | ||
""" | ||
|
||
api_base: str = "https://api.voyageai.com/v1" | ||
|
||
def _invoke( | ||
self, | ||
model: str, | ||
credentials: dict, | ||
texts: list[str], | ||
user: Optional[str] = None, | ||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, | ||
) -> TextEmbeddingResult: | ||
""" | ||
Invoke text embedding model | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param texts: texts to embed | ||
:param user: unique user id | ||
:param input_type: input type | ||
:return: embeddings result | ||
""" | ||
api_key = credentials["api_key"] | ||
if not api_key: | ||
raise CredentialsValidateFailedError("api_key is required") | ||
|
||
base_url = credentials.get("base_url", self.api_base) | ||
base_url = base_url.removesuffix("/") | ||
|
||
url = base_url + "/embeddings" | ||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} | ||
voyage_input_type = "null" | ||
if input_type is not None: | ||
voyage_input_type = input_type.value | ||
data = {"model": model, "input": texts, "input_type": voyage_input_type} | ||
|
||
try: | ||
response = requests.post(url, headers=headers, data=dumps(data)) | ||
except Exception as e: | ||
raise InvokeConnectionError(str(e)) | ||
|
||
if response.status_code != 200: | ||
try: | ||
resp = response.json() | ||
msg = resp["detail"] | ||
if response.status_code == 401: | ||
raise InvokeAuthorizationError(msg) | ||
elif response.status_code == 429: | ||
raise InvokeRateLimitError(msg) | ||
elif response.status_code == 500: | ||
raise InvokeServerUnavailableError(msg) | ||
else: | ||
raise InvokeBadRequestError(msg) | ||
except JSONDecodeError as e: | ||
raise InvokeServerUnavailableError( | ||
f"Failed to convert response to json: {e} with text: {response.text}" | ||
) | ||
|
||
try: | ||
resp = response.json() | ||
embeddings = resp["data"] | ||
usage = resp["usage"] | ||
except Exception as e: | ||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") | ||
|
||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) | ||
|
||
result = TextEmbeddingResult( | ||
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage | ||
) | ||
|
||
return result | ||
|
||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: | ||
""" | ||
Get number of tokens for given prompt messages | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param texts: texts to embed | ||
:return: | ||
""" | ||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts) | ||
|
||
def validate_credentials(self, model: str, credentials: dict) -> None: | ||
""" | ||
Validate model credentials | ||
:param model: model name | ||
:param credentials: model credentials | ||
:return: | ||
""" | ||
try: | ||
self._invoke(model=model, credentials=credentials, texts=["ping"]) | ||
except Exception as e: | ||
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") | ||
|
||
@property | ||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | ||
return { | ||
InvokeConnectionError: [InvokeConnectionError], | ||
InvokeServerUnavailableError: [InvokeServerUnavailableError], | ||
InvokeRateLimitError: [InvokeRateLimitError], | ||
InvokeAuthorizationError: [InvokeAuthorizationError], | ||
InvokeBadRequestError: [KeyError, InvokeBadRequestError], | ||
} | ||
|
||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: | ||
""" | ||
Calculate response usage | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param tokens: input tokens | ||
:return: usage | ||
""" | ||
# get input price info | ||
input_price_info = self.get_price( | ||
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens | ||
) | ||
|
||
# transform usage | ||
usage = EmbeddingUsage( | ||
tokens=tokens, | ||
total_tokens=tokens, | ||
unit_price=input_price_info.unit_price, | ||
price_unit=input_price_info.unit, | ||
total_price=input_price_info.total_amount, | ||
currency=input_price_info.currency, | ||
latency=time.perf_counter() - self.started_at, | ||
) | ||
|
||
return usage | ||
|
||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: | ||
""" | ||
generate custom model entities from credentials | ||
""" | ||
entity = AIModelEntity( | ||
model=model, | ||
label=I18nObject(en_US=model), | ||
model_type=ModelType.TEXT_EMBEDDING, | ||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, | ||
) | ||
|
||
return entity |
8 changes: 8 additions & 0 deletions
8
api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
model: voyage-3-lite | ||
model_type: text-embedding | ||
model_properties: | ||
context_size: 32000 | ||
pricing: | ||
input: '0.00002' | ||
unit: '0.001' | ||
currency: USD |
8 changes: 8 additions & 0 deletions
8
api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
model: voyage-3 | ||
model_type: text-embedding | ||
model_properties: | ||
context_size: 32000 | ||
pricing: | ||
input: '0.00006' | ||
unit: '0.001' | ||
currency: USD |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import logging | ||
|
||
from core.model_runtime.entities.model_entities import ModelType | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class VoyageProvider(ModelProvider): | ||
def validate_provider_credentials(self, credentials: dict) -> None: | ||
""" | ||
Validate provider credentials | ||
if validate failed, raise exception | ||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`. | ||
""" | ||
try: | ||
model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING) | ||
|
||
# Use `voyage-3` model for validate, | ||
# no matter what model you pass in, text completion model or chat model | ||
model_instance.validate_credentials(model="voyage-3", credentials=credentials) | ||
except CredentialsValidateFailedError as ex: | ||
raise ex | ||
except Exception as ex: | ||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") | ||
raise ex |
Oops, something went wrong.