Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Progressively decreasing attention windows #80

Open
Vorlent opened this issue May 18, 2024 · 0 comments
Open

Progressively decreasing attention windows #80

Vorlent opened this issue May 18, 2024 · 0 comments

Comments

@Vorlent
Copy link

Vorlent commented May 18, 2024

In the spirit of the paper "The Unreasonable Ineffectiveness of the Deeper Layers" (https://arxiv.org/abs/2403.17887v1), it should be possible to have progressively decreasing attention windows without losing any performance whatsoever.

LLama 3 70B has 80 layers in total and a context window of 8k tokens. The idea is that each layer has access to half the context compared to the previous layer with the final layers having some minimum size for the context window.

Layer Context
0 8k
1 4k
2 2k
3 1k
4 512
5 256
6 128
7 to 80 64

This would add up to 8k^2+4k^2+2k^2+1k^2+512^2 + 256^2+128^2+ 76*64^2 vs 8k^2 * 80. The computational cost of prompt processing would drop by 98%.

I came up with the concept of the idea by looking at these diagrams in the paper:

Screenshot from 2024-05-18 17-54-12
Screenshot from 2024-05-18 17-38-30

The initial layers have already arranged all the information in such a way as to make it accessible "locally" within the sliding window. In other words, the model already implements some form of hierarchical attention with the initial layers being responsible for performing the heavy lifting involved with global attention. If the above described optimization is feasible, the need for linear attention mechanisms vanishes into thin air as the initial quadratic attention mechanism is unavoidable for good LLM answering performance.

Choosing whether a token is worth looking at or not, requires some initial sweeping pass. If you have a token and want to know which tokens in the preceding context "resonate" with your token, you will have to do a linear pass. Repeating this linear pass for every generated token results in quadratic attention. There is no obvious way one could avoid this. One could plausibly take the context and turn a group of k tokens into a block, but then it is still necessary to perform block-wise quadratic attention. You merely go from O(n^2) to O((n/k)^2). This is not a big victory compared to a decreasing window, where k gets bigger and bigger with each layer to the point where the first three layers dominate 80%+ of the computation time and the remaining layers contribute almost nothing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant