diff --git a/mteb/abstasks/AbsTaskRetrieval.py b/mteb/abstasks/AbsTaskRetrieval.py index 41b070250..5345a50a5 100644 --- a/mteb/abstasks/AbsTaskRetrieval.py +++ b/mteb/abstasks/AbsTaskRetrieval.py @@ -132,42 +132,98 @@ def load_data(self, **kwargs): return self.corpus, self.queries, self.relevant_docs = {}, {}, {} self.instructions, self.top_ranked = None, None - dataset_path = self.metadata_dict["dataset"]["path"] + dataset_path = self.metadata.dataset["path"] hf_repo_qrels = ( dataset_path + "-qrels" if "clarin-knext" in dataset_path else None ) - for split in kwargs.get("eval_splits", self.metadata_dict["eval_splits"]): - corpus, queries, qrels, instructions, top_ranked = HFDataLoader( - hf_repo=dataset_path, - hf_repo_qrels=hf_repo_qrels, - streaming=False, - keep_in_memory=False, - trust_remote_code=self.metadata_dict["dataset"].get( - "trust_remote_code", False - ), - ).load(split=split) - # Conversion from DataSet - queries = {query["id"]: query["text"] for query in queries} - corpus = { - doc["id"]: doc.get("title", "") + " " + doc["text"] for doc in corpus - } - self.corpus[split], self.queries[split], self.relevant_docs[split] = ( - corpus, - queries, - qrels, - ) + if not self.is_multilingual: + for split in kwargs.get("eval_splits", self.metadata.eval_splits): + corpus, queries, qrels, instructions, top_ranked = HFDataLoader( + hf_repo=dataset_path, + hf_repo_qrels=hf_repo_qrels, + streaming=False, + keep_in_memory=False, + trust_remote_code=self.metadata.dataset.get( + "trust_remote_code", False + ), + ).load(split=split) + # Conversion from DataSet + queries = {query["id"]: query["text"] for query in queries} + corpus = { + doc["id"]: doc.get("title", "") + " " + doc["text"] + for doc in corpus + } + self.corpus[split], self.queries[split], self.relevant_docs[split] = ( + corpus, + queries, + qrels, + ) - # optional args - if instructions: - self.instructions = { - split: { - inst["query-id"]: inst["instruction"] for inst in instructions + # optional args + if instructions: + self.instructions = { + split: { + inst["query-id"]: inst["instruction"] + for inst in instructions + } } - } - if top_ranked: - self.top_ranked = { - split: {tr["query-id"]: tr["corpus-ids"] for tr in top_ranked} - } + if top_ranked: + self.top_ranked = { + split: {tr["query-id"]: tr["corpus-ids"] for tr in top_ranked} + } + else: + if not isinstance(self.metadata.eval_langs, dict): + raise ValueError("eval_langs must be a dict for multilingual tasks") + for lang in self.metadata.eval_langs: + self.corpus[lang], self.queries[lang], self.relevant_docs[lang] = ( + {}, + {}, + {}, + ) + for split in kwargs.get("eval_splits", self.metadata.eval_splits): + corpus, queries, qrels, instructions, top_ranked = HFDataLoader( + hf_repo=dataset_path, + hf_repo_qrels=hf_repo_qrels, + streaming=False, + keep_in_memory=False, + trust_remote_code=self.metadata.dataset.get( + "trust_remote_code", False + ), + ).load(split=split, config=lang) + # Conversion from DataSet + queries = {query["id"]: query["text"] for query in queries} + corpus = { + doc["id"]: doc.get("title", "") + " " + doc["text"] + for doc in corpus + } + ( + self.corpus[lang][split], + self.queries[lang][split], + self.relevant_docs[lang][split], + ) = ( + corpus, + queries, + qrels, + ) + + # optional args + if instructions: + if self.instructions is None: + self.instructions = {} + self.instructions[lang] = { + split: { + inst["query-id"]: inst["instruction"] + for inst in instructions + } + } + if top_ranked: + if self.top_ranked is None: + self.top_ranked = {} + self.top_ranked = { + split: { + tr["query-id"]: tr["corpus-ids"] for tr in top_ranked + } + } self.data_loaded = True diff --git a/mteb/abstasks/dataloaders.py b/mteb/abstasks/dataloaders.py index 25a6150a5..a8c165007 100644 --- a/mteb/abstasks/dataloaders.py +++ b/mteb/abstasks/dataloaders.py @@ -93,7 +93,7 @@ def check(fIn: str, ext: str): raise ValueError(f"File {fIn} must be present with extension {ext}") def load( - self, split: str = "test" + self, split: str = "test", config: str | None = None ) -> tuple[ dict[str, dict[str, str]], # corpus dict[str, str | list[str]], # queries @@ -118,33 +118,37 @@ def load( if not len(self.corpus): logger.info("Loading Corpus...") - self._load_corpus() + self._load_corpus(config) logger.info("Loaded %d %s Documents.", len(self.corpus), split.upper()) logger.info("Doc Example: %s", self.corpus[0]) if not len(self.queries): logger.info("Loading Queries...") - self._load_queries() + self._load_queries(config) - if "top_ranked" in configs or (not self.hf_repo and self.top_ranked_file): + if any(c.endswith("top_ranked") for c in configs) in configs or ( + not self.hf_repo and self.top_ranked_file + ): logger.info("Loading Top Ranked") - self._load_top_ranked() + self._load_top_ranked(config) logger.info( f"Top ranked loaded: {len(self.top_ranked) if self.top_ranked else 0}" ) else: self.top_ranked = None - if "instruction" in configs or (not self.hf_repo and self.instructions_file): + if any(c.endswith("instruction") for c in configs) or ( + not self.hf_repo and self.instructions_file + ): logger.info("Loading Instructions") - self._load_instructions() + self._load_instructions(config) logger.info( f"Instructions loaded: {len(self.instructions) if self.instructions else 0}" ) else: self.instructions = None - self._load_qrels(split) + self._load_qrels(split, config) # filter queries with no qrels qrels_dict = defaultdict(dict) @@ -159,23 +163,24 @@ def qrels_dict_init(row): return self.corpus, self.queries, self.qrels, self.instructions, self.top_ranked - def load_corpus(self) -> dict[str, dict[str, str]]: + def load_corpus(self, config: str | None = None) -> dict[str, dict[str, str]]: if not self.hf_repo: self.check(fIn=self.corpus_file, ext="jsonl") if not len(self.corpus): logger.info("Loading Corpus...") - self._load_corpus() + self._load_corpus(config) logger.info("Loaded %d %s Documents.", len(self.corpus)) logger.info("Doc Example: %s", self.corpus[0]) return self.corpus - def _load_corpus(self): + def _load_corpus(self, config: str | None = None): + config = f"{config}-corpus" if config is not None else "corpus" if self.hf_repo: corpus_ds = load_dataset( self.hf_repo, - "corpus", + config, keep_in_memory=self.keep_in_memory, streaming=self.streaming, trust_remote_code=self.trust_remote_code, @@ -200,11 +205,12 @@ def _load_corpus(self): ) self.corpus = corpus_ds - def _load_queries(self): + def _load_queries(self, config: str | None = None): + config = f"{config}-queries" if config is not None else "queries" if self.hf_repo: queries_ds = load_dataset( self.hf_repo, - "queries", + config, keep_in_memory=self.keep_in_memory, streaming=self.streaming, trust_remote_code=self.trust_remote_code, @@ -224,10 +230,12 @@ def _load_queries(self): ) self.queries = queries_ds - def _load_qrels(self, split): + def _load_qrels(self, split: str, config: str | None = None): + config = f"{config}-qrels" if config is not None else None if self.hf_repo: qrels_ds = load_dataset( self.hf_repo_qrels, + name=config, keep_in_memory=self.keep_in_memory, streaming=self.streaming, trust_remote_code=self.trust_remote_code, @@ -249,11 +257,12 @@ def _load_qrels(self, split): qrels_ds = qrels_ds.cast(features) self.qrels = qrels_ds - def _load_top_ranked(self): + def _load_top_ranked(self, config: str | None = None): + config = f"top_ranked-{config}" if config is not None else "top_ranked" if self.hf_repo: top_ranked_ds = load_dataset( self.hf_repo, - "top_ranked", + config, keep_in_memory=self.keep_in_memory, streaming=self.streaming, trust_remote_code=self.trust_remote_code, @@ -293,11 +302,12 @@ def _load_top_ranked(self): ) self.top_ranked = top_ranked_ds - def _load_instructions(self): + def _load_instructions(self, config: str | None = None): + config = f"instruction-{config}" if config is not None else "instruction" if self.hf_repo: instructions_ds = load_dataset( self.hf_repo, - "instruction", + config, keep_in_memory=self.keep_in_memory, streaming=self.streaming, trust_remote_code=self.trust_remote_code,