forked from ishwnews/MASS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nmt_unsupervised_with_bt_multigpu.sh
executable file
·29 lines (28 loc) · 1.42 KB
/
nmt_unsupervised_with_bt_multigpu.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
MODEL=checkpoint_path_in_pre-training_stage
DATA_PATH=YOUR_DATA_PATH
export NGPU=4; CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=$NGPU train.py \
--exp_name unsupMT_enfr \
--dump_path ./models/en-fr/ \
--exp_id unsupMT_enfr_bt_multigpu \
--reload_model "$MODEL,$MODEL" \
--data_path $DATA_PATH \
--lgs 'en-fr' \
--ae_steps 'en,fr' \
--bt_steps 'en-fr-en,fr-en-fr' \
--lambda_bt '1.0' \
--lambda_ae '0:0,100000:0' \
--encoder_only false \
--emb_dim 1024 \
--n_layers 6 \
--n_heads 8 \
--dropout 0.1 \
--attention_dropout 0.1 \
--gelu_activation true \
--tokens_per_batch 2000 \
--batch_size 32 \
--bptt 256 \
--optimizer adam_inverse_sqrt,beta1=0.9,beta2=0.98,lr=0.0001 \
--epoch_size 200000 \
--eval_bleu true \
--stopping_criterion 'valid_en-fr_mt_bleu,10' \
--validation_metrics 'valid_en-fr_mt_bleu' \