diff --git a/dataloader/dataloader.py b/dataloader/dataloader.py index d2318a3..740bb47 100644 --- a/dataloader/dataloader.py +++ b/dataloader/dataloader.py @@ -11,7 +11,16 @@ from tqdm import tqdm import re import os - +from utils.config import ( + hotpot_qa_path, + mwh_qa_path, + trivia_qa_path, + cbt_path, + math_path, + gsm8k_math, + arc_path, + mmlu_path +) split_map = { "train": { "hotpot_qa": "train", @@ -45,7 +54,7 @@ def sample(self, count: int): class DataloaderForHotpotQA(BaseDataloader): def __init__( - self, dataset: str = "hotpot_qa", name: str = "distractor", split: str = "train" + self, dataset: str = hotpot_qa_path, name: str = "distractor", split: str = "train" ): self.dataset = dataset self.name = name @@ -105,7 +114,7 @@ def sample_once(self): class DataloaderForMWHQA(BaseDataloader): def __init__( self, - dataset_path: str = "/home/test/test04/yuanjiarui/project/json_dataset/2MultiWikiQA", + dataset_path: str = mwh_qa_path, name: str = "", split: str = "train", ): @@ -153,7 +162,7 @@ def sample_once(self): class DataloaderForTrivalQA(BaseDataloader): def __init__( self, - dataset: str = "/home/test/test04/yuanjiarui/project/huggingface_cache/trivia_qa_dataset", + dataset: str = trivia_qa_path, name: str = "rc", split: str = "train", ): @@ -194,7 +203,7 @@ def sample_once(self): class DataloaderForCBT(BaseDataloader): def __init__( self, - dataset: str = "/home/test/test04/yuanjiarui/project/huggingface_cache/cbt_dataset", + dataset: str = cbt_path, name: str = "CN", split: str = "train", ): @@ -230,7 +239,7 @@ def sample_once(self): class DataloaderForGSM8K(BaseDataloader): def __init__( self, - dataset: str = "/home/test/test04/yuanjiarui/project/huggingface_cache/gsm8k_dataset", + dataset: str = gsm8k_math, name: str = "main", split: str = "train", ): @@ -263,7 +272,7 @@ def sample_once(self): class DataloaderForMATH(BaseDataloader): def __init__( self, - dataset: str = "/home/test/test04/yuanjiarui/project/json_dataset/MATH", + dataset: str = math_path, split="train", ): self.split = split @@ -325,7 +334,7 @@ def sample_once(self): class DataloaderForARC(BaseDataloader): def __init__( self, - dataset: str = "/home/test/test04/yuanjiarui/project/huggingface_cache/arc_dataset", + dataset: str = arc_path, name: str = "ARC-Challenge", split: str = "train", ): @@ -366,7 +375,7 @@ def sample_once(self): class DataloaderForMMLU(BaseDataloader): def __init__( self, - dataset: str = "/home/test/test04/yuanjiarui/project/huggingface_cache/mmlu_dataset", + dataset: str = mmlu_path, name: str = "all", split: str = "auxiliary_train", ): diff --git a/train/dpo_recipes/arc.yaml b/train/dpo_recipes/arc.yaml index 36b371b..ee8b437 100644 --- a/train/dpo_recipes/arc.yaml +++ b/train/dpo_recipes/arc.yaml @@ -19,9 +19,9 @@ sample_count: 0 origin_dpo_yaml_path: ./alignment-handbook/recipes/Llama3-8b/arc_dpo/base.yaml mid_dpo_jsonl_root_path: ./results/arc_dpo mid_dpo_dataset_root_path: ./my_datasets/arc_dpo -initial_dpo_min_value: 0 -initial_dpo_episilon: -100 -monte_sample_count: 10 +initial_dpo_min_value: 0.2 +initial_dpo_episilon: 0.45 +monte_sample_count: 10000 cal_ppl: 1 from_initial: False lambda1: -0.4 diff --git a/train/sft_dpo_recipes/arc.yaml b/train/sft_dpo_recipes/arc.yaml index 4982fe7..40007f7 100644 --- a/train/sft_dpo_recipes/arc.yaml +++ b/train/sft_dpo_recipes/arc.yaml @@ -17,14 +17,14 @@ origin_sft_yaml_path: ./alignment-handbook/recipes/Llama3-8b/arc_sft_dpo/sft/bas mid_sft_jsonl_root_path: ./results/arc_sft_dpo/sft mid_sft_dataset_root_path: ./my_datasets/arc_sft_dpo/sft initial_episilon: 0.6 -sample_count: 10 +sample_count: 10000 # dpo origin_dpo_yaml_path: ./alignment-handbook/recipes/Llama3-8b/arc_sft_dpo/dpo/base.yaml mid_dpo_jsonl_root_path: ./results/arc_sft_dpo/dpo mid_dpo_dataset_root_path: ./my_datasets/arc_sft_dpo/dpo -initial_dpo_min_value: 0 -initial_dpo_episilon: -100 -monte_sample_count: 10 +initial_dpo_min_value: 0.2 +initial_dpo_episilon: 0.45 +monte_sample_count: 10000 cal_ppl: 1 lambda1: -0.5 lambda2: 0.6 diff --git a/utils/config.py b/utils/config.py index d68eff2..76a730c 100644 --- a/utils/config.py +++ b/utils/config.py @@ -7,3 +7,11 @@ llama3_path_aistation = "" llama3_path_a800 = "" prompt_pool_path_a800 = "" +hotpot_qa_path="" +mwh_qa_path="" +trivia_qa_path="" +cbt_path="" +math_path="" +gsm8k_math="" +arc_path="" +mmlu_path="" \ No newline at end of file