diff --git a/haystack/components/rankers/sentence_transformers_diversity.py b/haystack/components/rankers/sentence_transformers_diversity.py index cc5d64d3bf..88fdc6aaaf 100644 --- a/haystack/components/rankers/sentence_transformers_diversity.py +++ b/haystack/components/rankers/sentence_transformers_diversity.py @@ -2,7 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Literal, Optional +from enum import Enum +from typing import Any, Dict, List, Optional, Union from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.lazy_imports import LazyImport @@ -16,24 +17,91 @@ from sentence_transformers import SentenceTransformer +class DiversityRankingStrategy(Enum): + """ + The strategy to use for diversity ranking. + """ + + GREEDY_DIVERSITY_ORDER = "greedy_diversity_order" + MAXIMUM_MARGIN_RELEVANCE = "maximum_margin_relevance" + + def __str__(self) -> str: + """ + Convert a Strategy enum to a string. + """ + return self.value + + @staticmethod + def from_str(string: str) -> "DiversityRankingStrategy": + """ + Convert a string to a Strategy enum. + """ + enum_map = {e.value: e for e in DiversityRankingStrategy} + strategy = enum_map.get(string) + if strategy is None: + msg = f"Unknown strategy '{string}'. Supported strategies are: {list(enum_map.keys())}" + raise ValueError(msg) + return strategy + + +class DiversityRankingSimilarity(Enum): + """ + The similarity metric to use for comparing embeddings. + """ + + DOT_PRODUCT = "dot_product" + COSINE = "cosine" + + def __str__(self) -> str: + """ + Convert a Similarity enum to a string. + """ + return self.value + + @staticmethod + def from_str(string: str) -> "DiversityRankingSimilarity": + """ + Convert a string to a Similarity enum. + """ + enum_map = {e.value: e for e in DiversityRankingSimilarity} + similarity = enum_map.get(string) + if similarity is None: + msg = f"Unknown similarity metric '{string}'. Supported metrics are: {list(enum_map.keys())}" + raise ValueError(msg) + return similarity + + @component class SentenceTransformersDiversityRanker: """ A Diversity Ranker based on Sentence Transformers. - Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity - of the documents. + Applies a document ranking algorithm based on one of the two strategies: + + 1. Greedy Diversity Order: + + Implements a document ranking algorithm that orders documents in a way that maximizes the overall diversity + of the documents based on their similarity to the query. - This component provides functionality to rank a list of documents based on their similarity with respect to the - query to maximize the overall diversity. It uses a pre-trained Sentence Transformers model to embed the query and - the Documents. + It uses a pre-trained Sentence Transformers model to embed the query and + the documents. - Usage example: + 2. Maximum Margin Relevance: + + Implements a document ranking algorithm that orders documents based on their Maximum Margin Relevance (MMR) + scores. + + MMR scores are calculated for each document based on their relevance to the query and diversity from already + selected documents. The algorithm iteratively selects documents based on their MMR scores, balancing between + relevance to the query and diversity from already selected documents. The 'lambda_threshold' controls the + trade-off between relevance and diversity. + + ### Usage example ```python from haystack import Document from haystack.components.rankers import SentenceTransformersDiversityRanker - ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine") + ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine", strategy="greedy_diversity_order") ranker.warm_up() docs = [Document(content="Paris"), Document(content="Berlin")] @@ -41,7 +109,7 @@ class SentenceTransformersDiversityRanker: output = ranker.run(query=query, documents=docs) docs = output["documents"] ``` - """ + """ # noqa: E501 def __init__( self, @@ -49,14 +117,16 @@ def __init__( top_k: int = 10, device: Optional[ComponentDevice] = None, token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), - similarity: Literal["dot_product", "cosine"] = "cosine", + similarity: Union[str, DiversityRankingSimilarity] = "cosine", query_prefix: str = "", query_suffix: str = "", document_prefix: str = "", document_suffix: str = "", meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", - ): + strategy: Union[str, DiversityRankingStrategy] = "greedy_diversity_order", + lambda_threshold: float = 0.5, + ): # pylint: disable=too-many-positional-arguments """ Initialize a SentenceTransformersDiversityRanker. @@ -78,6 +148,10 @@ def __init__( :param document_suffix: A string to add to the end of each Document text before ranking. :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content. :param embedding_separator: Separator used to concatenate the meta fields to the Document content. + :param strategy: The strategy to use for diversity ranking. Can be either "greedy_diversity_order" or + "maximum_margin_relevance". + :param lambda_threshold: The trade-off parameter between relevance and diversity. Only used when strategy is + "maximum_margin_relevance". """ torch_and_sentence_transformers_import.check() @@ -88,15 +162,16 @@ def __init__( self.device = ComponentDevice.resolve_device(device) self.token = token self.model = None - if similarity not in ["dot_product", "cosine"]: - raise ValueError(f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}.") - self.similarity = similarity + self.similarity = DiversityRankingSimilarity.from_str(similarity) if isinstance(similarity, str) else similarity self.query_prefix = query_prefix self.document_prefix = document_prefix self.query_suffix = query_suffix self.document_suffix = document_suffix self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator + self.strategy = DiversityRankingStrategy.from_str(strategy) if isinstance(strategy, str) else strategy + self.lambda_threshold = lambda_threshold or 0.5 + self._check_lambda_threshold(self.lambda_threshold, self.strategy) def warm_up(self): """ @@ -119,16 +194,18 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, model=self.model_name_or_path, + top_k=self.top_k, device=self.device.to_dict(), token=self.token.to_dict() if self.token else None, - top_k=self.top_k, - similarity=self.similarity, + similarity=str(self.similarity), query_prefix=self.query_prefix, - document_prefix=self.document_prefix, query_suffix=self.query_suffix, + document_prefix=self.document_prefix, document_suffix=self.document_suffix, meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, + strategy=str(self.strategy), + lambda_threshold=self.lambda_threshold, ) @classmethod @@ -182,14 +259,7 @@ def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List """ texts_to_embed = self._prepare_texts_to_embed(documents) - # Calculate embeddings - doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined] - query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined] - - # Normalize embeddings to unit length for computing cosine similarity - if self.similarity == "cosine": - doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1) - query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1) + doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed) n = len(documents) selected: List[int] = [] @@ -218,14 +288,84 @@ def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List return ranked_docs + def _embed_and_normalize(self, query, texts_to_embed): + # Calculate embeddings + doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined] + query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined] + + # Normalize embeddings to unit length for computing cosine similarity + if self.similarity == DiversityRankingSimilarity.COSINE: + doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1) + query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1) + return doc_embeddings, query_embedding + + def _maximum_margin_relevance( + self, query: str, documents: List[Document], lambda_threshold: float, top_k: int + ) -> List[Document]: + """ + Orders the given list of documents according to the Maximum Margin Relevance (MMR) scores. + + MMR scores are calculated for each document based on their relevance to the query and diversity from already + selected documents. + + The algorithm iteratively selects documents based on their MMR scores, balancing between relevance to the query + and diversity from already selected documents. The 'lambda_threshold' controls the trade-off between relevance + and diversity. + + A closer value to 0 favors diversity, while a closer value to 1 favors relevance to the query. + + See : "The Use of MMR, Diversity-Based Reranking for Reordering Documents and Producing Summaries" + https://www.cs.cmu.edu/~jgc/publication/The_Use_MMR_Diversity_Based_LTMIR_1998.pdf + """ + + texts_to_embed = self._prepare_texts_to_embed(documents) + doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed) + top_k = top_k if top_k else len(documents) + + selected: List[int] = [] + query_similarities_as_tensor = query_embedding @ doc_embeddings.T + query_similarities = query_similarities_as_tensor.reshape(-1) + idx = int(torch.argmax(query_similarities)) + selected.append(idx) + while len(selected) < top_k: + best_idx = None + best_score = -float("inf") + for idx, _ in enumerate(documents): + if idx in selected: + continue + relevance_score = query_similarities[idx] + diversity_score = max(doc_embeddings[idx] @ doc_embeddings[j].T for j in selected) + mmr_score = lambda_threshold * relevance_score - (1 - lambda_threshold) * diversity_score + if mmr_score > best_score: + best_score = mmr_score + best_idx = idx + if best_idx is None: + raise ValueError("No best document found, check if the documents list contains any documents.") + selected.append(best_idx) + + return [documents[i] for i in selected] + + @staticmethod + def _check_lambda_threshold(lambda_threshold: float, strategy: DiversityRankingStrategy): + if (strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE) and not 0 <= lambda_threshold <= 1: + raise ValueError(f"lambda_threshold must be between 0 and 1, but got {lambda_threshold}.") + @component.output_types(documents=List[Document]) - def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): + def run( + self, + query: str, + documents: List[Document], + top_k: Optional[int] = None, + lambda_threshold: Optional[float] = None, + ) -> Dict[str, List[Document]]: """ Rank the documents based on their diversity. :param query: The search query. :param documents: List of Document objects to be ranker. :param top_k: Optional. An integer to override the top_k set during initialization. + :param lambda_threshold: Override the trade-off parameter between relevance and diversity. Only used when + strategy is "maximum_margin_relevance". :returns: A dictionary with the following key: - `documents`: List of Document objects that have been selected based on the diversity ranking. @@ -245,9 +385,17 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None if top_k is None: top_k = self.top_k - elif top_k <= 0: - raise ValueError(f"top_k must be > 0, but got {top_k}") - - diversity_sorted = self._greedy_diversity_order(query=query, documents=documents) + elif not 0 < top_k <= len(documents): + raise ValueError(f"top_k must be between 1 and {len(documents)}, but got {top_k}") + + if self.strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE: + if lambda_threshold is None: + lambda_threshold = self.lambda_threshold + self._check_lambda_threshold(lambda_threshold, self.strategy) + re_ranked_docs = self._maximum_margin_relevance( + query=query, documents=documents, lambda_threshold=lambda_threshold, top_k=top_k + ) + else: + re_ranked_docs = self._greedy_diversity_order(query=query, documents=documents) - return {"documents": diversity_sorted[:top_k]} + return {"documents": re_ranked_docs[:top_k]} diff --git a/haystack/components/rankers/transformers_similarity.py b/haystack/components/rankers/transformers_similarity.py index d380768d3d..06608a4b59 100644 --- a/haystack/components/rankers/transformers_similarity.py +++ b/haystack/components/rankers/transformers_similarity.py @@ -43,7 +43,7 @@ class TransformersSimilarityRanker: ``` """ - def __init__( # noqa: PLR0913 + def __init__( # noqa: PLR0913, pylint: disable=too-many-positional-arguments self, model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2", device: Optional[ComponentDevice] = None, @@ -201,7 +201,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "TransformersSimilarityRanker": return default_from_dict(cls, data) @component.output_types(documents=List[Document]) - def run( + def run( # pylint: disable=too-many-positional-arguments self, query: str, documents: List[Document], diff --git a/releasenotes/notes/add-maximum-margin-relevance-ranker-9d6d71c6a408c6d1.yaml b/releasenotes/notes/add-maximum-margin-relevance-ranker-9d6d71c6a408c6d1.yaml new file mode 100644 index 0000000000..51501a17b8 --- /dev/null +++ b/releasenotes/notes/add-maximum-margin-relevance-ranker-9d6d71c6a408c6d1.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Added the Maximum Margin Relevance (MMR) strategy to the `SentenceTransformersDiversityRanker`. MMR scores are calculated for each document based on their relevance to the query and diversity from already selected documents. diff --git a/test/components/rankers/test_sentence_transformers_diversity.py b/test/components/rankers/test_sentence_transformers_diversity.py index ba3b10ae5c..eabd2ac375 100644 --- a/test/components/rankers/test_sentence_transformers_diversity.py +++ b/test/components/rankers/test_sentence_transformers_diversity.py @@ -6,8 +6,12 @@ import pytest import torch -from haystack import Document +from haystack import Document, Pipeline from haystack.components.rankers import SentenceTransformersDiversityRanker +from haystack.components.rankers.sentence_transformers_diversity import ( + DiversityRankingSimilarity, + DiversityRankingStrategy, +) from haystack.utils import ComponentDevice from haystack.utils.auth import Secret @@ -27,7 +31,7 @@ def test_init(self): assert component.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2" assert component.top_k == 10 assert component.device == ComponentDevice.resolve_device(None) - assert component.similarity == "cosine" + assert component.similarity == DiversityRankingSimilarity.COSINE assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) assert component.query_prefix == "" assert component.document_prefix == "" @@ -36,7 +40,7 @@ def test_init(self): assert component.meta_fields_to_embed == [] assert component.embedding_separator == "\n" - def test_init_with_custom_init_parameters(self): + def test_init_with_custom_parameters(self): component = SentenceTransformersDiversityRanker( model="sentence-transformers/msmarco-distilbert-base-v4", top_k=5, @@ -53,7 +57,7 @@ def test_init_with_custom_init_parameters(self): assert component.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4" assert component.top_k == 5 assert component.device == ComponentDevice.from_str("cuda:0") - assert component.similarity == "dot_product" + assert component.similarity == DiversityRankingSimilarity.DOT_PRODUCT assert component.token == Secret.from_token("fake-api-token") assert component.query_prefix == "query:" assert component.document_prefix == "document:" @@ -65,22 +69,26 @@ def test_init_with_custom_init_parameters(self): def test_to_dict(self): component = SentenceTransformersDiversityRanker() data = component.to_dict() - assert data == { - "type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker", - "init_parameters": { - "model": "sentence-transformers/all-MiniLM-L6-v2", - "top_k": 10, - "device": ComponentDevice.resolve_device(None).to_dict(), - "similarity": "cosine", - "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}, - "query_prefix": "", - "document_prefix": "", - "query_suffix": "", - "document_suffix": "", - "meta_fields_to_embed": [], - "embedding_separator": "\n", - }, + assert ( + data["type"] + == "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker" + ) + assert data["init_parameters"]["model"] == "sentence-transformers/all-MiniLM-L6-v2" + assert data["init_parameters"]["top_k"] == 10 + assert data["init_parameters"]["device"] == ComponentDevice.resolve_device(None).to_dict() + assert data["init_parameters"]["similarity"] == "cosine" + assert data["init_parameters"]["token"] == { + "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], + "strict": False, + "type": "env_var", } + assert data["init_parameters"]["query_prefix"] == "" + assert data["init_parameters"]["document_prefix"] == "" + assert data["init_parameters"]["query_suffix"] == "" + assert data["init_parameters"]["document_suffix"] == "" + assert data["init_parameters"]["meta_fields_to_embed"] == [] + assert data["init_parameters"]["embedding_separator"] == "\n" + assert data["init_parameters"]["strategy"] == "greedy_diversity_order" def test_from_dict(self): data = { @@ -104,7 +112,7 @@ def test_from_dict(self): assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2" assert ranker.top_k == 10 assert ranker.device == ComponentDevice.resolve_device(None) - assert ranker.similarity == "cosine" + assert ranker.similarity == DiversityRankingSimilarity.COSINE assert ranker.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) assert ranker.query_prefix == "" assert ranker.document_prefix == "" @@ -135,7 +143,7 @@ def test_from_dict_none_device(self): assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2" assert ranker.top_k == 10 assert ranker.device == ComponentDevice.resolve_device(None) - assert ranker.similarity == "cosine" + assert ranker.similarity == DiversityRankingSimilarity.COSINE assert ranker.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) assert ranker.query_prefix == "" assert ranker.document_prefix == "" @@ -154,7 +162,7 @@ def test_from_dict_no_default_parameters(self): assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2" assert ranker.top_k == 10 assert ranker.device == ComponentDevice.resolve_device(None) - assert ranker.similarity == "cosine" + assert ranker.similarity == DiversityRankingSimilarity.COSINE assert ranker.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) assert ranker.query_prefix == "" assert ranker.document_prefix == "" @@ -163,7 +171,7 @@ def test_from_dict_no_default_parameters(self): assert ranker.meta_fields_to_embed == [] assert ranker.embedding_separator == "\n" - def test_to_dict_with_custom_init_parameters(self): + def test_to_dict_with_custom_parameters(self): component = SentenceTransformersDiversityRanker( model="sentence-transformers/msmarco-distilbert-base-v4", top_k=5, @@ -178,22 +186,23 @@ def test_to_dict_with_custom_init_parameters(self): embedding_separator="--", ) data = component.to_dict() - assert data == { - "type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker", - "init_parameters": { - "model": "sentence-transformers/msmarco-distilbert-base-v4", - "top_k": 5, - "device": ComponentDevice.from_str("cuda:0").to_dict(), - "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, - "similarity": "dot_product", - "query_prefix": "query:", - "document_prefix": "document:", - "query_suffix": "query suffix", - "document_suffix": "document suffix", - "meta_fields_to_embed": ["meta_field"], - "embedding_separator": "--", - }, - } + + assert ( + data["type"] + == "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker" + ) + assert data["init_parameters"]["model"] == "sentence-transformers/msmarco-distilbert-base-v4" + assert data["init_parameters"]["top_k"] == 5 + assert data["init_parameters"]["device"] == ComponentDevice.from_str("cuda:0").to_dict() + assert data["init_parameters"]["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"} + assert data["init_parameters"]["similarity"] == "dot_product" + assert data["init_parameters"]["query_prefix"] == "query:" + assert data["init_parameters"]["document_prefix"] == "document:" + assert data["init_parameters"]["query_suffix"] == "query suffix" + assert data["init_parameters"]["document_suffix"] == "document suffix" + assert data["init_parameters"]["meta_fields_to_embed"] == ["meta_field"] + assert data["init_parameters"]["embedding_separator"] == "--" + assert data["init_parameters"]["strategy"] == "greedy_diversity_order" def test_from_dict_with_custom_init_parameters(self): data = { @@ -217,7 +226,7 @@ def test_from_dict_with_custom_init_parameters(self): assert ranker.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4" assert ranker.top_k == 5 assert ranker.device == ComponentDevice.from_str("cuda:0") - assert ranker.similarity == "dot_product" + assert ranker.similarity == DiversityRankingSimilarity.DOT_PRODUCT assert ranker.token == Secret.from_env_var("ENV_VAR", strict=False) assert ranker.query_prefix == "query:" assert ranker.document_prefix == "document:" @@ -226,16 +235,24 @@ def test_from_dict_with_custom_init_parameters(self): assert ranker.meta_fields_to_embed == ["meta_field"] assert ranker.embedding_separator == "--" - def test_run_incorrect_similarity(self): + def test_run_invalid_similarity(self): """ Tests that run method raises ValueError if similarity is incorrect """ similarity = "incorrect" - with pytest.raises( - ValueError, match=f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}." - ): + with pytest.raises(ValueError, match=f"Unknown similarity metric"): SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity) + def test_run_invalid_strategy(self): + """ + Tests that run method raises ValueError if strategy is incorrect + """ + strategy = "incorrect" + with pytest.raises(ValueError, match=f"Unknown strategy"): + SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine", strategy=strategy + ) + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_without_warm_up(self, similarity): """ @@ -362,7 +379,7 @@ def test_run_negative_top_k(self, similarity): query = "test" documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] - with pytest.raises(ValueError, match="top_k must be > 0, but got"): + with pytest.raises(ValueError, match="top_k must be between"): ranker.run(query=query, documents=documents, top_k=-5) @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) @@ -509,6 +526,52 @@ def test_run_greedy_diversity_order(self, similarity): assert ranked_text == "Berlin Eiffel Tower Bananas" + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_maximum_margin_relevance(self, similarity): + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity + ) + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) + + query = "city" + documents = [Document(content="Eiffel Tower"), Document(content="Berlin"), Document(content="Bananas")] + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) + + ranked_docs = ranker._maximum_margin_relevance(query=query, documents=documents, lambda_threshold=0, top_k=3) + ranked_text = " ".join([doc.content for doc in ranked_docs]) + + assert ranked_text == "Berlin Eiffel Tower Bananas" + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_maximum_margin_relevance_with_given_lambda_threshold(self, similarity): + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity + ) + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) + + query = "city" + documents = [Document(content="Eiffel Tower"), Document(content="Berlin"), Document(content="Bananas")] + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) + + ranked_docs = ranker._maximum_margin_relevance(query=query, documents=documents, lambda_threshold=1, top_k=3) + ranked_text = " ".join([doc.content for doc in ranked_docs]) + + assert ranked_text == "Berlin Eiffel Tower Bananas" + + def test_pipeline_serialise_deserialise(self): + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine", top_k=5 + ) + + pipe = Pipeline() + pipe.add_component("ranker", ranker) + pipe_serialized = pipe.dumps() + assert Pipeline.loads(pipe_serialized) == pipe + @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run(self, similarity): @@ -607,3 +670,51 @@ def test_run_real_world_use_case(self, similarity): # Check the order of ranked documents by comparing the content of the ranked documents assert result_content == expected_content + + @pytest.mark.integration + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_with_maximum_margin_relevance_strategy(self, similarity): + query = "renewable energy sources" + docs = [ + Document(content="18th-century French literature"), + Document(content="Solar power generation"), + Document(content="Ancient Egyptian hieroglyphics"), + Document(content="Wind turbine technology"), + Document(content="Baking sourdough bread"), + Document(content="Hydroelectric dam systems"), + Document(content="Geothermal energy extraction"), + Document(content="Biomass fuel production"), + ] + + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, strategy="maximum_margin_relevance" + ) + ranker.warm_up() + + # lambda_threshold=1, the most relevant document should be returned first + results = ranker.run(query=query, documents=docs, lambda_threshold=1, top_k=len(docs)) + expected = [ + "Solar power generation", + "Wind turbine technology", + "Geothermal energy extraction", + "Hydroelectric dam systems", + "Biomass fuel production", + "Ancient Egyptian hieroglyphics", + "Baking sourdough bread", + "18th-century French literature", + ] + assert [doc.content for doc in results["documents"]] == expected + + # lambda_threshold=0, after the most relevant one, diverse documents should be returned + results = ranker.run(query=query, documents=docs, lambda_threshold=0, top_k=len(docs)) + expected = [ + "Solar power generation", + "Ancient Egyptian hieroglyphics", + "Baking sourdough bread", + "18th-century French literature", + "Biomass fuel production", + "Hydroelectric dam systems", + "Geothermal energy extraction", + "Wind turbine technology", + ] + assert [doc.content for doc in results["documents"]] == expected