Skip to content

Latest commit

 

History

History
149 lines (118 loc) · 5.48 KB

README.md

File metadata and controls

149 lines (118 loc) · 5.48 KB

📋 We are still in progress making this repo clean. Use it with caution and please report errors and questions to us.

The Surprising Effectiveness of Test-Time Training for Abstract Reasoning

This repository is the official implementation of our paper:

The Surprising Effectiveness of Test-Time Training for Abstract Reasoning

Ekin Akyürek, Mehul Damani, Linlu Qiu, Han Guo, Yoon Kim, Jacob Andreas

Requirements

To install requirements, you can start a fresh conda environment, and install followings with pip:

git clone --recursive git://github.com/ekinakyurek/marc
cd marc/
# For TTT pipeline, we used a fork of torchtune library.
# You need to install it first
conda create -n arc python=3.10
conda activate arc
# Install torchtune with my specific fork
# We need this as editable because we actually use some files
# under third_party/torchtune/recipes/ which doesn't come
# if you just do pip install
cd third_party/torchtune
pip install -e .
# install other required libraries for torchtune
pip install torch torchao --pre --upgrade --index-url https://download.pytorch.org/whl/nightly/cu121

# Then we have simple requirements can be installed as:
pip install -r requirements.txt

📋 You need download the ARC dataset from kaggle link https://www.kaggle.com/competitions/arc-prize-2024/data

📋 You can reach out finetuned models and TTT checkpoints from the following links:

Test Time Training

To train the model(s) in the paper, run this command:

# Specify data path
data_file=/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json
# Specify finetuned path
base_checkpoint_dir=/path/to/finetuned/model/folder/
# Specify where TTT adapters should be saved
ttt_folder=/path/to/ttt/folder
mkdir -p $ttt_folder


# You need show an initial config file that is compatible with torchtune configs
# This is provided in this repo
lora_config_file=configs/ttt/8B_lora_single_device.yaml
# lora_config_file=configs/ttt/8.1B_lora_single_device.yaml # for barc
# But you can override some of the variables
batch_size=2
epochs=2
learning_rate=5e-5
lora_rank=128
lora_alpha=16.0
lora_to_output=False # doesn't apply for Llama3.2 models for now.
# You can specify how many tasks you want train for.
num_tasks=15

# You can run the main script
python test_time_train.py --lora_config=$lora_config_file \
--base_checkpoint_dir=$base_checkpoint_dir \
--experiment_folder=$ttt_folder \
--data_file=$data_file \
--batch_size=$batch_size \
--epochs=$epochs \
--num_tasks=${num_tasks} \
--lora_rank=$lora_rank \
--lora_alpha=$lora_alpha \
--lora_to_output=$lora_to_output \
--new_format # use --barc_format for barc

📋 If you are using BARC checkpoints and unmask_outputs and if unmask_outputs=True in the program arguments then you need to uncomment these lines in my torchtune clone here

📋 TTT training will save adapter checkpints under ttt_folder you specified above.

Inference

To do inference with TTT, you run predict.py

# Specify data path
data_file=/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json
# Tell where your Fintuned (named as base) and TTT checkpoints are
base_checkpoint_dir=/path/to/finetuned/model/folder/
ttt_folder=/path/to/ttt/folder

# if solution file is given predict will evaluate the model
solution_file=/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json


temperature=0
n_sample=1

# this should be same as your ttt
max_lora_rank=128
# You need to tell where predictions and submissions should be saved
tti_folder=/path/to/tti/folder
mkdir -p $tti_folder


python predict.py \
--experiment_folder=$tti_folder \
--pretrained_checkpoint=$base_checkpoint_dir \
--lora_checkpoints_folder=$ttt_folder \
--temperature=$temperature \
--n_sample=$n_sample \
--data_file=$data_file \
--solution_file=$solution_file \
--max_lora_rank=$max_lora_rank \
--include_n=1 \
--new_format

📋 For Llama-3 and Llama-3.2 we used different versions of VLLM, and these are not compatible with torchtune version that we use. So, we give setup instructions for vllm for llama3 and vllm for llama3-2 for reproducibiltiy. We use seperate conda environments for inference pipeline.

# For Llama3 and 3.1 models
conda create -n vllm python=3.10
conda activate vllm
pip install vllm@git+https://github.com/ekinakyurek/vllm.git@ekin/torchtunecompat
pip install -r requirements
# For Llama3.2 models
conda create -n vllmnew python=3.10
conda activate vllmnew
pip install vllm@git+https://github.com/ekinakyurek/vllm.git@ekin/ekin/newvllm
pip install -r requirements

Predictions from models