Skip to content

Commit

Permalink
Merge pull request #29 from icecream-and-tea/main
Browse files Browse the repository at this point in the history
fixed bug: gsm8k download
  • Loading branch information
Immortalise authored Dec 16, 2023
2 parents 258aa2e + 8df6313 commit 73f68e3
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions promptbench/dataload/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,20 @@ def __init__(self, dataset_name):
os.mkdir(self.data_dir)

# check if the dataset exists, if not, download it
self.filepath = os.path.join(self.data_dir, f"{dataset_name}.json")
self.filepath2 = os.path.join(self.data_dir, f"{dataset_name}.jsonl")
if dataset_name == "gsm8k":
self.filepath = os.path.join(self.data_dir, f"{dataset_name}.jsonl")
else:
self.filepath = os.path.join(self.data_dir, f"{dataset_name}.json")

if not os.path.exists(self.filepath):
if os.path.exists(self.filepath2):
self.filepath = self.filepath2
if dataset_name == "gsm8k":
url = f'https://wjdcloud.blob.core.windows.net/dataset/promptbench/dataset/{dataset_name}.jsonl'
else:
url = f'https://wjdcloud.blob.core.windows.net/dataset/promptbench/dataset/{dataset_name}.json'
print(f"Downloading {dataset_name} dataset...")
response = requests.get(url)
with open(self.filepath, 'wb') as f:
f.write(response.content)
print(f"Downloading {dataset_name} dataset...")
response = requests.get(url)
with open(self.filepath, 'wb') as f:
f.write(response.content)

def __len__(self):
assert len(self.data) > 0, "Empty dataset. Please load data first."
Expand Down

0 comments on commit 73f68e3

Please sign in to comment.