From d5d0e842e6ed1efbd622e82be56d2561b00b666b Mon Sep 17 00:00:00 2001 From: Adversarian Date: Wed, 18 Dec 2024 20:39:23 +0330 Subject: [PATCH 1/3] Expose output dims and output dtype as kwargs. --- .../llama_index/embeddings/voyageai/base.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/llama_index/embeddings/voyageai/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/llama_index/embeddings/voyageai/base.py index 860ad1994e84f..065672d84e225 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/llama_index/embeddings/voyageai/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/llama_index/embeddings/voyageai/base.py @@ -1,19 +1,19 @@ """Voyage embeddings file.""" + import logging import os +from io import BytesIO +from pathlib import Path from typing import Any, List, Optional, Union +import voyageai +from PIL import Image + from llama_index.core.base.embeddings.base import Embedding from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.callbacks.base import CallbackManager - -import voyageai from llama_index.core.embeddings import MultiModalEmbedding -from io import BytesIO -from pathlib import Path from llama_index.core.schema import ImageType -from PIL import Image - logger = logging.getLogger(__name__) @@ -36,6 +36,8 @@ class VoyageEmbedding(MultiModalEmbedding): _client: voyageai.Client = PrivateAttr(None) _aclient: voyageai.client_async.AsyncClient = PrivateAttr() truncation: Optional[bool] = None + output_dtype: Optional[str] = None + output_dimension: Optional[int] = None def __init__( self, @@ -43,6 +45,8 @@ def __init__( voyage_api_key: Optional[str] = None, embed_batch_size: Optional[int] = None, truncation: Optional[bool] = None, + output_dtype: Optional[str] = None, + output_dimension: Optional[int] = None, callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ): @@ -73,6 +77,8 @@ def __init__( self._client = voyageai.Client(api_key=voyage_api_key) self._aclient = voyageai.AsyncClient(api_key=voyage_api_key) self.truncation = truncation + self.output_dtype = output_dtype + self.output_dimension = output_dimension @classmethod def class_name(cls) -> str: @@ -161,6 +167,8 @@ def _embed(self, texts: List[str], input_type: str) -> List[List[float]]: model=self.model_name, input_type=input_type, truncation=self.truncation, + output_dtype=self.output_dtype, + output_dimension=self.output_dimension, ).embeddings async def _aembed(self, texts: List[str], input_type: str) -> List[List[float]]: @@ -177,6 +185,8 @@ async def _aembed(self, texts: List[str], input_type: str) -> List[List[float]]: model=self.model_name, input_type=input_type, truncation=self.truncation, + output_dtype=self.output_dtype, + output_dimension=self.output_dimension, ) return r.embeddings From afd7db9950c0ee51b6401c27cd122febd11ccd20 Mon Sep 17 00:00:00 2001 From: Adversarian Date: Wed, 18 Dec 2024 20:39:36 +0330 Subject: [PATCH 2/3] Add unit tests for new kwargs. --- .../tests/test_embeddings_voyageai.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/tests/test_embeddings_voyageai.py b/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/tests/test_embeddings_voyageai.py index 678016600f37f..2fdb09f5cf4d6 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/tests/test_embeddings_voyageai.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/tests/test_embeddings_voyageai.py @@ -17,6 +17,8 @@ def test_embedding_class_voyage_2(): assert emb.embed_batch_size == 72 assert emb.model_name == "voyage-2" assert emb.truncation + assert emb.output_dimension is None + assert emb.output_dtype is None def test_embedding_class_voyage_2_with_batch_size(): @@ -27,6 +29,36 @@ def test_embedding_class_voyage_2_with_batch_size(): assert emb.embed_batch_size == 49 assert emb.model_name == "voyage-2" assert emb.truncation is None + assert emb.output_dimension is None + assert emb.output_dtype is None + + +def test_embedding_class_voyage_3_large_with_output_dimension(): + emb = VoyageEmbedding( + model_name="voyage-3-large", + voyage_api_key="NOT_A_VALID_KEY", + output_dimension=512, + ) + assert isinstance(emb, BaseEmbedding) + assert emb.embed_batch_size == 7 + assert emb.model_name == "voyage-3-large" + assert emb.truncation is None + assert emb.output_dimension == 512 + assert emb.output_dtype is None + + +def test_embedding_class_voyage_3_large_with_output_dtype(): + emb = VoyageEmbedding( + model_name="voyage-3-large", + voyage_api_key="NOT_A_VALID_KEY", + output_dtype="float", + ) + assert isinstance(emb, BaseEmbedding) + assert emb.embed_batch_size == 7 + assert emb.model_name == "voyage-3-large" + assert emb.truncation is None + assert emb.output_dimension is None + assert emb.output_dtype == "float" def test_voyageai_embedding_class(): From be9332eb9aaca60c89c32ff2fd04ab491bdf3ec1 Mon Sep 17 00:00:00 2001 From: Adversarian Date: Wed, 18 Dec 2024 20:42:05 +0330 Subject: [PATCH 3/3] Version bump. --- .../embeddings/llama-index-embeddings-voyageai/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/pyproject.toml index 09b4b3cb06d8b..6f29987d6b85f 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-embeddings-voyageai" readme = "README.md" -version = "0.3.3" +version = "0.3.4" [tool.poetry.dependencies] python = ">=3.9,<4.0"