forked from AI-Hypercomputer/maxtext
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_checkpointing.sh
67 lines (57 loc) · 2.11 KB
/
test_checkpointing.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/bin/bash
set -ex
if [ -f "saved_metrics.txt" ]; then
rm saved_metrics.txt
echo "removed existing saved_metrics.txt"
fi
if [ -f "restored_metrics.txt" ]; then
rm restored_metrics.txt
echo "removed existing restored_metrics.txt"
fi
RUN_NAME=${1}-${4}-$(date +%Y-%m-%d-%H-%M)
OUTPUT_PATH=${2}
DATASET_PATH=${3}
COLLECT_STACK_TRACE=${4}
DATASET_TYPE=${5}
ATTENTION=${6}
if [ -z "${6}" ]; then
ATTENTION='autoselected'
fi
eval_metrics=checkpoint_save_restore
model_params=" base_emb_dim=384 base_num_query_heads=8 base_num_kv_heads=8 base_mlp_dim=192 base_num_decoder_layers=8 head_dim=128"
CMD_DATA=""
if [ "$DATASET_TYPE" == "c4-array_record" ]
then
eval_metrics=grain_checkpoint_save_restore
echo "Using c4-array_record dataset type"
echo "Mounting $DATASET_PATH to /tmp/gcsfuse/"
bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$DATASET_PATH MOUNT_PATH=/tmp/gcsfuse/
DATASET_PATH=/tmp/gcsfuse/
CMD_DATA=" grain_worker_count=0 dataset_type=c4-array_record dataset_name=array-record/c4/en/3.0.1 eval_dataset_name=array-record/c4/en/3.0.1"
fi
#Train
CMD1="python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=5 max_target_length=128 per_device_batch_size=1\
metrics_file=saved_metrics.txt checkpoint_period=3 base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\
async_checkpointing=false collect_stack_trace=$COLLECT_STACK_TRACE attention=$ATTENTION"
CMD1+=$model_params
CMD1+=$CMD_DATA
CMD2="python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=5 max_target_length=128 per_device_batch_size=1\
metrics_file=restored_metrics.txt base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\
async_checkpointing=false collect_stack_trace=$COLLECT_STACK_TRACE attention=$ATTENTION"
CMD2+=$model_params
CMD2+=$CMD_DATA
echo
echo "Start the first training run"
echo "Command is:"
echo $CMD1
$CMD1
# Wait for first train to finish
# process_id=$!
# wait $process_id
echo
echo "First training run done"
echo "Start the second training run"
echo "Command is:"
echo $CMD2
$CMD2
python3 end_to_end/tpu/eval_assert.py $eval_metrics metrics.txt learning/loss