Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

question about masking #92

Open
esason opened this issue Dec 18, 2024 · 2 comments
Open

question about masking #92

esason opened this issue Dec 18, 2024 · 2 comments

Comments

@esason
Copy link

esason commented Dec 18, 2024

hi,

I would like to use a mask that the topk from each attention row (A = Q @ K^T)
and calculate the full attention based on this.

  • Is it possible to calculate it using "score" within score_mod function. (it seems that this is scalar... :-( )
  • If not, is it possible to put in my masked_score, and exploit flex attention only for the reminder of calculation, namely the Softmax(masked_score) @ V
  • lastly, I wonder if i can use binary mask tensor with dims [B, H, T, T] as the binary mask.

thanks!

@drisspg
Copy link
Contributor

drisspg commented Dec 18, 2024

This is not possible today. You cannot have a score or a mask modification that performs a reduction within the kernel.

How would you put it into the mask mod without know the full scores apriori?

@esason
Copy link
Author

esason commented Dec 19, 2024

@drisspg

hi
Practically, there are some approximators for to the mask (i research for dynamic not fixed); so i wanted to combine these methods efficiently with the flex pipeline.
so given inputs of q,k,v , suppose i have mask=func_approx(q,k), then I thought to use it with flex.

There are more useful scenarios to research ...

  • If i have also the masked_score, can i have benefits with flex from the reminder of calculation.
    according to flash-attention study, the score calculation itself (even full) is quite cheap, relative to the whole attention calculation.
    so if i take even the full score, just for start, and heavily prune it. What is the benefit that i can have with the reminder of calculation.
    thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants