Skip to content

Commit

Permalink
fix a bug of dataset path #1
Browse files Browse the repository at this point in the history
  • Loading branch information
1rubbishyuan committed Oct 15, 2024
1 parent 7f1a51e commit 2ca2e2e
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 16 deletions.
27 changes: 18 additions & 9 deletions dataloader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@
from tqdm import tqdm
import re
import os

from utils.config import (
hotpot_qa_path,
mwh_qa_path,
trivia_qa_path,
cbt_path,
math_path,
gsm8k_math,
arc_path,
mmlu_path
)
split_map = {
"train": {
"hotpot_qa": "train",
Expand Down Expand Up @@ -45,7 +54,7 @@ def sample(self, count: int):
class DataloaderForHotpotQA(BaseDataloader):

def __init__(
self, dataset: str = "hotpot_qa", name: str = "distractor", split: str = "train"
self, dataset: str = hotpot_qa_path, name: str = "distractor", split: str = "train"
):
self.dataset = dataset
self.name = name
Expand Down Expand Up @@ -105,7 +114,7 @@ def sample_once(self):
class DataloaderForMWHQA(BaseDataloader):
def __init__(
self,
dataset_path: str = "/home/test/test04/yuanjiarui/project/json_dataset/2MultiWikiQA",
dataset_path: str = mwh_qa_path,
name: str = "",
split: str = "train",
):
Expand Down Expand Up @@ -153,7 +162,7 @@ def sample_once(self):
class DataloaderForTrivalQA(BaseDataloader):
def __init__(
self,
dataset: str = "/home/test/test04/yuanjiarui/project/huggingface_cache/trivia_qa_dataset",
dataset: str = trivia_qa_path,
name: str = "rc",
split: str = "train",
):
Expand Down Expand Up @@ -194,7 +203,7 @@ def sample_once(self):
class DataloaderForCBT(BaseDataloader):
def __init__(
self,
dataset: str = "/home/test/test04/yuanjiarui/project/huggingface_cache/cbt_dataset",
dataset: str = cbt_path,
name: str = "CN",
split: str = "train",
):
Expand Down Expand Up @@ -230,7 +239,7 @@ def sample_once(self):
class DataloaderForGSM8K(BaseDataloader):
def __init__(
self,
dataset: str = "/home/test/test04/yuanjiarui/project/huggingface_cache/gsm8k_dataset",
dataset: str = gsm8k_math,
name: str = "main",
split: str = "train",
):
Expand Down Expand Up @@ -263,7 +272,7 @@ def sample_once(self):
class DataloaderForMATH(BaseDataloader):
def __init__(
self,
dataset: str = "/home/test/test04/yuanjiarui/project/json_dataset/MATH",
dataset: str = math_path,
split="train",
):
self.split = split
Expand Down Expand Up @@ -325,7 +334,7 @@ def sample_once(self):
class DataloaderForARC(BaseDataloader):
def __init__(
self,
dataset: str = "/home/test/test04/yuanjiarui/project/huggingface_cache/arc_dataset",
dataset: str = arc_path,
name: str = "ARC-Challenge",
split: str = "train",
):
Expand Down Expand Up @@ -366,7 +375,7 @@ def sample_once(self):
class DataloaderForMMLU(BaseDataloader):
def __init__(
self,
dataset: str = "/home/test/test04/yuanjiarui/project/huggingface_cache/mmlu_dataset",
dataset: str = mmlu_path,
name: str = "all",
split: str = "auxiliary_train",
):
Expand Down
6 changes: 3 additions & 3 deletions train/dpo_recipes/arc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ sample_count: 0
origin_dpo_yaml_path: ./alignment-handbook/recipes/Llama3-8b/arc_dpo/base.yaml
mid_dpo_jsonl_root_path: ./results/arc_dpo
mid_dpo_dataset_root_path: ./my_datasets/arc_dpo
initial_dpo_min_value: 0
initial_dpo_episilon: -100
monte_sample_count: 10
initial_dpo_min_value: 0.2
initial_dpo_episilon: 0.45
monte_sample_count: 10000
cal_ppl: 1
from_initial: False
lambda1: -0.4
Expand Down
8 changes: 4 additions & 4 deletions train/sft_dpo_recipes/arc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ origin_sft_yaml_path: ./alignment-handbook/recipes/Llama3-8b/arc_sft_dpo/sft/bas
mid_sft_jsonl_root_path: ./results/arc_sft_dpo/sft
mid_sft_dataset_root_path: ./my_datasets/arc_sft_dpo/sft
initial_episilon: 0.6
sample_count: 10
sample_count: 10000
# dpo
origin_dpo_yaml_path: ./alignment-handbook/recipes/Llama3-8b/arc_sft_dpo/dpo/base.yaml
mid_dpo_jsonl_root_path: ./results/arc_sft_dpo/dpo
mid_dpo_dataset_root_path: ./my_datasets/arc_sft_dpo/dpo
initial_dpo_min_value: 0
initial_dpo_episilon: -100
monte_sample_count: 10
initial_dpo_min_value: 0.2
initial_dpo_episilon: 0.45
monte_sample_count: 10000
cal_ppl: 1
lambda1: -0.5
lambda2: 0.6
Expand Down
8 changes: 8 additions & 0 deletions utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,11 @@
llama3_path_aistation = ""
llama3_path_a800 = ""
prompt_pool_path_a800 = ""
hotpot_qa_path=""
mwh_qa_path=""
trivia_qa_path=""
cbt_path=""
math_path=""
gsm8k_math=""
arc_path=""
mmlu_path=""

0 comments on commit 2ca2e2e

Please sign in to comment.