Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding Maximum Margin Relevance Ranker #8554

Merged
merged 42 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2c75d00
initial import
davidsbatista Nov 15, 2024
648e230
linting
davidsbatista Nov 18, 2024
c174cce
adding MRR tests
davidsbatista Nov 18, 2024
e2251cb
adding release notes
davidsbatista Nov 18, 2024
07d8229
fixing tests
davidsbatista Nov 18, 2024
8f6f18f
Merge branch 'main' into adding-MMR-score-ranker
davidsbatista Nov 18, 2024
2ea5ca4
adding linting ignore to cross-encoder ranker
davidsbatista Nov 18, 2024
0b22c3c
update docstring
davidsbatista Nov 18, 2024
dd675a2
refactoring
davidsbatista Nov 18, 2024
f56bebd
making strategy Optional instead of Literal
davidsbatista Nov 19, 2024
5e5c71a
wip: adding unit tests
davidsbatista Nov 19, 2024
02c4f2d
refactoring MMR algorithm
davidsbatista Nov 19, 2024
3c19732
refactoring tests
davidsbatista Nov 19, 2024
eabddf1
cleaning up and updating tests
davidsbatista Nov 19, 2024
b1035e6
adding empty line between license + code
davidsbatista Nov 19, 2024
302589b
bug in tests
davidsbatista Nov 19, 2024
e1f742d
using Enum for strategy and similarity metric
davidsbatista Nov 19, 2024
6c75b9f
adding more tests
davidsbatista Nov 19, 2024
981149f
adding empty line between license + code
davidsbatista Nov 19, 2024
ef13a6f
removing run time params
davidsbatista Nov 19, 2024
a5f0d75
Merge branch 'main' into adding-MMR-score-ranker
davidsbatista Nov 19, 2024
b933ec3
PR comments
davidsbatista Nov 19, 2024
0e3ab5d
PR comments
davidsbatista Nov 20, 2024
ab5ea66
Merge branch 'main' into adding-MMR-score-ranker
davidsbatista Nov 20, 2024
a3e589b
fixing
davidsbatista Nov 20, 2024
2fb4951
fixing serialisation
davidsbatista Nov 20, 2024
03f4ac5
fixing serialisation tests
davidsbatista Nov 20, 2024
3ec8204
Update haystack/components/rankers/sentence_transformers_diversity.py
davidsbatista Nov 20, 2024
ffbfe90
Update haystack/components/rankers/sentence_transformers_diversity.py
davidsbatista Nov 20, 2024
500be14
Update haystack/components/rankers/sentence_transformers_diversity.py
davidsbatista Nov 20, 2024
2cc20ae
Update haystack/components/rankers/sentence_transformers_diversity.py
davidsbatista Nov 20, 2024
fea260d
Update haystack/components/rankers/sentence_transformers_diversity.py
davidsbatista Nov 20, 2024
e1fc2ed
Update haystack/components/rankers/sentence_transformers_diversity.py
davidsbatista Nov 20, 2024
1e25225
Update haystack/components/rankers/sentence_transformers_diversity.py
davidsbatista Nov 20, 2024
8d30a41
Merge branch 'main' into adding-MMR-score-ranker
davidsbatista Nov 20, 2024
680f8ad
fixing tests
davidsbatista Nov 21, 2024
141aea2
PR comments
davidsbatista Nov 21, 2024
b3404b2
Merge branch 'main' into adding-MMR-score-ranker
davidsbatista Nov 21, 2024
5d1edb8
PR comments
davidsbatista Nov 22, 2024
a744707
PR comments
davidsbatista Nov 22, 2024
180f445
PR comments
davidsbatista Nov 22, 2024
730bc87
Merge branch 'main' into adding-MMR-score-ranker
davidsbatista Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 19 additions & 60 deletions haystack/components/rankers/sentence_transformers_diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sentence_transformers import SentenceTransformer


class Strategy(Enum):
class DiversityRankingStrategy(Enum):
"""
The strategy to use for diversity ranking.
"""
Expand All @@ -32,19 +32,19 @@ def __str__(self) -> str:
return self.value

@staticmethod
def from_str(string: str) -> "Strategy":
def from_str(string: str) -> "DiversityRankingStrategy":
"""
Convert a string to a Strategy enum.
"""
enum_map = {e.value: e for e in Strategy}
enum_map = {e.value: e for e in DiversityRankingStrategy}
strategy = enum_map.get(string)
if strategy is None:
msg = f"Unknown strategy '{string}'. Supported strategies are: {list(enum_map.keys())}"
raise ValueError(msg)
return strategy


class Similarity(Enum):
class DiversityRankingSimilarity(Enum):
"""
The similarity metric to use for comparing embeddings.
"""
Expand All @@ -59,11 +59,11 @@ def __str__(self) -> str:
return self.value

@staticmethod
def from_str(string: str) -> "Similarity":
def from_str(string: str) -> "DiversityRankingSimilarity":
"""
Convert a string to a Similarity enum.
"""
enum_map = {e.value: e for e in Similarity}
enum_map = {e.value: e for e in DiversityRankingSimilarity}
similarity = enum_map.get(string)
if similarity is None:
msg = f"Unknown similarity metric '{string}'. Supported metrics are: {list(enum_map.keys())}"
Expand Down Expand Up @@ -117,14 +117,14 @@ def __init__(
top_k: int = 10,
device: Optional[ComponentDevice] = None,
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
similarity: Union[str, Similarity] = "cosine",
similarity: Union[str, DiversityRankingSimilarity] = "cosine",
query_prefix: str = "",
query_suffix: str = "",
document_prefix: str = "",
document_suffix: str = "",
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
strategy: Union[str, Strategy] = "greedy_diversity_order",
strategy: Union[str, DiversityRankingStrategy] = "greedy_diversity_order",
lambda_threshold: float = 0.5,
): # pylint: disable=too-many-positional-arguments
"""
Expand Down Expand Up @@ -162,50 +162,16 @@ def __init__(
self.device = ComponentDevice.resolve_device(device)
self.token = token
self.model = None
self.similarity = self._parse_similarity(similarity)
self.similarity = DiversityRankingSimilarity.from_str(similarity) if isinstance(similarity, str) else similarity
self.query_prefix = query_prefix
self.document_prefix = document_prefix
self.query_suffix = query_suffix
self.document_suffix = document_suffix
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.strategy = self._parse_strategy(strategy)
self._check_lambda_threshold(lambda_threshold, strategy) # type: ignore
self.strategy = DiversityRankingStrategy.from_str(strategy) if isinstance(strategy, str) else strategy
self.lambda_threshold = lambda_threshold or 0.5

@staticmethod
def _parse_similarity(similarity: Union[str, Similarity]) -> "Similarity":
"""
Parse the similarity metric to use for comparing embeddings.

:param similarity:
:returns:
The Similarity enum.
"""

if isinstance(similarity, str):
return Similarity.from_str(similarity)
elif isinstance(similarity, Similarity):
return similarity
else:
return Similarity.COSINE

@staticmethod
def _parse_strategy(strategy: Union[str, Strategy]) -> "Strategy":
"""
Parse the strategy to use for diversity ranking.

:param strategy:
:returns:
The Strategy enum.
"""

if isinstance(strategy, str):
return Strategy.from_str(strategy)
elif isinstance(strategy, Strategy):
return strategy
else:
return Strategy.GREEDY_DIVERSITY_ORDER
self._check_lambda_threshold(lambda_threshold, self.strategy)
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved

def warm_up(self):
"""
Expand Down Expand Up @@ -328,7 +294,7 @@ def _embed_and_normalize(self, query, texts_to_embed):
query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined]

