Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing online-softmax #1

Open
FSSRepo opened this issue Feb 29, 2024 · 3 comments
Open

Implementing online-softmax #1

FSSRepo opened this issue Feb 29, 2024 · 3 comments

Comments

@FSSRepo
Copy link
Owner

FSSRepo commented Feb 29, 2024

I'm opening this issue in this project to facilitate communication; I need clarification on certain doubts.

@ggerganov Could you take a quick look at my code in flash-matrix.cu where I've implemented a kernel that parallelizes along the sequence (increasing occupancy and reducing the tail effect) for a query batch size of 1 (inference), with a head_dim of 128, which is particularly problematic for improving performance with very small batch sizes. I've already implemented (QK^T*scale+mask)V and it works very well for lengths up to 16K.

However, implementing softmax is proving to be very difficult for me (I can't seem to find a way to implement it). Perhaps you, who have tinkered with creating the kernel and have a better understanding of the FA 2.0 paper, can guide me a bit.

@ggerganov
Copy link

I cloned and ran the repo, but it looks like the CUDA kernel produces just 0s:

$ ▶ bash run.sh 
src/flash-matrix.cu(320): warning #177-D: variable "warp_buffer" was declared but never referenced
      float* warp_buffer = (float*)(shmem + tensor_elements*sizeof(half) + (kv_size/reduce_block + 2) * sizeof(float) + warp_index*(tensor_elements + 2)*sizeof(float));
             ^

Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"

src/cuda_info.h: In function ‘void print_cuda_info()’:
src/cuda_info.h:40:8: warning: format ‘%d’ expects argument of type ‘int’, but argument 2 has type ‘size_t’ {aka ‘long unsigned int’} [-Wformat=]
   40 |             printf("Max Shared Memory Per MP: %d\n", deviceProp.sharedMemPerMultiprocessor);
      |        ^    ~~~~~~~~~~~~~~~~~~~~~~~~~~~  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
      |                                                     |
      |                                                     size_t {aka long unsigned int}
Device Count: 1

GPU Information 0:
Name: NVIDIA GeForce RTX 2060 SUPER
Architecture: 7.5
Number of SMs: 34
L2 Cache size: 4194304 bytes
VRAM: 7971 MB
Max Threads per Block: 1024
Max Register Per Block: 65536
Max Shared Memory Per Block: 49152
Max Blocks Per MP: 16
Max Threads Per MP: 1024
Max Register Per MP: 65536
Max Shared Memory Per MP: 65536
Max Grid Size: (2147483647, 65535, 65535)
Max Block Size: (1024, 1024, 64)

M= 1.8393, S=407.8494
M= 2.0483, S=335.9495
M= 1.6821, S=483.4723
M= 1.8194, S=416.5851
M= 1.6709, S=485.1035
M= 1.9016, S=388.8950
M= 1.8026, S=427.3427
M= 1.7276, S=456.1953
M= 1.7920, S=441.3259
M= 2.0294, S=342.5456
M= 1.7775, S=435.9819
M= 2.0669, S=324.5927
M= 1.6992, S=476.6263
M= 1.7908, S=437.2062
M= 2.0039, S=349.1781
M= 1.8033, S=428.0863
M= 1.9781, S=356.0716
M= 1.9747, S=357.9479
M= 1.8237, S=414.7466
M= 2.0712, S=326.3984
M= 1.7997, S=431.9955
M= 2.0420, S=339.2790
M= 2.1695, S=295.2199
M= 1.9372, S=373.4118
M= 2.0962, S=319.4227
M= 1.8040, S=426.7111
M= 1.5584, S=543.4086
M= 2.1283, S=310.8639
M= 1.7439, S=450.5404
M= 2.0426, S=335.2083
M= 1.6806, S=473.5671
M= 1.9723, S=365.3445
---------------- QKV ------------------
-0.01777f, -0.01389f, -0.01455f, -0.03189f, -0.00678f, 0.00009f, 0.01625f, 0.04121f, 0.01610f, -0.00425f, -0.02576f, -0.00738f, -0.00771f, 0.01742f, 0.00726f, 0.03700f, 
-0.00148f, 0.01505f, -0.01111f, -0.00342f, -0.00525f, 0.02027f, 0.01050f, 0.00946f, -0.01828f, 0.01539f, 0.01486f, -0.00015f, 0.01740f, 0.01277f, 0.00252f, -0.01090f, 


Shared memory: 9.57 KB

reduce block: 16
cuda time: 0.0348 ms
R (-0.0562) CUDA(0.0000) diff: 0.0562 - head = 7, dim = 35
0.9722 -0.5815 0.6025 0.5415 0.0069 0.6006 0.6587 0.4626 [0.0494] [-0.8633] [-0.7080] [0.2666] [-0.7412] [0.0293] [0.3000] [0.7026]  = 2.2990
---------------- CUDA QKV ------------------
0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 
0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 0.00000f, 

Is this expected? This seems to be using the flash_attn_ext_f16 kernel

@FSSRepo
Copy link
Owner Author

FSSRepo commented Mar 4, 2024

Try it now, it should work like this, using flash_attn_row kernel, to use flash_attn_ext_f16 kernel you can specify --no-kv-parallel:

Device Count: 1

GPU Information 0:
Name: NVIDIA GeForce RTX 3050 Laptop GPU
Architecture: 8.6
Number of SMs: 16
L2 Cache size: 1572864 bytes
VRAM: 4095 MB
Max Threads per Block: 1024
Max Register Per Block: 65536
Max Shared Memory Per Block: 49152
Max Blocks Per MP: 16
Max Threads Per MP: 1536
Max Register Per MP: 65536
Max Shared Memory Per MP: 102400
Max Grid Size: (2147483647, 65535, 65535)
Max Block Size: (1024, 1024, 64)

---------------- Reference ------------------
0.00752f, -0.00013f, -0.02614f, -0.02851f, 0.00029f, -0.00679f, 0.01165f, 0.00011f, -0.00125f, -0.01256f, 0.00305f, 0.01789f, -0.00508f, 0.00248f, -0.01243f, -0.01494f,
0.00310f, 0.02219f, 0.00115f, 0.01248f, -0.01541f, 0.00151f, 0.00781f, -0.01498f, -0.00929f, -0.01956f, 0.00898f, -0.01847f, 0.00419f, 0.00576f, 0.01077f, 0.01651f,

---------------- Parallel KV CUDA ------------------
0.00752f, -0.00013f, -0.02614f, -0.02852f, 0.00028f, -0.00680f, 0.01165f, 0.00012f, -0.00124f, -0.01256f, 0.00304f, 0.01790f, -0.00509f, 0.00249f, -0.01243f, -0.01495f,
0.00310f, 0.02221f, 0.00115f, 0.01248f, -0.01542f, 0.00151f, 0.00780f, -0.01499f, -0.00929f, -0.01955f, 0.00898f, -0.01847f, 0.00419f, 0.00575f, 0.01077f, 0.01651f,


cuda time: 0.5146 ms
Shared memory: 9.57 KB
R (0.0394) CUDA(0.0395) diff: 0.0000 - head = 19, dim = 80

@ggerganov
Copy link

So both parallel and non-parallel version seem to match the reference numbers. Does that mean it works correctly?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants