diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 877f2e467..55817bea0 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -18,8 +18,9 @@ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + void set_params_fprop(Flash_fwd_params ¶ms, - // sizes + // sizes const size_t b, const size_t seqlen_q, const size_t seqlen_k, @@ -29,7 +30,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, const size_t h_k, const size_t d, const size_t d_rounded, - // device pointers + // device pointers const at::Tensor q, const at::Tensor k, const at::Tensor v, @@ -44,8 +45,8 @@ void set_params_fprop(Flash_fwd_params ¶ms, int window_size_left, int window_size_right, const float softcap, - bool seqlenq_ngroups_swapped = false, - const bool unpadded_lse = false) { + bool seqlenq_ngroups_swapped=false, + const bool unpadded_lse=false) { // Reset the parameters params = {}; @@ -73,8 +74,8 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.v_batch_stride = v.stride(0); params.o_batch_stride = out.stride(0); if (seqlenq_ngroups_swapped) { - params.q_batch_stride *= seqlen_q; - params.o_batch_stride *= seqlen_q; + params.q_batch_stride *= seqlen_q; + params.o_batch_stride *= seqlen_q; } } @@ -101,14 +102,14 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.d_rounded = d_rounded; // Set the different scale values. -#ifdef FLASHATTENTION_DISABLE_SOFTCAP - TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap."); -#endif + #ifdef FLASHATTENTION_DISABLE_SOFTCAP + TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap."); + #endif if (softcap > 0.0) { params.softcap = softmax_scale / softcap; params.scale_softmax = softcap; params.scale_softmax_log2 = softcap * M_LOG2E; - } else { + } else{ // Remove potential NaN params.softcap = 0.0; params.scale_softmax = softmax_scale; @@ -125,9 +126,9 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; TORCH_CHECK(p_dropout < 1.f); -#ifdef FLASHATTENTION_DISABLE_DROPOUT - TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); -#endif + #ifdef FLASHATTENTION_DISABLE_DROPOUT + TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + #endif // Causal is the special case where window_size_right == 0 and window_size_left < 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. @@ -138,22 +139,22 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.window_size_left = window_size_left; params.window_size_right = window_size_right; -#ifdef FLASHATTENTION_DISABLE_LOCAL - TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), - "This flash attention build does not support local attention."); -#endif + #ifdef FLASHATTENTION_DISABLE_LOCAL + TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), + "This flash attention build does not support local attention."); + #endif params.is_seqlens_k_cumulative = true; -#ifdef FLASHATTENTION_DISABLE_UNEVEN_K - TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); -#endif + #ifdef FLASHATTENTION_DISABLE_UNEVEN_K + TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); + #endif params.unpadded_lse = unpadded_lse; params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; } -void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel = false) { +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { @@ -210,9 +211,9 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n } void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, - const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q, - const int head_size_rounded, const float p_dropout, - const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) { + const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q, + const int head_size_rounded, const float p_dropout, + const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) { // This needs to match with run_mha_fwd_splitkv_dispatch const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); @@ -224,15 +225,11 @@ void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout if (num_splits < 1) { // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block. - params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, - dprops->multiProcessorCount * 2, num_n_blocks, 128); + params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128); } if (params.num_splits > 1) { - at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, - opts.dtype(at::kFloat)); - at::Tensor out_accum = torch::empty( - {params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, - opts.dtype(at::kFloat)); + at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr(); } @@ -240,9 +237,7 @@ void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, } } -void -set_params_alibi(Flash_fwd_params ¶ms, const c10::optional &alibi_slopes_, int batch_size, - int num_heads) { +void set_params_alibi(Flash_fwd_params ¶ms, c10::optional &alibi_slopes_, int batch_size, int num_heads){ #ifdef FLASHATTENTION_DISABLE_ALIBI TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi."); params.alibi_slopes_ptr = nullptr; @@ -252,8 +247,7 @@ set_params_alibi(Flash_fwd_params ¶ms, const c10::optional &alib TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); CHECK_DEVICE(alibi_slopes); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); - TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || - alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads})); + TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads})); params.alibi_slopes_ptr = alibi_slopes.data_ptr(); params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } else { @@ -262,7 +256,7 @@ set_params_alibi(Flash_fwd_params ¶ms, const c10::optional &alib #endif } -std::array +std::vector mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size @@ -294,9 +288,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - CHECK_DEVICE(q); - CHECK_DEVICE(k); - CHECK_DEVICE(v); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); @@ -325,9 +317,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = - seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && - p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); @@ -373,7 +363,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char) q.get_device()}; + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); @@ -382,7 +372,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size // Only return softmax if there's dropout to reduce compilation time if (return_softmax) { TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); - p = torch::empty({batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded}, opts); + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); } Flash_fwd_params params; @@ -393,9 +383,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size num_heads, num_heads_k, head_size, head_size_rounded, q_padded, k_padded, v_padded, out, - /*cu_seqlens_q_d=*/nullptr, - /*cu_seqlens_k_d=*/nullptr, - /*seqused_k=*/nullptr, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, return_softmax ? p.data_ptr() : nullptr, softmax_lse.data_ptr(), p_dropout, @@ -403,7 +393,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size window_size_left, window_size_right, softcap - ); + ); set_params_splitkv(params, batch_size, num_heads, @@ -417,11 +407,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); // Forward kernel will populate memory with the seed and offset. - params.rng_state = reinterpret_cast(rng_state.data_ptr()); + params.rng_state = reinterpret_cast(rng_state.data_ptr()); - if (p_dropout > 0.0) { + if (p_dropout > 0.0) { auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); + gen_, at::cuda::detail::getDefaultCUDAGenerator()); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); params.philox_args = gen->philox_cuda_state(counter_offset); @@ -453,7 +443,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; } -std::array +std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. @@ -494,9 +484,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); - CHECK_DEVICE(q); - CHECK_DEVICE(k); - CHECK_DEVICE(v); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); @@ -529,21 +517,17 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const int page_block_size = !paged_KV ? 1 : k.size(1); TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16"); - if (max_seqlen_q == 1 && - !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case + if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case if (is_causal) { window_size_right = 0; } void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = - max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && - p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { - q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape( - {batch_size * ngroups, num_heads_k, head_size_og}); + q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og}); max_seqlen_q = ngroups; num_heads = num_heads_k; cu_seqlens_q_d = nullptr; @@ -571,7 +555,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - if (seqused_k.has_value()) { + if (seqused_k.has_value()){ auto seqused_k_ = seqused_k.value(); TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); @@ -598,8 +582,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og); if (seqlenq_ngroups_swapped) { - out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape( - {batch_size * ngroups, num_heads_k, head_size_og}); + out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og}); } if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } } else { @@ -614,7 +597,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char) q.get_device()}; + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); @@ -622,13 +605,13 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s // Only return softmax if there's dropout to reduce compilation time if (return_softmax) { TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); - p = torch::empty({batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded}, opts); + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); } if (zero_tensors) { out.zero_(); softmax_lse.fill_(-std::numeric_limits::infinity()); - if (return_softmax) { p.zero_(); } + if (return_softmax) {p.zero_();} } Flash_fwd_params params; @@ -650,7 +633,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s window_size_right, softcap, seqlenq_ngroups_swapped, - /*unpadded_lse*/true); + /*unpadded_lse*/true); params.total_q = total_q; if (paged_KV) { @@ -674,11 +657,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); // Forward kernel will populate memory with the seed and offset. - params.rng_state = reinterpret_cast(rng_state.data_ptr()); + params.rng_state = reinterpret_cast(rng_state.data_ptr()); - if (p_dropout > 0.0) { + if (p_dropout > 0.0) { auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); + gen_, at::cuda::detail::getDefaultCUDAGenerator()); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); params.philox_args = gen->philox_cuda_state(counter_offset); @@ -713,7 +696,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; } -std::array +std::vector mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. @@ -752,9 +735,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype"); - CHECK_DEVICE(q); - CHECK_DEVICE(kcache); - CHECK_DEVICE(vcache); + CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); @@ -796,9 +777,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = - seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && - head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); if (seqlenq_ngroups_swapped) { const int ngroups = num_heads / num_heads_k; q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); @@ -822,10 +801,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he at::Tensor q_padded, kcache_padded, vcache_padded; if (head_size_og % 8 != 0) { q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - kcache_padded = torch::nn::functional::pad(kcache, - torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - vcache_padded = torch::nn::functional::pad(vcache, - torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); } else { q_padded = q; kcache_padded = kcache; @@ -856,7 +833,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char) q.get_device()}; + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); @@ -870,17 +847,17 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he num_heads, num_heads_k, head_size, head_size_rounded, q_padded, kcache_padded, vcache_padded, out, - /*cu_seqlens_q_d=*/nullptr, - /*cu_seqlens_k_d=*/nullptr, - /*seqused_k=*/nullptr, - /*p_ptr=*/nullptr, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + /*p_ptr=*/nullptr, softmax_lse.data_ptr(), - /*p_dropout=*/0.f, + /*p_dropout=*/0.f, softmax_scale, window_size_left, window_size_right, softcap - ); + ); at::Tensor k, v, k_padded, v_padded; if (k_.has_value()) { @@ -891,8 +868,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he v = v_.value(); TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query"); TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query"); - CHECK_DEVICE(k); - CHECK_DEVICE(v); + CHECK_DEVICE(k); CHECK_DEVICE(v); TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension"); int seqlen_knew = k.size(1); @@ -928,8 +904,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he params.is_seqlens_k_cumulative = !(seqlens_k_.has_value()); if (rotary_cos_.has_value()) { - TORCH_CHECK(k_.has_value(), - "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); + TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); auto rotary_cos = rotary_cos_.value(); CHECK_DEVICE(rotary_cos); params.rotary_dim = rotary_cos.size(1) * 2;