diff --git a/promptbench/dataload/dataset.py b/promptbench/dataload/dataset.py index 33d4be4..9d44018 100644 --- a/promptbench/dataload/dataset.py +++ b/promptbench/dataload/dataset.py @@ -35,17 +35,20 @@ def __init__(self, dataset_name): os.mkdir(self.data_dir) # check if the dataset exists, if not, download it - self.filepath = os.path.join(self.data_dir, f"{dataset_name}.json") - self.filepath2 = os.path.join(self.data_dir, f"{dataset_name}.jsonl") + if dataset_name == "gsm8k": + self.filepath = os.path.join(self.data_dir, f"{dataset_name}.jsonl") + else: + self.filepath = os.path.join(self.data_dir, f"{dataset_name}.json") + if not os.path.exists(self.filepath): - if os.path.exists(self.filepath2): - self.filepath = self.filepath2 + if dataset_name == "gsm8k": + url = f'https://wjdcloud.blob.core.windows.net/dataset/promptbench/dataset/{dataset_name}.jsonl' else: url = f'https://wjdcloud.blob.core.windows.net/dataset/promptbench/dataset/{dataset_name}.json' - print(f"Downloading {dataset_name} dataset...") - response = requests.get(url) - with open(self.filepath, 'wb') as f: - f.write(response.content) + print(f"Downloading {dataset_name} dataset...") + response = requests.get(url) + with open(self.filepath, 'wb') as f: + f.write(response.content) def __len__(self): assert len(self.data) > 0, "Empty dataset. Please load data first."