-
Notifications
You must be signed in to change notification settings - Fork 4
/
oxe-64-goal-cond.sh
32 lines (29 loc) · 1.72 KB
/
oxe-64-goal-cond.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Pre-training tokenizer using four A100-40GB GPUs
# Arguments:
# oxe_data_mixes_type: 'select_sthsth' for OXE+SSv2, 'select' for OXE only
# dataset_path: path to preprocessed OXE dataset
# sthsth_root_path: path to preprocessed SSv2 dataset
accelerate launch train_tokenizer.py \
--exp_name oxe-64-goal-cond-tokenizer --output_dir log_vqgan --seed 0 --mixed_precision bf16 \
--model_type ctx_vqgan \
--learning_rate 1e-4 --discr_learning_rate 1e-4 \
--train_batch_size 16 --gradient_accumulation_steps 1 --disc_start 1000005 \
--oxe_data_mixes_type select_sthsth --resolution 64 --dataloader_num_workers 16 \
--rand_shuffle --video_stepsize 1 --segment_horizon 16 --segment_length 8 --context_length 2 \
--dataset_path {path to preprocessed_OXE} \
--sthsth_root_path {path to preprocessed_SSv2} \
--pretrained_model_name_or_path {pretrained model, e.g. which from oxe-64-act-free.sh}
# Pre-training transformer using four A100-40GB GPUs
# Argments:
# pretrained_model_name_or_path: path to the pre-trained tokenizer
accelerate launch train_gpt.py \
--exp_name oxe-64-goal-cond-transformer --output_dir log_trm --seed 0 --mixed_precision bf16 \
--vqgan_type ctx_vqgan \
--pretrained_model_name_or_path {log directory of finetuned tokenizer}/unwrapped_model \
--config_name configs/llama/config.json \
--per_device_train_batch_size 16 --gradient_accumulation_steps 1 \
--learning_rate 1e-4 --lr_scheduler_type cosine \
--oxe_data_mixes_type select --resolution 64 --dataloader_num_workers 16 \
--dataset_path {path to preprocessed_OXE} \
--goal_conditioned --video_stepsize 1 --segment_length 17 --context_length 2 \
--weight_decay 0.01 --llama_attn_drop 0.1 --embed_no_wd