# Normalize embeddings to unit length for computing cosine similarity
if self.similarity == Similarity.COSINE:
if self.similarity == DiversityRankingSimilarity.COSINE:
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
return doc_embeddings, query_embedding
Expand Down Expand Up @@ -357,13 +323,10 @@ def _maximum_margin_relevance(
top_k = top_k if top_k else len(documents)

selected: List[int] = []
mmr_scores = []

tensor = query_embedding @ doc_embeddings.T
query_similarities = tensor.reshape(-1)
query_similarities_as_tensor = query_embedding @ doc_embeddings.T
query_similarities = query_similarities_as_tensor.reshape(-1)
idx = int(torch.argmax(query_similarities))
selected.append(idx)
mmr_scores.append(query_similarities[idx])
while len(selected) < top_k:
best_idx = None
best_score = -float("inf")
Expand All @@ -378,14 +341,13 @@ def _maximum_margin_relevance(
best_idx = idx
if best_idx is None:
raise ValueError("No best document found, check if the documents list contains any documents.")
mmr_scores.append(best_score)
selected.append(best_idx)

return [documents[i] for i in selected]

@staticmethod
def _check_lambda_threshold(lambda_threshold: float, strategy: Strategy):
if (strategy == Strategy.MAXIMUM_MARGIN_RELEVANCE) and not 0 <= lambda_threshold <= 1:
def _check_lambda_threshold(lambda_threshold: float, strategy: DiversityRankingStrategy):
if (strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE) and not 0 <= lambda_threshold <= 1:
raise ValueError(f"lambda_threshold must be between 0 and 1, but got {lambda_threshold}.")

@component.output_types(documents=List[Document])
Expand Down Expand Up @@ -423,13 +385,10 @@ def run(

if top_k is None:
top_k = self.top_k
elif top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
elif top_k > len(documents):
raise ValueError(f"top_k must be <= number of documents, but got {top_k}")
elif not 0 < top_k <= len(documents):
raise ValueError(f"top_k must be between 1 and {len(documents)}, but got {top_k}")

if self.strategy == Strategy.MAXIMUM_MARGIN_RELEVANCE:
# use lambda_threshold provided at runtime or the one set during initialization
if self.strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE:
if lambda_threshold is None:
lambda_threshold = self.lambda_threshold
self._check_lambda_threshold(lambda_threshold, self.strategy)
Expand Down
40 changes: 17 additions & 23 deletions test/components/rankers/test_sentence_transformers_diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import pytest
import torch

from haystack import Document
from haystack import Document, Pipeline
from haystack.components.rankers import SentenceTransformersDiversityRanker
from haystack.components.rankers.sentence_transformers_diversity import Similarity, Strategy
from haystack.components.rankers.sentence_transformers_diversity import (
DiversityRankingSimilarity,
DiversityRankingStrategy,
)
from haystack.utils import ComponentDevice
from haystack.utils.auth import Secret

Expand All @@ -28,7 +31,7 @@ def test_init(self):
assert component.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
assert component.top_k == 10
assert component.device == ComponentDevice.resolve_device(None)
assert component.similarity == Similarity.COSINE
assert component.similarity == DiversityRankingSimilarity.COSINE
assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
assert component.query_prefix == ""
assert component.document_prefix == ""
Expand All @@ -54,7 +57,7 @@ def test_init_with_custom_parameters(self):
assert component.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4"
assert component.top_k == 5
assert component.device == ComponentDevice.from_str("cuda:0")
assert component.similarity == Similarity.DOT_PRODUCT
assert component.similarity == DiversityRankingSimilarity.DOT_PRODUCT
assert component.token == Secret.from_token("fake-api-token")
assert component.query_prefix == "query:"
assert component.document_prefix == "document:"
Expand Down Expand Up @@ -109,7 +112,7 @@ def test_from_dict(self):
assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
assert ranker.top_k == 10
assert ranker.device == ComponentDevice.resolve_device(None)
assert ranker.similarity == Similarity.COSINE
assert ranker.similarity == DiversityRankingSimilarity.COSINE
assert ranker.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
assert ranker.query_prefix == ""
assert ranker.document_prefix == ""
Expand Down Expand Up @@ -140,7 +143,7 @@ def test_from_dict_none_device(self):
assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
assert ranker.top_k == 10
assert ranker.device == ComponentDevice.resolve_device(None)
assert ranker.similarity == Similarity.COSINE
assert ranker.similarity == DiversityRankingSimilarity.COSINE
assert ranker.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
assert ranker.query_prefix == ""
assert ranker.document_prefix == ""
Expand All @@ -159,7 +162,7 @@ def test_from_dict_no_default_parameters(self):
assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
assert ranker.top_k == 10
assert ranker.device == ComponentDevice.resolve_device(None)
assert ranker.similarity == Similarity.COSINE
assert ranker.similarity == DiversityRankingSimilarity.COSINE
assert ranker.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
assert ranker.query_prefix == ""
assert ranker.document_prefix == ""
Expand Down Expand Up @@ -223,7 +226,7 @@ def test_from_dict_with_custom_init_parameters(self):
assert ranker.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4"
assert ranker.top_k == 5
assert ranker.device == ComponentDevice.from_str("cuda:0")
assert ranker.similarity == Similarity.DOT_PRODUCT
assert ranker.similarity == DiversityRankingSimilarity.DOT_PRODUCT
assert ranker.token == Secret.from_env_var("ENV_VAR", strict=False)
assert ranker.query_prefix == "query:"
assert ranker.document_prefix == "document:"
Expand Down Expand Up @@ -376,7 +379,7 @@ def test_run_negative_top_k(self, similarity):
query = "test"
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]

with pytest.raises(ValueError, match="top_k must be > 0, but got"):
with pytest.raises(ValueError, match="top_k must be between"):
ranker.run(query=query, documents=documents, top_k=-5)

@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
Expand Down Expand Up @@ -563,20 +566,11 @@ def test_pipeline_serialise_deserialise(self):
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine", top_k=5
)
ranker_serialized = ranker.to_dict()
ranker_deserialized = SentenceTransformersDiversityRanker.from_dict(ranker_serialized)
assert ranker.model_name_or_path == ranker_deserialized.model_name_or_path
assert ranker.top_k == ranker_deserialized.top_k
assert ranker.device == ranker_deserialized.device
assert ranker.similarity == ranker_deserialized.similarity
assert ranker.token == ranker_deserialized.token
assert ranker.query_prefix == ranker_deserialized.query_prefix
assert ranker.document_prefix == ranker_deserialized.document_prefix
assert ranker.query_suffix == ranker_deserialized.query_suffix
assert ranker.document_suffix == ranker_deserialized.document_suffix
assert ranker.meta_fields_to_embed == ranker_deserialized.meta_fields_to_embed
assert ranker.embedding_separator == ranker_deserialized.embedding_separator
assert ranker.strategy == ranker_deserialized.strategy

pipe = Pipeline()
pipe.add_component("ranker", ranker)
pipe_serialized = pipe.to_dict()
assert Pipeline.from_dict(pipe_serialized) == pipe
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
Expand Down
Loading