-
Notifications
You must be signed in to change notification settings - Fork 409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable autograd graph to propagate after multi-device syncing for loss functions in ddp
#2754
Conversation
That sounds good to me, but can we add a test for this enhancement? |
Thanks for the prompt response @Borda. I'm thinking that I can make an additional unittest in |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #2754 +/- ##
=======================================
- Coverage 69% 69% -0%
=======================================
Files 344 330 -14
Lines 18824 18653 -171
=======================================
- Hits 12971 12801 -170
+ Misses 5853 5852 -1 |
yeah, that sounds good to me :) |
6c926d7
to
1d0dabe
Compare
Update: to accommodate both cases where tensors from different ranks have the same/different shape, the line to put the original tensor (holding the AD graph) back into the gathered list was added in two places in the code. Because of the two cases, I wrote two unittests to account for each. Interestingly, both pass |
that is strange and worse some more investigation... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked briefly why the tests do not pass on older versions of Pytorch but could not find a reason.
I think we should just only support this for Pytorch > 2.0 and then add this to the documentation.
dc35370
to
e693ace
Compare
ce5dca1
to
ffc67f6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seeems the two test functions are now included twice in the test_ddp.py
file, please check
Alright, I finally sat down to understand what was going on here. The non-deterministic behavior was really strange to me, so I tried a lot of debugging and realized that order of output of the
I found out that the reason for this is that when the torchmetrics/tests/unittests/conftest.py Lines 66 to 68 in abdd2c4
and then pytest.pool.starmap is called when we want to run a test in ddp mode. The solution was to call the setup function during the test function and everything is in the expected order. See this commit for details: 48e699b.This has not been a problem before because we normally reduce all the states in some way e.g. sum them and then the order does not matter at all. Hopefully, this also means that it works regardless of Pytorch version. @cw-tan sorry for the headache this must have been to debug. I have significantly simplified the tests you had for me to understand what was going on. Hope this still is fine with you. |
@SkafteNicki Fantastic, thank you so much! I'm just excited to see this feature released so I can remove the monkeypatch in my own code to achieve the same effects. I think the docs are the remaining change -- still some details about only being PyTorch 2 compatible, but the tests have passed for other versions. |
ddp
)ddp
What does this PR do?
Fixes #2745
Single-line enhancement proposed in #2745, that is, to enable the propagation of the autograd graph after the
all_gather
operation. This is useful for syncing loss functions in addp
setting.Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃
📚 Documentation preview 📚: https://torchmetrics--2754.org.readthedocs.build/en/2754/