-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
flash-attention-like gpu kernel #23
Draft
chengchingwen
wants to merge
2
commits into
master
Choose a base branch
from
flash
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Merged
refine forward
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #23 +/- ##
===========================================
- Coverage 74.51% 48.07% -26.45%
===========================================
Files 30 38 +8
Lines 2052 3162 +1110
===========================================
- Hits 1529 1520 -9
- Misses 523 1642 +1119 ☔ View full report in Codecov by Sentry. |
(with 4090) julia> using NeuralAttentionlib, CUDA, BenchmarkTools; using NeuralAttentionlib: Flash
julia> x = CUDA.randn(Float32, 64, 512, 32); y = CUDA.randn(Float32, 64, 512, 32); z = CUDA.randn(Float32, 64, 512, 32);
julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z);
604.971 μs (211 allocations: 12.69 KiB)
julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z);
578.523 μs (59 allocations: 3.03 KiB)
julia> x = CUDA.randn(Float32, 64, 1024, 128); y = CUDA.randn(Float32, 64, 1024, 128); z = CUDA.randn(Float32, 64, 1024, 128);
julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z);
8.371 ms (206 allocations: 12.30 KiB)
julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z);
6.397 ms (59 allocations: 3.03 KiB)
julia> x = CUDA.randn(Float32, 32, 4096, 128); y = CUDA.randn(Float32, 32, 4096, 128); z = CUDA.randn(Float32, 32, 4096, 128);
julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z);
131.801 ms (208 allocations: 12.33 KiB)
julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z); <=
52.394 ms (59 allocations: 3.03 KiB)
(Float32[0.012302159 0.017965356 … -0.00015686426 0.014433293; -0.018675877 0.006785178 … 0.0022356245 0.0063521285; … ; -0.032304276 -0.042408943 … -0.039765503 -0.011026613; -0.008006868 0.03263003 … -0.012576684 0.007417702;;; -0.018318532 -0.0031435888 … 0.007511898 -0.034395427; -0.039739955 0.0067583104 … 0.001627093 0.004926813; … ; -0.014008925 -0.027526634 … -0.022851078 -0.03130795; 0.031661086 0.016136607 … 0.028245382 0.010535367;;; 0.008197796 0.026282633 … -0.0089175515 -0.03374061; -0.049456023 -0.033448916 … -0.03887646 -0.0071907463; … ; 0.025568059 -0.009142641 … 0.024252633 -0.0011169636; -0.00068929675 0.031708676 … 0.015330821 0.040038884;;; … ;;; 0.013630414 0.04991551 … -0.0006895974 -0.0067350264; -0.0010544751 -0.018107645 … 0.0016674573 0.026153116; … ; -0.021417249 0.0038626934 … 0.0025041953 0.012506626; -0.020112598 -0.014542421 … 0.012148654 -0.012378455;;; 0.0074338443 0.04256689 … -0.024805214 -0.03777187; -0.004147524 -0.0038875814 … -0.046487812 -0.010641718; … ; 0.0037521122 0.007813704 … 0.008547867 -0.0024645203; 0.021574477 0.0029059318 … -0.02713483 -0.019076949;;; -0.05634497 0.041607294 … 0.012316624 -0.0070284223; 0.015494817 -0.018000303 … 0.03147435 -0.050554577; … ; -0.04466497 -0.0039102207 … -0.036508568 -0.04374122; 0.048409604 -0.0018546274 … 0.020939333 0.010200917], Float32[8.759159 8.817464 … 8.727321 8.78187;;; 8.877047 8.600131 … 8.591627 8.652889;;; 9.015136 8.857832 … 8.761169 8.694989;;; … ;;; 8.665132 8.844432 … 8.560938 8.978898;;; 8.554172 8.873826 … 8.87549 9.010774;;; 8.905191 8.768375 … 8.803723 8.813686])
julia> x = CUDA.randn(Float32, 128, 4096, 128); y = CUDA.randn(Float32, 128, 4096, 128); z = CUDA.randn(Float32, 128, 4096, 128); <=
julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z);
138.517 ms (208 allocations: 12.33 KiB)
julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z);
192.907 ms (59 allocations: 3.03 KiB)
julia> CUDA.@time NeuralAttentionlib.naive_qkv_attention(x, y, z);
0.149221 seconds (6.77 k CPU allocations: 455.070 KiB) (5 GPU allocations: 8.252 GiB, 0.02% memmgmt time)
julia> CUDA.@time Flash.flash_attention_forward(x, y, z);
0.208802 seconds (60 CPU allocations: 3.062 KiB) (2 GPU allocations: 258.000 MiB, 0.01% memmgmt time)
julia> using Zygote
julia> x = CUDA.randn(Float32, 64, 512, 32); y = CUDA.randn(Float32, 64, 512, 32); z = CUDA.randn(Float32, 64, 512, 32); dO = CUDA.randn(Float32, 64, 512, 32);
julia> @btime CUDA.@sync Zygote.pullback($(NeuralAttentionlib.naive_qkv_attention), $x, $y, $z)[2]($dO);
602.745 μs (209 allocations: 12.53 KiB)
#75 (generic function with 1 method)
julia> @btime CUDA.@sync Flash.flash_attention_backward($dO, Flash.flash_attention_forward($x, $y, $z)..., $x, $y, $z);
1.724 ms (132 allocations: 7.28 KiB)
julia> x = CUDA.randn(Float32, 64, 1024, 128); y = CUDA.randn(Float32, 64, 1024, 128); z = CUDA.randn(Float32, 64, 1024, 128); dO = CUDA.randn(Float32, 64, 1024, 128);
julia> @btime CUDA.@sync Zygote.pullback($(NeuralAttentionlib.naive_qkv_attention), $x, $y, $z)[2]($dO);
16.974 ms (324 allocations: 19.52 KiB)
julia> @btime CUDA.@sync Flash.flash_attention_backward($dO, Flash.flash_attention_forward($x, $y, $z)..., $x, $y, $z);
24.202 ms (132 allocations: 7.28 KiB)
julia> CUDA.@time Zygote.pullback(NeuralAttentionlib.naive_qkv_attention, x, y, z)[2](dO);
0.051180 seconds (5.96 k CPU allocations: 335.906 KiB) (11 GPU allocations: 1.626 GiB, 0.07% memmgmt time)
julia> CUDA.@time Flash.flash_attention_backward(dO, Flash.flash_attention_forward(x, y, z)..., x, y, z);
0.033187 seconds (136 CPU allocations: 7.391 KiB) (5 GPU allocations: 128.500 MiB, 0.05% memmgmt time)
julia> x = CUDA.randn(Float32, 32, 4096, 128); y = CUDA.randn(Float32, 32, 4096, 128); z = CUDA.randn(Float32, 32, 4096, 128); dO = CUDA.randn(Float32, 32, 4096, 128);
julia> @btime CUDA.@sync Zygote.pullback($(NeuralAttentionlib.naive_qkv_attention), $x, $y, $z)[2]($dO);
ERROR: Out of GPU memory trying to allocate 8.000 GiB
Effective GPU memory usage: 99.85% (23.607 GiB/23.642 GiB)
Memory pool usage: 16.312 GiB (22.875 GiB reserved)
julia> @btime CUDA.@sync Flash.flash_attention_backward($dO, Flash.flash_attention_forward($x, $y, $z)..., $x, $y, $z);
171.479 ms (132 allocations: 7.28 KiB)
julia> CUDA.@time Flash.flash_attention_backward(dO, Flash.flash_attention_forward(x, y, z)..., x, y, z); <=
0.189891 seconds (136 CPU allocations: 7.391 KiB) (5 GPU allocations: 258.000 MiB, 0.01% memmgmt time)
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is an initial attempt to adopt the technique used in flash attention. The implementation basically follows the pseudo-code in flash attention 2 paper. The code is done in a CUDA WMMA fashion, so we should be able to opt-in/out WMMA instructions.
Currently, this is only a draft for testing. I'll see if it is possible to merge with the existing attention interface. It should at least support arbitrary masks, dropout, and the backward function.
Some simple benchmark (with RTX 3090 24GB):
(with A100 80GB):