-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add retrieval evaluator for any modality
- Loading branch information
1 parent
5bb6331
commit fb04e0c
Showing
1 changed file
with
320 additions
and
0 deletions.
There are no files selected for viewing
320 changes: 320 additions & 0 deletions
320
llama-index-core/llama_index/core/evaluation/retrieval/omni_modal.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,320 @@ | ||
import asyncio | ||
import json | ||
from dataclasses import asdict, dataclass | ||
from typing import Any, Callable, Collection, Dict, Generic, List, Optional, Tuple | ||
|
||
from llama_index.core.async_utils import asyncio_run | ||
from llama_index.core.bridge.pydantic import BaseModel, Field | ||
from llama_index.core.embeddings.omni_modal_base import KD, KQ | ||
from llama_index.core.evaluation.retrieval.metrics import resolve_metrics | ||
from llama_index.core.evaluation.retrieval.metrics_base import ( | ||
BaseRetrievalMetric, | ||
RetrievalMetricResult, | ||
) | ||
from llama_index.core.indices.omni_modal import OmniModalVectorIndexRetriever | ||
from llama_index.core.postprocessor.types import BaseNodePostprocessor | ||
from llama_index.core.schema import BaseNode, QueryBundle | ||
|
||
|
||
class OmniModalEmbeddingQAFinetuneDataset(BaseModel): | ||
"""Omni-Modal Embedding QA Finetuning Dataset. | ||
Args: | ||
queries (Dict[str, QueryBundle]): Dict id -> query. | ||
corpus (Dict[str, BaseNode]): Dict id -> string. | ||
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids. | ||
""" | ||
|
||
queries: Dict[str, QueryBundle] # dict id -> query | ||
corpus: Dict[str, BaseNode] # dict id -> string | ||
relevant_docs: Dict[str, List[str]] # query id -> list of doc ids | ||
|
||
@property | ||
def query_docid_pairs(self) -> List[Tuple[QueryBundle, List[str]]]: | ||
"""Get query, relevant doc ids.""" | ||
return [ | ||
(query, self.relevant_docs[query_id]) | ||
for query_id, query in self.queries.items() | ||
] | ||
|
||
def save_json(self, path: str) -> None: | ||
"""Save json.""" | ||
data = { | ||
"queries": {k: asdict(v) for k, v in self.queries.items()}, | ||
"corpus": {k: v.to_dict() for k, v in self.corpus.items()}, | ||
"relevant_docs": self.relevant_docs, | ||
} | ||
|
||
with open(path, "w") as f: | ||
json.dump(data, f, indent=4) | ||
|
||
@classmethod | ||
def from_json( | ||
cls, | ||
path: str, | ||
*, | ||
query_loader: Callable[[Dict[str, Any]], QueryBundle] = QueryBundle.from_dict, | ||
node_loader: Callable[[Dict[str, Any]], BaseNode] = BaseNode.from_dict, | ||
) -> "OmniModalEmbeddingQAFinetuneDataset": | ||
"""Load json.""" | ||
with open(path) as f: | ||
data = json.load(f) | ||
|
||
return cls( | ||
queries={k: query_loader(v) for k, v in data["queries"].items()}, | ||
corpus={k: node_loader(v) for k, v in data["corpus"].items()}, | ||
relevant_docs=data["relevant_docs"], | ||
) | ||
|
||
|
||
class OmniModalRetrievalEvalResult(BaseModel): | ||
"""Retrieval eval result. | ||
NOTE: this abstraction might change in the future. | ||
Attributes: | ||
query_bundle (QueryBundle): Input query | ||
expected_ids (List[str]): Expected ids | ||
retrieved_ids (List[str]): Retrieved ids | ||
metric_dict (Dict[str, BaseRetrievalMetric]): \ | ||
Metric dictionary for the evaluation | ||
""" | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
query_bundle: QueryBundle = Field(..., description="Input query") | ||
query_type: str = Field(..., description="Modality type of query") | ||
expected_ids: List[str] = Field(..., description="Expected ids") | ||
expected_docs: Optional[List[BaseNode]] = Field( | ||
default=None, | ||
description="Expected documents associated with nodes provided in `expected_ids`", | ||
) | ||
retrieved_ids: List[str] = Field(..., description="Retrieved ids") | ||
retrieved_docs: List[BaseNode] = Field(..., description="Retrieved documents") | ||
doc_types: Optional[Collection[str]] = Field( | ||
default=None, | ||
description="Modality types of documents to match. `None` means all.", | ||
) | ||
metric_dict: Dict[str, RetrievalMetricResult] = Field( | ||
..., description="Metric dictionary for the evaluation" | ||
) | ||
|
||
@property | ||
def metric_vals_dict(self) -> Dict[str, float]: | ||
"""Dictionary of metric values.""" | ||
return {k: v.score for k, v in self.metric_dict.items()} | ||
|
||
def __str__(self) -> str: | ||
"""String representation.""" | ||
return f"Query: {self.query_bundle}\n" f"Metrics: {self.metric_vals_dict!s}\n" | ||
|
||
|
||
@dataclass | ||
class OmniModalRetrievalEvaluator(Generic[KD, KQ]): | ||
"""Omni-Modal Retrieval Evaluator class.""" | ||
|
||
metrics: List[BaseRetrievalMetric] | ||
"""List of metrics to evaluate""" | ||
|
||
retriever: OmniModalVectorIndexRetriever[KD, KQ] | ||
"""Retriever to evaluate""" | ||
|
||
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None | ||
""""Optional post-processor""" | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
@classmethod | ||
def from_metric_names( | ||
cls, metric_names: List[str], **kwargs: Any | ||
) -> "OmniModalRetrievalEvaluator": | ||
"""Create evaluator from metric names. | ||
Args: | ||
metric_names (List[str]): List of metric names | ||
**kwargs: Additional arguments for the evaluator | ||
""" | ||
metric_types = resolve_metrics(metric_names) | ||
return cls(metrics=[metric() for metric in metric_types], **kwargs) | ||
|
||
async def _aget_retrieved_ids_and_docs( | ||
self, | ||
query_bundle: QueryBundle, | ||
*, | ||
query_type: KQ, | ||
# Defaults to all document modalities | ||
doc_types: Optional[Collection[KD]] = None, | ||
) -> Tuple[List[str], List[BaseNode]]: | ||
"""Get retrieved ids and documents.""" | ||
scored_nodes = await self.retriever.aretrieve_multi_modal( | ||
query_bundle, | ||
query_type=query_type, | ||
doc_types=doc_types, | ||
) | ||
|
||
if self.node_postprocessors: | ||
for node_postprocessor in self.node_postprocessors: | ||
scored_nodes = node_postprocessor.postprocess_nodes( | ||
scored_nodes, | ||
query_bundle=query_bundle, | ||
) | ||
|
||
retrieved_nodes = [scored_node.node for scored_node in scored_nodes] | ||
|
||
return ( | ||
[node.node_id for node in retrieved_nodes], | ||
retrieved_nodes, | ||
) | ||
|
||
def _compute_metrics( | ||
self, | ||
query_bundle: QueryBundle, | ||
*, | ||
retrieved_ids: List[str], | ||
retrieved_docs: List[BaseNode], | ||
expected_ids: List[str], | ||
expected_docs: Optional[List[BaseNode]], | ||
) -> Dict[str, RetrievalMetricResult]: | ||
return { | ||
metric.metric_name: metric.compute( | ||
expected_ids=expected_ids, | ||
retrieved_ids=retrieved_ids, | ||
# Assume that the metric does not require this | ||
query=None, | ||
expected_texts=None, | ||
retrieved_texts=None, | ||
) | ||
for metric in self.metrics | ||
} | ||
|
||
def evaluate( | ||
self, | ||
query_bundle: QueryBundle, | ||
query_type: KQ, | ||
expected_ids: List[str], | ||
expected_docs: Optional[List[BaseNode]] = None, | ||
# Defaults to all document modalities | ||
doc_types: Optional[Collection[KD]] = None, | ||
) -> OmniModalRetrievalEvalResult: | ||
"""Run evaluation results with query string and expected ids. | ||
Args: | ||
query (str): Query string | ||
expected_ids (List[str]): Expected ids | ||
Returns: | ||
OmniModalRetrievalEvalResult: Evaluation result | ||
""" | ||
return asyncio_run( | ||
self.aevaluate( | ||
query_bundle=query_bundle, | ||
query_type=query_type, | ||
expected_ids=expected_ids, | ||
expected_docs=expected_docs, | ||
doc_types=doc_types, | ||
) | ||
) | ||
|
||
async def aevaluate( | ||
self, | ||
query_bundle: QueryBundle, | ||
query_type: KQ, | ||
expected_ids: List[str], | ||
expected_docs: Optional[List[BaseNode]] = None, | ||
# Defaults to all document modalities | ||
doc_types: Optional[Collection[KD]] = None, | ||
) -> OmniModalRetrievalEvalResult: | ||
"""Run evaluation with query string, retrieved contexts, | ||
and generated response string. | ||
Subclasses can override this method to provide custom evaluation logic and | ||
take in additional arguments. | ||
""" | ||
retrieved_ids, retrieved_docs = await self._aget_retrieved_ids_and_docs( | ||
query_bundle, | ||
query_type=query_type, | ||
doc_types=doc_types, | ||
) | ||
|
||
return OmniModalRetrievalEvalResult( | ||
query_bundle=query_bundle, | ||
query_type=query_type, | ||
expected_ids=expected_ids, | ||
expected_docs=expected_docs, | ||
retrieved_ids=retrieved_ids, | ||
retrieved_docs=retrieved_docs, | ||
doc_types=doc_types, | ||
metric_dict=self._compute_metrics( | ||
query_bundle, | ||
retrieved_ids=retrieved_ids, | ||
retrieved_docs=retrieved_docs, | ||
expected_ids=expected_ids, | ||
expected_docs=expected_docs, | ||
), | ||
) | ||
|
||
async def aevaluate_dataset( | ||
self, | ||
dataset: OmniModalEmbeddingQAFinetuneDataset, | ||
query_type: KQ, | ||
*, | ||
# Defaults to all document modalities | ||
doc_types: Optional[Collection[KD]] = None, | ||
workers: int = 2, | ||
show_progress: bool = False, | ||
) -> List[OmniModalRetrievalEvalResult]: | ||
"""Run evaluation with dataset.""" | ||
semaphore = asyncio.Semaphore(workers) | ||
|
||
async def eval_worker( | ||
query_bundle: QueryBundle, | ||
expected_ids: List[str], | ||
) -> OmniModalRetrievalEvalResult: | ||
async with semaphore: | ||
ret_ids, ret_docs = await self._aget_retrieved_ids_and_docs( | ||
query_bundle, | ||
query_type=query_type, | ||
doc_types=doc_types, | ||
) | ||
|
||
assert all(doc_id in dataset.corpus for doc_id in ret_ids), ( | ||
"Some retrieved documents do not belong in the dataset. " | ||
"Make sure the dataset and retriever are built from the " | ||
"same index." | ||
) | ||
|
||
return OmniModalRetrievalEvalResult( | ||
query_bundle=query_bundle, | ||
query_type=query_type, | ||
expected_ids=expected_ids, | ||
expected_docs=None, | ||
retrieved_ids=ret_ids, | ||
retrieved_docs=ret_docs, | ||
doc_types=doc_types, | ||
metric_dict=self._compute_metrics( | ||
query_bundle, | ||
retrieved_ids=ret_ids, | ||
retrieved_docs=ret_docs, | ||
expected_ids=expected_ids, | ||
expected_docs=None, | ||
), | ||
) | ||
|
||
response_jobs = [] | ||
for query_id, query in dataset.queries.items(): | ||
expected_ids = dataset.relevant_docs[query_id] | ||
response_jobs.append(eval_worker(query, expected_ids)) | ||
if show_progress: | ||
from tqdm.asyncio import tqdm_asyncio | ||
|
||
eval_results = await tqdm_asyncio.gather(*response_jobs) | ||
else: | ||
eval_results = await asyncio.gather(*response_jobs) | ||
|
||
return eval_results |