Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharded SDPAFwdOp #2565

Merged
merged 5 commits into from
Jul 12, 2024
Merged

Sharded SDPAFwdOp #2565

merged 5 commits into from
Jul 12, 2024

Conversation

cowanmeg
Copy link
Collaborator

Temporarily add support for a sharded forward scaled dot product attention.
Currently, we only support DID parallelization on the logical domain which requires us to split and parallelize an axis at the logical level see #2563. This is a hack until we support DID parallelization on the loop domain after which this PR can be reverted.

Restrictions:

  1. q,k,v inputs are manually sharded before the SDPAFwdOp is created. We cannot rely on sharding propagation or sharding after the Fusion is created, because the dimension checks are called when the op is created.
  2. Only the head dimension is sharded and all inputs and outputs have either a sharded head dimension or unshaded.
  3. DID axis is the outermost axis. This is because during evaluation if we see 5 dimensions, it is assumed the first is the DID axis and is appropriately squeezed from the inputs and unsqueezed onto the outputs.

@cowanmeg cowanmeg requested review from Priya2698 and wujingyue July 10, 2024 16:20
csrc/ir/nodes.cpp Outdated Show resolved Hide resolved
// Add back the device dim axis for outputs with a head dimension.
if (handle_device_dim) {
output = output.unsqueeze(0);
log_sumexp = log_sumexp.unsqueeze(0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the DID axis in log_sumexp?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, we only need it for the backwards, so we don't need to track the DID axis. I'll remove it!

Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I'd wait for Priya's approval.

csrc/root_domain_map.cpp Outdated Show resolved Hide resolved
}
}
// Map D from any input (query/key/value) to output, logsumexp only.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cowanmeg cowanmeg requested a review from Priya2698 July 11, 2024 16:19
@cowanmeg
Copy link
Collaborator Author

!build

@cowanmeg
Copy link
Collaborator Author

!build

Copy link
Collaborator

@Priya2698 Priya2698 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@cowanmeg
Copy link
Collaborator Author

Note: failing tests are unrelated and caused by 'nlohmann/json.hpp' file not found

@cowanmeg cowanmeg merged commit 114e21d into NVIDIA:main Jul 12, 2024
17 of 20 checks passed
jacobhinkle pushed a commit that referenced this pull request Jul 19, 2024
Temporarily add support for a sharded forward scaled dot product
attention.
Currently, we only support DID parallelization on the logical domain
which requires us to split and parallelize an axis at the logical level
see #2563. This is a hack until we
support DID parallelization on the loop domain after which this PR can
be reverted.

Restrictions:
1. q,k,v inputs are manually sharded _before_ the SDPAFwdOp is created.
We cannot rely on sharding propagation or sharding after the Fusion is
created, because the dimension checks are called when the op is created.
2. Only the head dimension is sharded and all inputs and outputs have
either a sharded head dimension or unshaded.
3. DID axis is the outermost axis. This is because during evaluation if
we see 5 dimensions, it is assumed the first is the DID axis and is
appropriately squeezed from the inputs and unsqueezed onto the outputs.
cowanmeg added a commit that referenced this pull request Aug 23, 2024
Adds temporary support for sharded backwards scaled dot product
attention. Until #2563 is
completed.

Similar to #2565
Similar restrictions:
1. All necessary sharded inputs are manually sharded before the
SDPABwdOp is created. We cannot rely on sharding propagation or sharding
after the Fusion is created, because the dimension checks are called
when the op is created.
2. Only the head dimension is sharded and all inputs and outputs have
either a sharded head dimension or unshaded.
3. DID axis is the outermost axis. This is because during evaluation if
we see 5 dimensions, it is assumed the first is the DID axis and is
appropriately squeezed from the inputs and unsqueezed onto the outputs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants