Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in scatter #2245

Merged
merged 2 commits into from
May 15, 2024
Merged

Fix bug in scatter #2245

merged 2 commits into from
May 15, 2024

Conversation

samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented May 14, 2024

Fixes a subtle bug, exposed by #2168

@samnordmann samnordmann requested review from wujingyue and cowanmeg May 14, 2024 18:47
@samnordmann
Copy link
Collaborator Author

samnordmann commented May 14, 2024

I find this bug really counter-intuitive!

I thought it was already checked here that the buffers were contiguous...

@samnordmann
Copy link
Collaborator Author

!build --dist

@cowanmeg
Copy link
Collaborator

I find this bug really counter-intuitive!

I thought it was already checked here that the buffers were contiguous...

Huh...I don't know how this didn't come up earlier...

@cowanmeg
Copy link
Collaborator

We have assumed that the input tensor is contiguous when lowering comms (probably should have added this contiguous() call before), but our tests have all have contiguous aten inputs, so I'm not certain where the non-contiguity got introduced...

@wujingyue
Copy link
Collaborator

@samnordmann samnordmann requested a review from wujingyue May 14, 2024 20:09
@wujingyue
Copy link
Collaborator

We have assumed that the input tensor is contiguous when lowering comms (probably should have added this contiguous() call before), but our tests have all have contiguous aten inputs, so I'm not certain where the non-contiguity got introduced...

This is exposed by #2168. But the root cause I believe is that insertReshardings doesn't set allocation domain properly. IIRC, you and @jjsjann123 noticed this potential problem in another PR, but we never get a chance to fix this properly.

Below are fusion IR for the two stages. Although Set.Permute is inserted, because the output of the first stage doesn't have an allocation domain, Set.Permute is allowed to produce a non-contiguous tensor (in this case, strides=[3, 12, 1]) for speed.

Inputs:
  T0_g[ iS0{3}, iS1{4}, iS2{3}, iS3{5} ] (DeviceMesh{1, }), float
Outputs:
  T4_g[ iS15{4}, iS14{3}, iS16{3} ] (DeviceMesh{1, }), float

%kernel_math {
T1_l[ iS4{3}, iS5{4}, iS6{3}, rS7{5} ] (DeviceMesh{1, })
   = reduction( T0_g[ iS0{3}, iS1{4}, iS2{3}, iS3{5} ] (DeviceMesh{1, }), op = add, initial value = float(0), allreduce = false )
T4_g[ iS15{4}, iS14{3}, iS16{3} ] (DeviceMesh{1, })
   = Set.Permute( T1_l[ iS4{3}, iS5{4}, iS6{3}, rS7{5} ] (DeviceMesh{1, }), cache_op=Streaming )
}

[4, 3, 3]
[3, 12, 1]
Inputs:
  T5_g[ ideviceIdx.x17{4}, iS18{3}, iS19{3} ] (DeviceMesh{0, 1, 2, 3, }), float
Outputs:
  T3_g[ iS11{3}, ideviceIdx.x12{4}, iS13{3} ] (DeviceMesh{0, 1, 2, 3, }), float

%kernel_math {
T6_l[ iS21{3}, ideviceIdx.x20{4}, iS22{3} ] (DeviceMesh{0, 1, 2, 3, })
   = Set.Permute( T5_g[ ideviceIdx.x17{4}, iS18{3}, iS19{3} ] (DeviceMesh{0, 1, 2, 3, }), cache_op=Streaming )
T3_g[ iS11{3}, ideviceIdx.x12{4}, iS13{3} ] (DeviceMesh{0, 1, 2, 3, })
   = T6_l[ iS21{3}, ideviceIdx.x20{4}, iS22{3} ] (DeviceMesh{0, 1, 2, 3, })
   + T6_l[ iS21{3}, ideviceIdx.x20{4}, iS22{3} ] (DeviceMesh{0, 1, 2, 3, });
}

I think our options are:

  1. Merge this PR, which now sounds like a bandaid because it leaves performance on the table. In the above example, doing reduction+permute in one kernel is likely fast than in two kernels.
  2. Fix insertResharding to set allocation domains properly. This sounds like the right fix.
  3. Revert Allocation order refactor #2168, which isn't the root cause but is unfortunately the trigger.

I suspect option 2 will take a while and it's a bad idea to leave CI broken, so I'll cross that out. Option 3 is safest because #2168 could triggered other failure cases that we are not aware just yet. Option 1 is suboptimal, but as long as we (👀 @cowanmeg) fix the root cause soon, we should be fine. Wdyt?

@jjsjann123
Copy link
Collaborator

Oops. sorry about that. 😛

@samnordmann ignore the thunder tests. there's something with transformer_engine.

@wujingyue wujingyue merged commit 1a7c6f6 into NVIDIA:main May 15, 2024
34 of 37 checks passed
@cowanmeg
Copy link
Collaborator

Yes, I agree option 2 is the correct fix. Plus we need to set allocation domain for DID parallelism on the leaf domain to work. Hopefully, I can find some time soon to work on this!

cowanmeg added a commit that referenced this pull request May 30, 2024
Sets allocation domain of sharded tensors during the pass
`propagateShardingsAndSetAllocationDomain`.
The two passes are merged in attempt to reduce the number of passes over
all expressions in the fusion.

Allocation domain is set to the tv's leaf domain. Since presegmentation
passes and scheduling occur after the sharding passes, the leaf domain
is identical to the rfact domain. After DID parallelization of the leaf
domain is allowed the leaf and rfactor domain will not be the same.

This will avoid issues such as
#2245 (comment) and
allow the `AllocationDomainPass` presegmentation pass on for distributed
matmul tests
protonu pushed a commit that referenced this pull request May 30, 2024
Sets allocation domain of sharded tensors during the pass
`propagateShardingsAndSetAllocationDomain`.
The two passes are merged in attempt to reduce the number of passes over
all expressions in the fusion.

Allocation domain is set to the tv's leaf domain. Since presegmentation
passes and scheduling occur after the sharding passes, the leaf domain
is identical to the rfact domain. After DID parallelization of the leaf
domain is allowed the leaf and rfactor domain will not be the same.

This will avoid issues such as
#2245 (comment) and
allow the `AllocationDomainPass` presegmentation pass on for distributed
matmul tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants