From 50601bf4588f1277512bf19696c2643da86fb573 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 19 May 2024 07:06:38 +0000 Subject: [PATCH] Use int64_t for page pointer arth --- csrc/flash_attn/src/utils.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 4f999a6b7..4f655a4c3 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -296,21 +296,21 @@ void cp_async_wait() { // assumes that the tensor has already been positioned at the correct head. template __forceinline__ __device__ -int resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size, +int64_t resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size, const int* block_table, const int page_stride, const int row_stride) { constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow; constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread; constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad; constexpr int kBlockN = Kernel_traits::kBlockN; - const int col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad; - const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread; - const int global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN; - const int page_offset = global_row_offset % page_block_size; - const int virtual_page_idx = global_row_offset / page_block_size; - - return block_table[virtual_page_idx] * page_stride - + page_offset * row_stride + const int64_t col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad; + const int64_t block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread; + const int64_t global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN; + const int64_t page_offset = global_row_offset % page_block_size; + const int65_t virtual_page_idx = global_row_offset / page_block_size; + + return ((int64_t) block_table[virtual_page_idx]) * ((int64_t) page_stride) + + page_offset * ((int64_t) row_stride) + col_offset; }