We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using of a start position index in a fused attention kernel does not work.
START_IDX = 128 def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: mask_y = torch.full((1, 1, q.size(2), q.size(2)), float("-inf")) mask_y = torch.triu(mask_y, diagonal=START_IDX + 1).float() att_y = (q @ k.transpose(-2, -1)) * scale att_y = att_y + mask_y.to(att_y) att_y = torch.nn.functional.softmax(att_y, dim=-1) return att_y @ v q = torch.randn([4, 32, 4096, 128], dtype=torch.float16, device="cuda") k = torch.randn([4, 32, 4096, 128], dtype=torch.float16, device="cuda") v = torch.randn([4, 32, 4096, 128], dtype=torch.float16, device="cuda") scale = 1 / math.sqrt(128) x = triton_fa(q, k, v, scale, True, START_IDX) y = attention_reference(q, k, v) print(torch.max(torch.abs(x - y))) print(torch.sum(x - y))
Almost identical prediction as with the vanilla implementation for any start position index.
Returns nan for any START_IDX != 0.
nan
START_IDX != 0
torch==2.0.0 triton==2.0.0
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Description
Using of a start position index in a fused attention kernel does not work.
Steps to reproduce
Expected Behavior
Almost identical prediction as with the vanilla implementation for any start position index.
Actual Behavior
Returns
nan
for anySTART_IDX != 0
.Your environment
torch==2.0.0
triton==2.0.0
Self-service
Code of Conduct
The text was updated successfully, but these errors were encountered: