-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sharded SDPAFwdOp #2565
Sharded SDPAFwdOp #2565
Conversation
csrc/ir/nodes.cpp
Outdated
// Add back the device dim axis for outputs with a head dimension. | ||
if (handle_device_dim) { | ||
output = output.unsqueeze(0); | ||
log_sumexp = log_sumexp.unsqueeze(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need the DID axis in log_sumexp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, we only need it for the backwards, so we don't need to track the DID axis. I'll remove it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I'd wait for Priya's approval.
csrc/root_domain_map.cpp
Outdated
} | ||
} | ||
// Map D from any input (query/key/value) to output, logsumexp only. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still needed given https://github.com/NVIDIA/Fuser/pull/2565/files#r1672758030?
!build |
!build |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Note: failing tests are unrelated and caused by |
Temporarily add support for a sharded forward scaled dot product attention. Currently, we only support DID parallelization on the logical domain which requires us to split and parallelize an axis at the logical level see #2563. This is a hack until we support DID parallelization on the loop domain after which this PR can be reverted. Restrictions: 1. q,k,v inputs are manually sharded _before_ the SDPAFwdOp is created. We cannot rely on sharding propagation or sharding after the Fusion is created, because the dimension checks are called when the op is created. 2. Only the head dimension is sharded and all inputs and outputs have either a sharded head dimension or unshaded. 3. DID axis is the outermost axis. This is because during evaluation if we see 5 dimensions, it is assumed the first is the DID axis and is appropriately squeezed from the inputs and unsqueezed onto the outputs.
Adds temporary support for sharded backwards scaled dot product attention. Until #2563 is completed. Similar to #2565 Similar restrictions: 1. All necessary sharded inputs are manually sharded before the SDPABwdOp is created. We cannot rely on sharding propagation or sharding after the Fusion is created, because the dimension checks are called when the op is created. 2. Only the head dimension is sharded and all inputs and outputs have either a sharded head dimension or unshaded. 3. DID axis is the outermost axis. This is because during evaluation if we see 5 dimensions, it is assumed the first is the DID axis and is appropriately squeezed from the inputs and unsqueezed onto the outputs.
Temporarily add support for a sharded forward scaled dot product attention.
Currently, we only support DID parallelization on the logical domain which requires us to split and parallelize an axis at the logical level see #2563. This is a hack until we support DID parallelization on the loop domain after which this PR can be reverted.
Restrictions: