Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance degradation on multi-node pretrain #1836

Open
HaebinShin opened this issue Nov 25, 2024 · 0 comments
Open

Performance degradation on multi-node pretrain #1836

HaebinShin opened this issue Nov 25, 2024 · 0 comments
Labels
question Further information is requested

Comments

@HaebinShin
Copy link

Hi, I'm experiencing performance degradation when using multi-node training with pretrain.py
I followed the continual pretraining tutorial using TinyLlama on the OpenWebMath 14B dataset.

I'm working with a bare-bones multi-node setup, where each node has 8 GPUs. For each node, I used the following commands:

fabric run --node-rank=$RANK
        --main-address=$IP
        --main-port=$PORT
        --num-nodes=$NODE_COUNT
        --devices=8 --accelerator=cuda
        /codes/litgpt/litgpt/__main__.py pretrain
        --config=config/tinyllama-openwebmath.yaml --train.micro_batch_size=4
        --out_dir=/checkpoints/$CKPT_DIR
        --logger_name=wandb
        --train.log_interval=1
        --data.init_args.data_path=/dataset/processed/open-web-math
        --train.save_interval=2000

According to the wandb logs, the total number of tokens trained is the same, but the iterations decrease proportionally with the number of nodes.
image

However, the final results show lower performance as the number of nodes increases.

  gsm8k math svamp asdiv mawps tabmwp mathqa mmlu_stemm sat_math avg
32node 2.9 3.2 15.1 22.1 27.9 15.3 12.1 14.2 18.8 14.6
2node 4.1 3.6 17.9 29.7 38.7 15.9 12.3 15.8 18.8 17.4
1node 4.1 3 19.6 29.9 39.4 15.7 9.8 16.5 31.2 18.8

In wandb, it seems that the loss is recorded only for rank 0, so I understand why the loss curve might appear different.
However, I can't figure out why the overall performance decreases.

For clarify, all nodes are using the same learning rate and same global_batch_size.

I'd appreciate any advice on what might be causing this issue and what adjustments I should consider.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

1 participant