Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding Maximum Margin Relevance Ranker #8554

Merged
merged 42 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2c75d00
initial import
davidsbatista Nov 15, 2024
648e230
linting
davidsbatista Nov 18, 2024
c174cce
adding MRR tests
davidsbatista Nov 18, 2024
e2251cb
adding release notes
davidsbatista Nov 18, 2024
07d8229
fixing tests
davidsbatista Nov 18, 2024
8f6f18f
Merge branch 'main' into adding-MMR-score-ranker
davidsbatista Nov 18, 2024
2ea5ca4
adding linting ignore to cross-encoder ranker
davidsbatista Nov 18, 2024
0b22c3c
update docstring
davidsbatista Nov 18, 2024
dd675a2
refactoring
davidsbatista Nov 18, 2024
f56bebd
making strategy Optional instead of Literal
davidsbatista Nov 19, 2024
5e5c71a
wip: adding unit tests
davidsbatista Nov 19, 2024
02c4f2d
refactoring MMR algorithm
davidsbatista Nov 19, 2024
3c19732
refactoring tests
davidsbatista Nov 19, 2024
eabddf1
cleaning up and updating tests
davidsbatista Nov 19, 2024
b1035e6
adding empty line between license + code
davidsbatista Nov 19, 2024
302589b
bug in tests
davidsbatista Nov 19, 2024
e1f742d
using Enum for strategy and similarity metric
davidsbatista Nov 19, 2024
6c75b9f
adding more tests
davidsbatista Nov 19, 2024
981149f
adding empty line between license + code
davidsbatista Nov 19, 2024
ef13a6f
removing run time params
davidsbatista Nov 19, 2024
a5f0d75
Merge branch 'main' into adding-MMR-score-ranker
davidsbatista Nov 19, 2024
b933ec3
PR comments
davidsbatista Nov 19, 2024
0e3ab5d
PR comments
davidsbatista Nov 20, 2024
ab5ea66
Merge branch 'main' into adding-MMR-score-ranker
davidsbatista Nov 20, 2024
a3e589b
fixing
davidsbatista Nov 20, 2024
2fb4951
fixing serialisation
davidsbatista Nov 20, 2024
03f4ac5
fixing serialisation tests
davidsbatista Nov 20, 2024
3ec8204
Update haystack/components/rankers/sentence_transformers_diversity.py
davidsbatista Nov 20, 2024
ffbfe90
Update haystack/components/rankers/sentence_transformers_diversity.py
davidsbatista Nov 20, 2024
500be14
Update haystack/components/rankers/sentence_transformers_diversity.py
davidsbatista Nov 20, 2024
2cc20ae
Update haystack/components/rankers/sentence_transformers_diversity.py
davidsbatista Nov 20, 2024
fea260d
Update haystack/components/rankers/sentence_transformers_diversity.py
davidsbatista Nov 20, 2024
e1fc2ed
Update haystack/components/rankers/sentence_transformers_diversity.py
davidsbatista Nov 20, 2024
1e25225
Update haystack/components/rankers/sentence_transformers_diversity.py
davidsbatista Nov 20, 2024
8d30a41
Merge branch 'main' into adding-MMR-score-ranker
davidsbatista Nov 20, 2024
680f8ad
fixing tests
davidsbatista Nov 21, 2024
141aea2
PR comments
davidsbatista Nov 21, 2024
b3404b2
Merge branch 'main' into adding-MMR-score-ranker
davidsbatista Nov 21, 2024
5d1edb8
PR comments
davidsbatista Nov 22, 2024
a744707
PR comments
davidsbatista Nov 22, 2024
180f445
PR comments
davidsbatista Nov 22, 2024
730bc87
Merge branch 'main' into adding-MMR-score-ranker
davidsbatista Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 206 additions & 28 deletions haystack/components/rankers/sentence_transformers_diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List, Literal, Optional
from enum import Enum
from typing import Any, Dict, List, Optional, Union

from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport
Expand All @@ -16,47 +17,105 @@
from sentence_transformers import SentenceTransformer


class Strategy(Enum):
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
"""
The strategy to use for diversity ranking.
"""

GREEDY_DIVERSITY_ORDER = "greedy_diversity_order"
MAXIMUM_MARGIN_RELEVANCE = "maximum_margin_relevance"

@staticmethod
def from_str(value: str) -> "Strategy":
"""
Convert a string to a Strategy enum.
"""
if value == "greedy_diversity_order":
return Strategy.GREEDY_DIVERSITY_ORDER
if value == "maximum_margin_relevance":
return Strategy.MAXIMUM_MARGIN_RELEVANCE
raise ValueError(
f"Invalid value for Strategy: {value} choose from 'greedy_diversity_order' or 'maximum_margin_relevance'"
)


class Similarity(Enum):
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
"""
The similarity metric to use for comparing embeddings.
"""

DOT_PRODUCT = "dot_product"
COSINE = "cosine"

@staticmethod
def from_str(value: str) -> "Similarity":
"""
Convert a string to a Similarity enum.
"""
if value == "dot_product":
return Similarity.DOT_PRODUCT
if value == "cosine":
return Similarity.COSINE
raise ValueError(f"Invalid value for Similarity: {value} choose from 'dot_product' or 'cosine'")


@component
class SentenceTransformersDiversityRanker:
"""
A Diversity Ranker based on Sentence Transformers.

Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity
of the documents.
It applies a document ranking algorithm based on two strategies:
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved

This component provides functionality to rank a list of documents based on their similarity with respect to the
query to maximize the overall diversity. It uses a pre-trained Sentence Transformers model to embed the query and
the Documents.
1. Greedy Diversity Order:

Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity
of the documents.

This component provides functionality to rank a list of documents based on their similarity with respect to the
query to maximize the overall diversity. It uses a pre-trained Sentence Transformers model to embed the query and
the Documents.
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved

2. Maximum Margin Relevance:

Implements a document ranking algorithm that orders documents based on their Maximum Margin Relevance (MMR)
scores.

MMR scores are calculated for each document based on their relevance to the query and diversity from already
selected documents. The algorithm iteratively selects documents based on their MMR scores, balancing between
relevance to the query and diversity from already selected documents. The 'lambda_threshold' controls the
trade-off between relevance and diversity.

Usage example:
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
```python
from haystack import Document
from haystack.components.rankers import SentenceTransformersDiversityRanker

ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine")
ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine", strategy="greedy_diversity_order")
ranker.warm_up()

docs = [Document(content="Paris"), Document(content="Berlin")]
query = "What is the capital of germany?"
output = ranker.run(query=query, documents=docs)
docs = output["documents"]
```
"""
""" # noqa: E501

def __init__(
self,
model: str = "sentence-transformers/all-MiniLM-L6-v2",
top_k: int = 10,
device: Optional[ComponentDevice] = None,
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
similarity: Literal["dot_product", "cosine"] = "cosine",
similarity: Union[str, Similarity] = "cosine",
query_prefix: str = "",
query_suffix: str = "",
document_prefix: str = "",
document_suffix: str = "",
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
strategy: Union[str, Strategy] = "greedy_diversity_order",
lambda_threshold: float = 0.5,
): # pylint: disable=too-many-positional-arguments
"""
Initialize a SentenceTransformersDiversityRanker.

Expand All @@ -78,6 +137,10 @@ def __init__(
:param document_suffix: A string to add to the end of each Document text before ranking.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
:param strategy: The strategy to use for diversity ranking. Can be one of "greedy_diversity_order" or
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
"maximum_margin_relevance".
:param lambda_threshold: The trade-off parameter between relevance and diversity. Only used when strategy is
"maximum_margin_relevance".
"""
torch_and_sentence_transformers_import.check()

Expand All @@ -88,15 +151,50 @@ def __init__(
self.device = ComponentDevice.resolve_device(device)
self.token = token
self.model = None
if similarity not in ["dot_product", "cosine"]:
raise ValueError(f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}.")
self.similarity = similarity
self.similarity = self.parse_similarity(similarity)
self.query_prefix = query_prefix
self.document_prefix = document_prefix
self.query_suffix = query_suffix
self.document_suffix = document_suffix
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.strategy = self.parse_strategy(strategy)
self._check_lambda_threshold(lambda_threshold, strategy) # type: ignore
self.lambda_threshold = lambda_threshold or 0.5
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def parse_similarity(similarity: Union[str, Similarity]) -> "Similarity":
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
"""
Parse the similarity metric to use for comparing embeddings.

:param similarity:
:returns:
The Similarity enum.
"""

if isinstance(similarity, str):
return Similarity.from_str(similarity)
elif isinstance(similarity, Similarity):
return similarity
else:
return Similarity.COSINE

@staticmethod
def parse_strategy(strategy: Union[str, Strategy]) -> "Strategy":
"""
Parse the strategy to use for diversity ranking.

:param strategy:
:returns:
The Strategy enum.
"""

if isinstance(strategy, str):
return Strategy.from_str(strategy)
elif isinstance(strategy, Strategy):
return strategy
else:
return Strategy.GREEDY_DIVERSITY_ORDER

def warm_up(self):
"""
Expand All @@ -119,16 +217,18 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
model=self.model_name_or_path,
top_k=self.top_k,
device=self.device.to_dict(),
token=self.token.to_dict() if self.token else None,
top_k=self.top_k,
similarity=self.similarity,
query_prefix=self.query_prefix,
document_prefix=self.document_prefix,
query_suffix=self.query_suffix,
document_prefix=self.document_prefix,
document_suffix=self.document_suffix,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
strategy=self.strategy,
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
lambda_threshold=self.lambda_threshold,
)

@classmethod
Expand Down Expand Up @@ -182,14 +282,7 @@ def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List
"""
texts_to_embed = self._prepare_texts_to_embed(documents)

# Calculate embeddings
doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined]
query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined]

# Normalize embeddings to unit length for computing cosine similarity
if self.similarity == "cosine":
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed)

n = len(documents)
selected: List[int] = []
Expand Down Expand Up @@ -218,17 +311,91 @@ def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List

return ranked_docs

def _embed_and_normalize(self, query, texts_to_embed):
# Calculate embeddings
doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined]
query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined]

# Normalize embeddings to unit length for computing cosine similarity
if self.similarity == Similarity.COSINE:
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
return doc_embeddings, query_embedding

def _maximum_margin_relevance(
self, query: str, documents: List[Document], lambda_threshold: float, top_k: int
) -> List[Document]:
"""
Orders the given list of documents according to the Maximum Margin Relevance (MMR) scores.

MMR scores are calculated for each document based on their relevance to the query and diversity from already
selected documents.

The algorithm iteratively selects documents based on their MMR scores, balancing between relevance to the query
and diversity from already selected documents. The 'lambda_threshold' controls the trade-off between relevance
and diversity.

A closer value to 0, favors diversity, while a closer value to 1, favors relevance to the query.
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved

See : "The Use of MMR, Diversity-Based Reranking for Reordering Documents and Producing Summaries"
https://www.cs.cmu.edu/~jgc/publication/The_Use_MMR_Diversity_Based_LTMIR_1998.pdf
"""

texts_to_embed = self._prepare_texts_to_embed(documents)
doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed)
top_k = top_k if top_k else len(documents)

selected: List[int] = []
mmr_scores = []

tensor = query_embedding @ doc_embeddings.T
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
query_similarities = tensor.reshape(-1)
idx = int(torch.argmax(query_similarities))
selected.append(idx)
mmr_scores.append(query_similarities[idx])
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
while len(selected) < top_k:
best_idx = None
best_score = -float("inf")
for idx, _ in enumerate(documents):
if idx in selected:
continue
relevance_score = query_similarities[idx]
diversity_score = max(doc_embeddings[idx] @ doc_embeddings[j].T for j in selected)
mmr_score = lambda_threshold * relevance_score - (1 - lambda_threshold) * diversity_score
if mmr_score > best_score:
best_score = mmr_score
best_idx = idx
if best_idx is None:
raise ValueError("No best document found check if the documents list contains any documents.")
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
mmr_scores.append(best_score)
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
selected.append(best_idx)

return [documents[i] for i in selected]

@staticmethod
def _check_lambda_threshold(lambda_threshold: float, strategy: Strategy):
if (strategy == Strategy.MAXIMUM_MARGIN_RELEVANCE) and not 0 <= lambda_threshold <= 1:
raise ValueError(f"lambda_threshold must be between 0 and 1, but got {lambda_threshold}.")

@component.output_types(documents=List[Document])
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
def run(
self,
query: str,
documents: List[Document],
top_k: Optional[int] = None,
lambda_threshold: Optional[float] = None,
) -> Dict[str, List[Document]]:
"""
Rank the documents based on their diversity.

:param query: The search query.
:param documents: List of Document objects to be ranker.
:param top_k: Optional. An integer to override the top_k set during initialization.
:param lambda_threshold: Override the trade-off parameter between relevance and diversity. Only used when
strategy is "maximum_margin_relevance".

:returns: A dictionary with the following key:
- `documents`: List of Document objects that have been selected based on the diversity ranking.
- `documents`: List of Document objects that have been selected based on the diversity-ranking.
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved

:raises ValueError: If the top_k value is less than or equal to 0.
:raises RuntimeError: If the component has not been warmed up.
Expand All @@ -247,7 +414,18 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None
top_k = self.top_k
elif top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
elif top_k > len(documents):
raise ValueError(f"top_k must be <= number of documents, but got {top_k}")
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved

if self.strategy == Strategy.MAXIMUM_MARGIN_RELEVANCE:
# use lambda_threshold provided at runtime or the one set during initialization
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
if lambda_threshold is None:
lambda_threshold = self.lambda_threshold
self._check_lambda_threshold(lambda_threshold, self.strategy)
re_ranked_docs = self._maximum_margin_relevance(
query=query, documents=documents, lambda_threshold=lambda_threshold, top_k=top_k
)
else:
re_ranked_docs = self._greedy_diversity_order(query=query, documents=documents)

diversity_sorted = self._greedy_diversity_order(query=query, documents=documents)

return {"documents": diversity_sorted[:top_k]}
return {"documents": re_ranked_docs[:top_k]}
4 changes: 2 additions & 2 deletions haystack/components/rankers/transformers_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class TransformersSimilarityRanker:
```
"""

def __init__( # noqa: PLR0913
def __init__( # noqa: PLR0913, pylint: disable=too-many-positional-arguments
self,
model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2",
device: Optional[ComponentDevice] = None,
Expand Down Expand Up @@ -201,7 +201,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "TransformersSimilarityRanker":
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
def run( # pylint: disable=too-many-positional-arguments
self,
query: str,
documents: List[Document],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Added the Maximum Margin Relevance (MMR) strategy to the `SentenceTransformersDiversityRanker`. MMR scores are calculated for each document based on their relevance to the query and diversity from already selected documents.
Loading