You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
With data parallelism, expect b to be DIDy parallel in addition to the hidden dimension.
Due to #2563, we have to manually split the device dimension in the logical domain instead of having it as logical-to-loop transforms. This has prevented us from having a drop-in replacement. For example, the weight has to be 3D and torch.linear takes 2D weight.
#3073 is an attempt to support case 1 with the limitation of #2563. I did this first because having case 1 fused turns out to be important for performance.
Other cases are yet to be done, and probably should be done after #2563 to avoid accumulating too many tech debts. Note: for case 2 in particular, we'll also need to decompose a sharded linear into matmul + collective + biasadd.
The text was updated successfully, but these errors were encountered:
This feature request is to create a drop-in replacement for https://pytorch.org/docs/stable/generated/torch.nn.functional.linear.html that's sharded.
A linear layer can be sharded in several ways. For example,
b
to beDIDy
parallel in addition to the hidden dimension.Due to #2563, we have to manually split the device dimension in the logical domain instead of having it as logical-to-loop transforms. This has prevented us from having a drop-in replacement. For example, the weight has to be 3D and
torch.linear
takes 2D weight.#3073 is an attempt to support case 1 with the limitation of #2563. I did this first because having case 1 fused turns out to be important for performance.
Other cases are yet to be done, and probably should be done after #2563 to avoid accumulating too many tech debts. Note: for case 2 in particular, we'll also need to decompose a sharded linear into matmul + collective + biasadd.
The text was updated successfully, but these errors were encountered: