Skip to content

Commit

Permalink
Use int64_t for page pointer arth
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon committed May 19, 2024
1 parent f80aa0f commit 50601bf
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions csrc/flash_attn/src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,21 +296,21 @@ void cp_async_wait() {
// assumes that the tensor has already been positioned at the correct head.
template <typename Kernel_traits>
__forceinline__ __device__
int resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
int64_t resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) {
constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad;
constexpr int kBlockN = Kernel_traits::kBlockN;

const int col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad;
const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
const int global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN;
const int page_offset = global_row_offset % page_block_size;
const int virtual_page_idx = global_row_offset / page_block_size;

return block_table[virtual_page_idx] * page_stride
+ page_offset * row_stride
const int64_t col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad;
const int64_t block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
const int64_t global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN;
const int64_t page_offset = global_row_offset % page_block_size;
const int65_t virtual_page_idx = global_row_offset / page_block_size;

return ((int64_t) block_table[virtual_page_idx]) * ((int64_t) page_stride)
+ page_offset * ((int64_t) row_stride)
+ col_offset;
}

Expand Down

0 comments on commit 50601bf

Please sign in to comment.