diff --git a/docs/tutorials/adding_asset.md b/docs/tutorials/adding_asset.md index aa10db8f..65e0e044 100644 --- a/docs/tutorials/adding_asset.md +++ b/docs/tutorials/adding_asset.md @@ -27,6 +27,13 @@ def config(): # } # } # } + # + # If the Dataset you are using does not have your required splits by default, + # "custom_test_split" and "custom_train_split" can be used instead. These are + # usually strings, but their structure is dictated by the data loader in a + # specific dataset. The framework supports absolute paths, relative paths ( + # relative to `data_dir/*Dataset/`) or special paths prefixed with `:data_dir:`, + # which are resolved relative to `data_dir`. def prompt(input_sample): # This function receives an input_sample and pre-processes it into the diff --git a/llmebench/datasets/TyDiQA.py b/llmebench/datasets/TyDiQA.py index cec6edce..5c07c0f9 100644 --- a/llmebench/datasets/TyDiQA.py +++ b/llmebench/datasets/TyDiQA.py @@ -22,7 +22,7 @@ def metadata(): "license": "Apache License Version 2.0", "splits": { "dev": "tydiqa-goldp-dev-arabic.json", - "train": ":depends:ARCD/arcd-train.json", + "train": ":data_dir:ARCD/arcd-train.json", }, "task_type": TaskType.QuestionAnswering, } diff --git a/llmebench/datasets/UnifiedFCFactuality.py b/llmebench/datasets/UnifiedFCFactuality.py index d106d982..a6e5ff2a 100644 --- a/llmebench/datasets/UnifiedFCFactuality.py +++ b/llmebench/datasets/UnifiedFCFactuality.py @@ -25,7 +25,7 @@ def metadata(): "license": "Research Purpose Only", "splits": { "test": "ramy_arabic_fact_checking.tsv", - "train": ":depends:ANSStance/claim/train.csv", + "train": ":data_dir:ANSStance/claim/train.csv", }, "task_type": TaskType.Classification, "class_labels": ["true", "false"], diff --git a/llmebench/datasets/UnifiedFCStance.py b/llmebench/datasets/UnifiedFCStance.py index 288599f6..915e1215 100644 --- a/llmebench/datasets/UnifiedFCStance.py +++ b/llmebench/datasets/UnifiedFCStance.py @@ -27,7 +27,7 @@ def metadata(): "license": "Research Purpose Only", "splits": { "test": "ramy_arabic_stance.jsonl", - "train": ":depends:ANSStance/stance/train.csv", + "train": ":data_dir:ANSStance/stance/train.csv", }, "task_type": TaskType.Classification, "class_labels": ["agree", "disagree", "discuss", "unrelated"], diff --git a/llmebench/datasets/XQuAD.py b/llmebench/datasets/XQuAD.py index fc5ae0d0..9f5e6d64 100644 --- a/llmebench/datasets/XQuAD.py +++ b/llmebench/datasets/XQuAD.py @@ -23,7 +23,7 @@ def metadata(): "license": "CC-BY-SA4.0", "splits": { "test": "xquad.ar.json", - "train": ":depends:ARCD/arcd-train.json", + "train": ":data_dir:ARCD/arcd-train.json", }, "task_type": TaskType.QuestionAnswering, } diff --git a/llmebench/utils.py b/llmebench/utils.py index beeca1f3..65fbe7b9 100644 --- a/llmebench/utils.py +++ b/llmebench/utils.py @@ -51,6 +51,7 @@ def get_data_paths(config, split): assert split in ["train", "test"] dataset_args = config.get("dataset_args", {}) + dataset_args["data_dir"] = "" dataset = config["dataset"](**dataset_args) if split == "test": @@ -129,10 +130,10 @@ def resolve_path(path, dataset, data_dir): if not isinstance(data_dir, Path): data_dir = Path(data_dir) - if not str(path).startswith(":depends:") and path.is_absolute(): + if not str(path).startswith(":data_dir:") and path.is_absolute(): return path - elif str(path).startswith(":depends:"): - return data_dir / str(path)[len(":depends:") :] + elif str(path).startswith(":data_dir:"): + return data_dir / str(path)[len(":data_dir:") :] else: dataset_name = dataset.__class__.__name__ if dataset_name.endswith("Dataset"):