You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I'm experiencing performance degradation when using multi-node training with pretrain.py
I followed the continual pretraining tutorial using TinyLlama on the OpenWebMath 14B dataset.
I'm working with a bare-bones multi-node setup, where each node has 8 GPUs. For each node, I used the following commands:
According to the wandb logs, the total number of tokens trained is the same, but the iterations decrease proportionally with the number of nodes.
However, the final results show lower performance as the number of nodes increases.
gsm8k
math
svamp
asdiv
mawps
tabmwp
mathqa
mmlu_stemm
sat_math
avg
32node
2.9
3.2
15.1
22.1
27.9
15.3
12.1
14.2
18.8
14.6
2node
4.1
3.6
17.9
29.7
38.7
15.9
12.3
15.8
18.8
17.4
1node
4.1
3
19.6
29.9
39.4
15.7
9.8
16.5
31.2
18.8
In wandb, it seems that the loss is recorded only for rank 0, so I understand why the loss curve might appear different.
However, I can't figure out why the overall performance decreases.
For clarify, all nodes are using the same learning rate and same global_batch_size.
I'd appreciate any advice on what might be causing this issue and what adjustments I should consider.
The text was updated successfully, but these errors were encountered:
Hi, I'm experiencing performance degradation when using multi-node training with
pretrain.py
I followed the continual pretraining tutorial using TinyLlama on the OpenWebMath 14B dataset.
I'm working with a bare-bones multi-node setup, where each node has 8 GPUs. For each node, I used the following commands:
According to the wandb logs, the total number of tokens trained is the same, but the iterations decrease proportionally with the number of nodes.
However, the final results show lower performance as the number of nodes increases.
In wandb, it seems that the loss is recorded only for rank 0, so I understand why the loss curve might appear different.
However, I can't figure out why the overall performance decreases.
For clarify, all nodes are using the same learning rate and same global_batch_size.
I'd appreciate any advice on what might be causing this issue and what adjustments I should consider.
The text was updated successfully, but these errors were encountered: