diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-voyageai-rerank/llama_index/postprocessor/voyageai_rerank/base.py b/llama-index-integrations/postprocessor/llama-index-postprocessor-voyageai-rerank/llama_index/postprocessor/voyageai_rerank/base.py index 7a6cf670b39f0..5ef168e7af68b 100644 --- a/llama-index-integrations/postprocessor/llama-index-postprocessor-voyageai-rerank/llama_index/postprocessor/voyageai_rerank/base.py +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-voyageai-rerank/llama_index/postprocessor/voyageai_rerank/base.py @@ -15,11 +15,13 @@ class VoyageAIRerank(BaseNodePostprocessor): model: str = Field(description="Name of the model to use.") - top_n: int = Field( - description="The number of most relevant documents to return. If not specified, the reranking results of all documents will be returned." + top_n: Optional[int] = Field( + description="The number of most relevant documents to return. If not specified, the reranking results of all documents will be returned.", + default=None, ) truncation: bool = Field( - description="Whether to truncate the input to satisfy the 'context length limit' on the query and the documents." + description="Whether to truncate the input to satisfy the 'context length limit' on the query and the documents.", + default=True, ) _client: Any = PrivateAttr() @@ -29,7 +31,7 @@ def __init__( model: str, api_key: Optional[str] = None, top_n: Optional[int] = None, - truncation: Optional[bool] = None, + truncation: bool = True, # deprecated top_k: Optional[int] = None, ): @@ -55,7 +57,10 @@ def _postprocess_nodes( ) -> List[NodeWithScore]: dispatcher.event( ReRankStartEvent( - query=query_bundle, nodes=nodes, top_n=self.top_n, model_name=self.model + query=query_bundle, + nodes=nodes, + top_n=self.top_n or len(nodes), + model_name=self.model, ) ) @@ -70,7 +75,7 @@ def _postprocess_nodes( EventPayload.NODES: nodes, EventPayload.MODEL_NAME: self.model, EventPayload.QUERY_STR: query_bundle.query_str, - EventPayload.TOP_K: self.top_n, + EventPayload.TOP_K: self.top_n or len(nodes), }, ) as event: texts = [ diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-voyageai-rerank/pyproject.toml b/llama-index-integrations/postprocessor/llama-index-postprocessor-voyageai-rerank/pyproject.toml index d15dd97efe8f9..7d48f7a7de7ae 100644 --- a/llama-index-integrations/postprocessor/llama-index-postprocessor-voyageai-rerank/pyproject.toml +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-voyageai-rerank/pyproject.toml @@ -30,7 +30,7 @@ license = "MIT" name = "llama-index-postprocessor-voyageai-rerank" packages = [{include = "llama_index/"}] readme = "README.md" -version = "0.3.1" +version = "0.3.2" [tool.poetry.dependencies] python = ">=3.9,<4.0" diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-voyageai-rerank/tests/test_postprocessor_voyageai-rerank.py b/llama-index-integrations/postprocessor/llama-index-postprocessor-voyageai-rerank/tests/test_postprocessor_voyageai-rerank.py index 48e221a6bf5dc..e339f17de5eeb 100644 --- a/llama-index-integrations/postprocessor/llama-index-postprocessor-voyageai-rerank/tests/test_postprocessor_voyageai-rerank.py +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-voyageai-rerank/tests/test_postprocessor_voyageai-rerank.py @@ -1,10 +1,13 @@ -from llama_index.core.postprocessor.types import BaseNodePostprocessor -from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode +import os + +import pytest +from pytest_mock import MockerFixture from voyageai.api_resources import VoyageResponse +from voyageai.object.reranking import RerankingObject +from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode from llama_index.postprocessor.voyageai_rerank import VoyageAIRerank -from voyageai.object.reranking import RerankingObject -from pytest_mock import MockerFixture rerank_sample_response = { "object": "list", @@ -22,7 +25,11 @@ def test_class(): assert BaseNodePostprocessor.__name__ in names_of_base_classes -def test_rerank(mocker: MockerFixture) -> None: +@pytest.mark.parametrize( + "constructor_kwargs", + [{"top_n": 2, "truncation": True}, {"top_n": None}], +) +def test_rerank(mocker: MockerFixture, constructor_kwargs: dict) -> None: # Mocked client with the desired behavior for embed_documents result_object = RerankingObject( documents=["0", "1"], @@ -39,7 +46,7 @@ def test_rerank(mocker: MockerFixture) -> None: ) voyageai_rerank = VoyageAIRerank( - api_key="api_key", top_n=2, model="rerank-lite-1", truncation=True + api_key="api_key", model="rerank-lite-1", **constructor_kwargs ) result = voyageai_rerank.postprocess_nodes( nodes=[ @@ -51,3 +58,20 @@ def test_rerank(mocker: MockerFixture) -> None: assert len(result) == 2 assert result[0].text == "text2" assert result[1].text == "text1" + + +def test_rerank_construction_with_no_optional_kwargs(): + os.environ["VOYAGE_API_KEY"] = "mock_api_key" + reranker = VoyageAIRerank(model="rerank-2") + assert reranker.truncation + assert reranker.top_n is None + assert reranker.model == "rerank-2" + + +def test_rerank_construction_with_optional_kwargs(): + reranker = VoyageAIRerank( + model="rerank-2", api_key="mock_api_key", top_n=10, truncation=False + ) + assert not reranker.truncation + assert reranker.top_n == 10 + assert reranker.model == "rerank-2"