Memory bandwidth efficient sparse tree attention
- (precompute) chunk the tree into query blocks
- (precompute) compute unique ancestors, attention mask, and leaves for each block
- (runtime) only load keys and values for the query block's unique ancestors and leaves
- (runtime) go fast
- A100 Colab Benchmark
go forth, search the tree of possible futures
Notes on precomputation:
- Can probably make this fast enough for runtime with a bit more work since for a dynamic tree structure (i.e. dependent on the model's output), we only need to compute these kernel inputs once, and then they get reused by all attention layers in the model
- Static tree structures are still useful: Medusa uses a size 256 static left weighted tree that gets populated via cartesian products of their multiple topk output heads to accelerate batch size 1 inference by ~3x
Todo:
- Organize blocks based on DFS odering to minimize the number of blocks that need to load the same ancestor KVs (i.e. maximize the shared lineage of each block)
Credits: