diff --git a/llmebench/datasets/ANERcorp.py b/llmebench/datasets/ANERcorp.py index 4e36cd7b..6fe3fa2d 100644 --- a/llmebench/datasets/ANERcorp.py +++ b/llmebench/datasets/ANERcorp.py @@ -53,7 +53,7 @@ def metadata(): "test": "data/sequence_tagging_ner_pos_etc/NER/AnerCorp/ANERCorp_CamelLab_test.txt", "train": "data/sequence_tagging_ner_pos_etc/NER/AnerCorp/ANERCorp_CamelLab_train.txt", }, - "task_type": TaskType.Labeling, + "task_type": TaskType.SequenceLabeling, "class_labels": [ "B-PERS", "I-PERS", diff --git a/llmebench/datasets/Aqmar.py b/llmebench/datasets/Aqmar.py index c3e04d67..ce66e1f9 100644 --- a/llmebench/datasets/Aqmar.py +++ b/llmebench/datasets/Aqmar.py @@ -77,7 +77,7 @@ def metadata(): "path": "data/sequence_tagging_ner_pos_etc/NER/aqmar/AQMAR_Arabic_NER_corpus-1.0", }, }, - "task_type": TaskType.Labeling, + "task_type": TaskType.SequenceLabeling, "class_labels": [ "B-PERS", "I-PERS", diff --git a/llmebench/datasets/MGBWords.py b/llmebench/datasets/MGBWords.py index 3cb6e9ea..0a40f357 100644 --- a/llmebench/datasets/MGBWords.py +++ b/llmebench/datasets/MGBWords.py @@ -31,7 +31,7 @@ def metadata(): "splits": { "test": "data/sequence_tagging_ner_pos_etc/NER/mgb/MGB-words.txt" }, - "task_type": TaskType.Labeling, + "task_type": TaskType.SequenceLabeling, "class_labels": [ "B-PERS", "I-PERS", diff --git a/llmebench/datasets/QCRIDialectalArabicPOS.py b/llmebench/datasets/QCRIDialectalArabicPOS.py index c934ee12..8146fa26 100644 --- a/llmebench/datasets/QCRIDialectalArabicPOS.py +++ b/llmebench/datasets/QCRIDialectalArabicPOS.py @@ -45,7 +45,7 @@ def metadata(): }, "default": ["glf.data_5", "lev.data_5", "egy.data_5", "mgr.data_5"], }, - "task_type": TaskType.Labeling, + "task_type": TaskType.SequenceLabeling, "class_labels": [ "ADJ", "ADV", diff --git a/llmebench/datasets/WikiNewsPOS.py b/llmebench/datasets/WikiNewsPOS.py index 0ffcb19a..dec64539 100644 --- a/llmebench/datasets/WikiNewsPOS.py +++ b/llmebench/datasets/WikiNewsPOS.py @@ -23,7 +23,7 @@ def metadata(): "test": "data/sequence_tagging_ner_pos_etc/POS/WikiNewsTruth.txt.POS.tab", "train": "data/sequence_tagging_ner_pos_etc/POS/WikiNewsTruthDev.txt", }, - "task_type": TaskType.Labeling, + "task_type": TaskType.SequenceLabeling, "class_labels": [ "ABBREV", "ADJ", diff --git a/llmebench/datasets/XGLUEPOS.py b/llmebench/datasets/XGLUEPOS.py index daf080f5..b2e12385 100644 --- a/llmebench/datasets/XGLUEPOS.py +++ b/llmebench/datasets/XGLUEPOS.py @@ -23,7 +23,7 @@ def metadata(): "dev": "data/sequence_tagging_ner_pos_etc/POS/XGLUE/ar.dev.src-trg.txt", "test": "data/sequence_tagging_ner_pos_etc/POS/XGLUE/ar.test.src-trg.txt", }, - "task_type": TaskType.Labeling, + "task_type": TaskType.SequenceLabeling, "class_labels": [ "ADJ", "ADP", diff --git a/llmebench/datasets/dataset_base.py b/llmebench/datasets/dataset_base.py index 6bcbda68..c59843f9 100644 --- a/llmebench/datasets/dataset_base.py +++ b/llmebench/datasets/dataset_base.py @@ -97,7 +97,8 @@ def metadata(): Model. "class_labels" : list (optional) List of class labels, must be provided when `task_type` is - `Classification`, `MultiLabelClassification` or `Labeling`. + `Classification`, `MultiLabelClassification` or + `SequenceLabeling`. "score_range" : tuple (optional) Score range defining (min_val, max_val). Must be defined when `task_type` is `Regression` diff --git a/llmebench/tasks/__init__.py b/llmebench/tasks/__init__.py index 122d7524..0510eab6 100644 --- a/llmebench/tasks/__init__.py +++ b/llmebench/tasks/__init__.py @@ -37,7 +37,7 @@ [ "Classification", "MultiLabelClassification", - "Labeling", + "SequenceLabeling", "QuestionAnswering", "SequenceToSequence", "Regression", diff --git a/tests/datasets/test_metadata.py b/tests/datasets/test_metadata.py index 26ed0fa6..97b19a5c 100644 --- a/tests/datasets/test_metadata.py +++ b/tests/datasets/test_metadata.py @@ -57,7 +57,7 @@ def test_dataset_metadata(self): if metadata["task_type"] in [ TaskType.Classification, - TaskType.Labeling, + TaskType.SequenceLabeling, TaskType.MultiLabelClassification, ]: self.assertIn("class_labels", metadata) diff --git a/tests/models/test_HuggingFaceInferenceAPI.py b/tests/models/test_HuggingFaceInferenceAPI.py index 0f93cf3f..41ee5663 100644 --- a/tests/models/test_HuggingFaceInferenceAPI.py +++ b/tests/models/test_HuggingFaceInferenceAPI.py @@ -4,6 +4,8 @@ from llmebench import Benchmark from llmebench.models import HuggingFaceInferenceAPIModel, HuggingFaceTaskTypes +from llmebench.utils import is_fewshot_asset + class TestAssetsForHuggingFaceInferenceAPIPrompts(unittest.TestCase): @classmethod @@ -29,7 +31,7 @@ def test_huggingface_inference_api_prompts(self): config = asset["config"] dataset = config["dataset"](**config["dataset_args"]) data_sample = dataset.get_data_sample() - if "fewshot" in config["general_args"]: + if is_fewshot_asset(config, asset["module"].prompt): prompt = asset["module"].prompt( data_sample["input"], [data_sample for _ in range(n_shots)],