Skip to content

Commit

Permalink
evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyiXia committed Nov 11, 2024
1 parent 1df8b36 commit 3f07173
Show file tree
Hide file tree
Showing 15 changed files with 342 additions and 2 deletions.
2 changes: 1 addition & 1 deletion FlagEmbedding/abc/evaluation/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDic
return self._load_remote_corpus(dataset_name=dataset_name)

def load_qrels(self, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
"""Load the corpus from the dataset.
"""Load the qrels from the dataset.
Args:
dataset_name (Optional[str], optional): Name of the dataset. Defaults to :data:`None`.
Expand Down
47 changes: 47 additions & 0 deletions FlagEmbedding/evaluation/miracl/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,28 @@


class MIRACLEvalDataLoader(AbsEvalDataLoader):
"""
Data loader class for MIRACL.
"""
def available_dataset_names(self) -> List[str]:
"""
Get the available dataset names.
Returns:
List[str]: All the available dataset names.
"""
return ["ar", "bn", "en", "es", "fa", "fi", "fr", "hi", "id", "ja", "ko", "ru", "sw", "te", "th", "zh", "de", "yo"]

def available_splits(self, dataset_name: str) -> List[str]:
"""
Get the avaialble splits.
Args:
dataset_name (str): Dataset name.
Returns:
List[str]: All the available splits for the dataset.
"""
if dataset_name in ["de", "yo"]:
return ["dev"]
else:
Expand All @@ -25,6 +43,15 @@ def _load_remote_corpus(
dataset_name: str,
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
"""Load the corpus dataset from HF.
Args:
dataset_name (str): Name of the dataset.
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
Returns:
datasets.DatasetDict: Loaded datasets instance of corpus.
"""
corpus = datasets.load_dataset(
"miracl/miracl-corpus", dataset_name,
cache_dir=self.cache_dir,
Expand Down Expand Up @@ -60,6 +87,16 @@ def _load_remote_qrels(
split: str = 'dev',
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
"""Load the qrels from HF.
Args:
dataset_name (str): Name of the dataset.
split (str, optional): Split of the dataset. Defaults to ``'dev'``.
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
Returns:
datasets.DatasetDict: Loaded datasets instance of qrel.
"""
endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/miracl/miracl"
qrels_download_url = f"{endpoint}/resolve/main/miracl-v1.0-{dataset_name}/qrels/qrels.miracl-v1.0-{dataset_name}-{split}.tsv"

Expand Down Expand Up @@ -101,6 +138,16 @@ def _load_remote_queries(
split: str = 'dev',
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
"""Load the queries from HF.
Args:
dataset_name (str): Name of the dataset.
split (str, optional): Split of the dataset. Defaults to ``'dev'``.
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
Returns:
datasets.DatasetDict: Loaded datasets instance of queries.
"""
endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/miracl/miracl"
queries_download_url = f"{endpoint}/resolve/main/miracl-v1.0-{dataset_name}/topics/topics.miracl-v1.0-{dataset_name}-{split}.tsv"

Expand Down
8 changes: 8 additions & 0 deletions FlagEmbedding/evaluation/miracl/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@


class MIRACLEvalRunner(AbsEvalRunner):
"""
Evaluation runner of MIRACL.
"""
def load_data_loader(self) -> MIRACLEvalDataLoader:
"""Load the data loader instance by args.
Returns:
MIRACLEvalDataLoader: The MIRACL data loader instance.
"""
data_loader = MIRACLEvalDataLoader(
eval_name=self.eval_args.eval_name,
dataset_dir=self.eval_args.dataset_dir,
Expand Down
2 changes: 2 additions & 0 deletions FlagEmbedding/evaluation/mkqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
)

from .data_loader import MKQAEvalDataLoader
from .evaluator import MKQAEvaluator
from .runner import MKQAEvalRunner

__all__ = [
"MKQAEvalArgs",
"MKQAEvalModelArgs",
"MKQAEvalRunner",
"MKQAEvalDataLoader",
"MKQAEvaluator"
]
59 changes: 59 additions & 0 deletions FlagEmbedding/evaluation/mkqa/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,39 @@


class MKQAEvalDataLoader(AbsEvalDataLoader):
"""
Data loader class for MKQA.
"""
def available_dataset_names(self) -> List[str]:
"""
Get the available dataset names.
Returns:
List[str]: All the available dataset names.
"""
return ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']

def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
"""
Get the avaialble splits.
Args:
dataset_name (str): Dataset name.
Returns:
List[str]: All the available splits for the dataset.
"""
return ["test"]

def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDict:
"""Load the corpus.
Args:
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
Returns:
datasets.DatasetDict: Loaded datasets instance of corpus.
"""
if self.dataset_dir is not None:
# same corpus for all languages
save_dir = self.dataset_dir
Expand All @@ -28,6 +54,19 @@ def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDic
return self._load_remote_corpus(dataset_name=dataset_name)

def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
"""Try to load qrels from local datasets.
Args:
save_dir (str): Directory that save the data files.
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
split (str, optional): Split of the dataset. Defaults to ``'test'``.
Raises:
ValueError: No local qrels found, will try to download from remote.
Returns:
datasets.DatasetDict: Loaded datasets instance of qrels.
"""
checked_split = self.check_splits(split)
if len(checked_split) == 0:
raise ValueError(f"Split {split} not found in the dataset.")
Expand Down Expand Up @@ -96,6 +135,16 @@ def _load_remote_qrels(
split: str = 'test',
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
"""Load remote qrels from HF.
Args:
dataset_name (str): Name of the dataset.
split (str, optional): Split of the dataset. Defaults to ``'test'``.
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
Returns:
datasets.DatasetDict: Loaded datasets instance of qrel.
"""
endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/Shitao/bge-m3-data"
queries_download_url = f"{endpoint}/resolve/main/MKQA_test-data.zip"

Expand Down Expand Up @@ -137,6 +186,16 @@ def _load_remote_queries(
split: str = 'test',
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
"""Load the queries from HF.
Args:
dataset_name (str): Name of the dataset.
split (str, optional): Split of the dataset. Defaults to ``'test'``.
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
Returns:
datasets.DatasetDict: Loaded datasets instance of queries.
"""
endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/Shitao/bge-m3-data"
queries_download_url = f"{endpoint}/resolve/main/MKQA_test-data.zip"

Expand Down
30 changes: 30 additions & 0 deletions FlagEmbedding/evaluation/mkqa/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,25 @@


class MKQAEvaluator(AbsEvaluator):
"""
The evaluator class of MKQA.
"""
def get_corpus_embd_save_dir(
self,
retriever_name: str,
corpus_embd_save_dir: Optional[str] = None,
dataset_name: Optional[str] = None
):
"""Get the directory to save the corpus embedding.
Args:
retriever_name (str): Name of the retriever.
corpus_embd_save_dir (Optional[str], optional): Directory to save the corpus embedding. Defaults to ``None``.
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
Returns:
str: The final directory to save the corpus embedding.
"""
if corpus_embd_save_dir is not None:
# Save the corpus embeddings in the same directory for all dataset_name
corpus_embd_save_dir = os.path.join(corpus_embd_save_dir, retriever_name)
Expand All @@ -24,6 +37,15 @@ def evaluate_results(
search_results_save_dir: str,
k_values: List[int] = [1, 3, 5, 10, 100, 1000]
):
"""Compute the metrics and get the eval results.
Args:
search_results_save_dir (str): Directory that saves the search results.
k_values (List[int], optional): Cutoffs. Defaults to ``[1, 3, 5, 10, 100, 1000]``.
Returns:
dict: The evaluation results.
"""
eval_results_dict = {}

corpus = self.data_loader.load_corpus()
Expand Down Expand Up @@ -70,6 +92,14 @@ def compute_metrics(
):
"""
Compute Recall@k for QA task. The definition of recall in QA task is different from the one in IR task. Please refer to the paper of RocketQA: https://aclanthology.org/2021.naacl-main.466.pdf.
Args:
corpus_dict (Dict[str, str]): Dictionary of the corpus with doc id and contents.
qrels (Dict[str, List[str]]): Relevances of queries and passage.
search_results (Dict[str, Dict[str, float]]): Search results of the model to evaluate.
Returns:
dict: The model's scores of the metrics.
"""
contexts = []
answers = []
Expand Down
13 changes: 13 additions & 0 deletions FlagEmbedding/evaluation/mkqa/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@


class MKQAEvalRunner(AbsEvalRunner):
"""
Evaluation runner of MKQA.
"""
def load_data_loader(self) -> MKQAEvalDataLoader:
"""Load the data loader instance by args.
Returns:
MKQAEvalDataLoader: The MKQA data loader instance.
"""
data_loader = MKQAEvalDataLoader(
eval_name=self.eval_args.eval_name,
dataset_dir=self.eval_args.dataset_dir,
Expand All @@ -16,6 +24,11 @@ def load_data_loader(self) -> MKQAEvalDataLoader:
return data_loader

def load_evaluator(self) -> MKQAEvaluator:
"""Load the evaluator instance by args.
Returns:
MKQAEvaluator: The MKQA evaluator instance.
"""
evaluator = MKQAEvaluator(
eval_name=self.eval_args.eval_name,
data_loader=self.data_loader,
Expand Down
6 changes: 5 additions & 1 deletion docs/source/API/evaluation.rst
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
Evaluation
==========
==========

.. toctree::
evaluation/miracl
evaluation/mkqa
48 changes: 48 additions & 0 deletions docs/source/API/evaluation/miracl.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
MIRACL
======

`MIRACL <https://project-miracl.github.io/>`_ (Multilingual Information Retrieval Across a Continuum of Languages)
is an WSDM 2023 Cup challenge that focuses on search across 18 different languages.
They release a multilingual retrieval dataset containing the train and dev set for 16 "known languages" and only dev set for 2 "surprise languages".
The topics are generated by native speakers of each language, who also label the relevance between the topics and a given document list.
You can found the `dataset <https://huggingface.co/datasets/miracl/miracl-corpus>`_ on HuggingFace.

You can evaluate model's performance on MIRACL simply by running our provided shell script:

.. code:: bash
chmod +x /examples/evaluation/miracl/eval_miracl.sh
./examples/evaluation/miracl/eval_miracl.sh
Or by running:

.. code:: bash
python -m FlagEmbedding.evaluation.miracl \
--eval_name miracl \
--dataset_dir ./miracl/data \
--dataset_names bn hi sw te th yo \
--splits dev \
--corpus_embd_save_dir ./miracl/corpus_embd \
--output_dir ./miracl/search_results \
--search_top_k 1000 \
--rerank_top_k 100 \
--cache_path /root/.cache/huggingface/hub \
--overwrite False \
--k_values 10 100 \
--eval_output_method markdown \
--eval_output_path ./miracl/miracl_eval_results.md \
--eval_metrics ndcg_at_10 recall_at_100 \
--embedder_name_or_path BAAI/bge-m3 \
--reranker_name_or_path BAAI/bge-reranker-v2-m3 \
--devices cuda:0 cuda:1 \
--cache_dir /root/.cache/huggingface/hub \
--reranker_max_length 1024
change the embedder, reranker, devices and cache directory to your preference.

.. toctree::
:hidden:

miracl/data_loader
miracl/runner
13 changes: 13 additions & 0 deletions docs/source/API/evaluation/miracl/data_loader.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
data_loader
===========

.. autoclass:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader

Methods
-------

.. automethod:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader.available_dataset_names
.. automethod:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader.available_splits
.. automethod:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader._load_remote_corpus
.. automethod:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader._load_remote_qrels
.. automethod:: FlagEmbedding.evaluation.miracl.MIRACLEvalDataLoader._load_remote_queries
5 changes: 5 additions & 0 deletions docs/source/API/evaluation/miracl/runner.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
runner
======

.. autoclass:: FlagEmbedding.evaluation.miracl.MIRACLEvalRunner
:members:
Loading

0 comments on commit 3f07173

Please sign in to comment.