Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train Stucked #44

Open
YinengXiong opened this issue Mar 29, 2021 · 4 comments
Open

Train Stucked #44

YinengXiong opened this issue Mar 29, 2021 · 4 comments

Comments

@YinengXiong
Copy link

Hi ~
I Use
if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model = bnconvert(model) model.cuda()
to use sync-bn during multi-gpu training, but when training the network, it looks like training procedure stucked at final batch in one epoch

image

@vacancy
Copy link
Owner

vacancy commented Mar 29, 2021

The implementation requires that each module on different devices should invoke the batchnorm for exactly SAME amount of times in each forward pass. For example, you can not only call batchnorm on GPU0 but not on GPU1. The #i (i = 1, 2, 3, ...) calls of the batchnorm on each device will be viewed as a whole and the statistics will be reduced. This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this will usually not be the issue for most of the models.

Can you check this?

@YinengXiong
Copy link
Author

So if I want to use model e.g. torchvision, which built with nn.BatchNorm
I should:
0. model = torchvision.models.resnet50()

  1. model = bnconvert(model)
  2. model = DataParallelWithCallback(model)
  3. model.cuda()

Am I right?

@vacancy
Copy link
Owner

vacancy commented Apr 7, 2021

Correct. I suspect the reason is the following:

The implementation requires that each module on different devices should invoke the batchnorm for exactly SAME amount of times in each forward pass. For example, you can not only call batchnorm on GPU0 but not on GPU1. The #i (i = 1, 2, 3, ...) calls of the batchnorm on each device will be viewed as a whole and the statistics will be reduced. This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this will usually not be the issue for most of the models.

@YinengXiong
Copy link
Author

Correct. I suspect the reason is the following:

The implementation requires that each module on different devices should invoke the batchnorm for exactly SAME amount of times in each forward pass. For example, you can not only call batchnorm on GPU0 but not on GPU1. The #i (i = 1, 2, 3, ...) calls of the batchnorm on each device will be viewed as a whole and the statistics will be reduced. This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this will usually not be the issue for most of the models.

thanks a lot

This was referenced Apr 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants