-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix sdpfa_fwd to not assume the presence of DIDs. #3116
Conversation
!build |
!build |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LTGM
self.sched.parallelize(t, 0, nvfuser.ParallelType.mesh_x) | ||
|
||
torch.cuda.set_device(mpi_test.local_rank) | ||
torch.manual_seed(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we make the seed manual for all tests, and put this line in the test's constructor ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes and cc @Priya2698 who briefly mentioned how this should be done which I forgot :) We do have
Line 387 in 61a77e0
class NVFuserTest(TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are using this with pytest
:
Fuser/tests/python/test_python_frontend.py
Lines 41 to 42 in e01e23e
@pytest.mark.skipif(is_pre_volta(), reason="Only supported on Volta and newer devices.") | |
class TestNvFuserFrontend(NVFuserTest): |
What do you mean by doing it differently for pytest?
!build |
Similar to #3073,
sdpfa_fwd
shouldn't assume DIDs are available at definition time. Instead, treat extra preceding dimensions as batch at definition time and check they are device parallel at evaluation time.This is required to land #3115.