diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index bcd170411e7cb..c53cda16d4714 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -739,6 +739,9 @@ void paged_attention_v1_launcher( // NOTE(woosuk): To reduce the compilation time, we only compile for the // head sizes that we use in the model. However, we can easily extend this // to support any head size which is a multiple of 16. + case 32: + LAUNCH_PAGED_ATTENTION_V1(32); + break; case 64: LAUNCH_PAGED_ATTENTION_V1(64); break; @@ -903,6 +906,9 @@ void paged_attention_v2_launcher( // NOTE(woosuk): To reduce the compilation time, we only compile for the // head sizes that we use in the model. However, we can easily extend this // to support any head size which is a multiple of 16. + case 32: + LAUNCH_PAGED_ATTENTION_V2(32); + break; case 64: LAUNCH_PAGED_ATTENTION_V2(64); break; diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index abb4e3bea14bb..1baa16a2e0243 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -375,6 +375,9 @@ void paged_attention_v1_impl_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { + case 32: + LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); + break; case 64: LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); break; @@ -692,6 +695,9 @@ void paged_attention_v2_impl_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { + case 32: + LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); + break; case 64: LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); break; diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 6b270ffd5bc00..8df6d4ced9dc6 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -10,7 +10,7 @@ class PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 128, 256] + return [32, 64, 80, 96, 112, 128, 256] @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 92023d5b75f5a..076f151ffcb61 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -34,7 +34,7 @@ class PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 120, 128, 192, 256] + return [32, 64, 80, 96, 112, 120, 128, 192, 256] @staticmethod def get_kv_cache_shape(