This is the official repository for Montessori-Instruct: Generate Influential Training Data Tailored for Student Learning. In this work, we propose a novel data synthesis framework that tailors the data synthesis ability of the teacher toward the student’s learning process.
Updates:
🎉 [2024-10-18] We release the paper, codes, and synthetic datasets related to Montessori-Instruct.
🔗 Table of Contents
git clone [email protected]:cxcscmu/Montessori-Instruct.git montessori
cd montessori
# tested under Python 3.10.14
pip install -r requirements.txt
pip install flash-attn --no-build-isolation
- If you are going to use models from Hugging Face, please set up your Hugging Face cache path.
export HF_HOME=xxx
- We use Weights and Biases to log the training process. Please log in to your wandb account first.
wandb login
- We use Huggingface Accelerate to full parameters fine tune the students. Please set your accelerate environment config file first:
# for Llama3-8B models
accelerate config --config_file ./configs/fsdp.yml
# for Tinyllama-1.1B models
accelerate config --config_file ./configs/ddp.yml
- 🤖 Models: In our main experiments, we use Llama3-8B-Instruct as the teacher model and Llama3-8B and TinyLlama-1.1B as the student models. Please ensure you have access to Meta Llama3 series models beforehand.
- 🔎 Datasets: We use the alpaca_gpt4 dataset as the seed dataset.
Here we provide a quick start to go through the whole data synthesis pipeline.
Step 1: Setup your project config in project_config.yml
.
project_name: your_project_name # your project name
teacher_model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # the path of the teacher model
student_model_name_or_path: TinyLlama/TinyLlama_v1.1 # the path of the student model
fsdp_or_ddp: ./configs/ddp.yml # the path of the accelerate config file
training_dataset_num: 100 # the number of synthetic data for student training
probing_dataset_num: 20 # the number of probing data for local data influence collection
warmup_dataset_num: 100 # the number of warmup data for student warmup
num_gpus: 8
Step 2: Run the following command:
python gen_configs.py
to generate all the config files under configs/your_project_name
, including the training scripts under the scripts/your_project_name
folder.
Step 3: Start your quick project by running:
bash ./scripts/your_project_name/run_all.sh
You can find the synthetic data under the ./data/your_project_name
folder and the logs under the ./logs/your_project_name
folder.
We require two datasets: one for warming up the student (warmup dataset) and another for collecting local data influence (probing dataset). The gen_warmup_and_probing_dataset.yml
file contains the parameters for this step. You need to create gen_warmup_and_probing_dataset.yml
first, and you can look at the template config file under configs/templates.
Then run the following command to generate the datasets:
# Generate the instructions first
python ./src/gen_instructions.py --config_file ./configs/gen_warmup_and_probing_dataset.yml
# Generate the responses then
python ./src/gen_responses.py --config_file ./configs/gen_warmup_and_probing_dataset.yml
# Divide the dataset into warmup and probing datasets separately
python ./src/divide_dataset.py --config_file ./configs/gen_warmup_and_probing_dataset.yml
warm_up.yml
contains the parameters for warming up the student. Once you've configured this file, you can run the following command to initiate the warmup process:
accelerate launch --config_file ./configs/ddp.yml ./src/warm_up.py --config_file ./configs/warm_up.yml # for the Tinyllama-1.1B model; change ddp to fsdp for Llama3-8B series models.
In this step, we load the warmup checkpoint and use it to collect the local data influence based on the probing dataset. The collect_local_data_influence.yml
file contains the parameters for this step. By default, we use the alpaca_eval_gpt4_1106_preview
reference dataset, which has been downloaded with this repository.
After configuring, you can run the following command to collect the local data influence:
accelerate launch --config_file ./configs/ddp.yml ./src/collect_local_data_influence.py --config_file ./configs/collect_local_data_influence.yml # for Tinyllama-1.1B models; change ddp to fsdp for Llama3-8B models.
After collecting the local data influence, we can use them to create a preference dataset and update the teacher using this preference dataset. The dpo_teacher.yml
file contains the parameters for this step.
Run the following command to update the teacher:
python ./src/dpo_teacher.py --config_file ./configs/dpo_teacher.yml
At this stage, we have obtained the optimized teacher model tailored to the student's learning process. Now, we use this teacher to generate the synthetic dataset for the student's training. The gen_training_dataset.yml
file contains the parameters for generating the synthetic dataset, and train_student.yml
contains the parameters for training the student.
Run the following command to generate the synthetic dataset:
# First generate the training dataset for student models
python ./src/gen_instructions.py --config_file ./configs/gen_training_dataset.yml
python ./src/gen_responses.py --config_file ./configs/gen_training_dataset.yml
# Then train the student model
accelerate launch --config_file ./configs/ddp.yml ./src/train_student.py --config_file ./configs/train_student.yml
We evaluate the instruction-following ability of the student trained with Montessori-Instruct synthetic data using the Alpaca Eval 2.0 (in-domain) and MT-Bench (out-of-domain) benchmarks. Additionally, we employ the lm-evaluation-harness to evaluate the student's general performance on MMLU, GSM8K, and other benchmarks.
If you have any questions regarding the code or the paper, feel free to contact Xiaochuan ([email protected]). If you encounter any problems when using the code or want to report a bug, you can open an issue to provide detailed information to help us assist you more effectively and efficiently.
Please cite our paper if you use Montessori-Instruct in your work:
@misc{li2024montessoriinstructgenerateinfluentialtraining,
title={Montessori-Instruct: Generate Influential Training Data Tailored for Student Learning},
author={Xiaochuan Li and Zichun Yu and Chenyan Xiong},
year={2024},
eprint={2410.14208},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2410.14208},
}