-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Woosuk Kwon <[email protected]>
- Loading branch information
1 parent
c2d1b07
commit 1aaced5
Showing
11 changed files
with
283 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#include <torch/all.h> | ||
#include <torch/cuda.h> | ||
|
||
// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned | ||
// memory, and that UVA (Unified Virtual Addressing) is enabled. | ||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) { | ||
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU"); | ||
TORCH_CHECK(cpu_tensor.is_contiguous(), "Input tensor must be contiguous"); | ||
|
||
// Get raw host pointer from CPU tensor | ||
void* host_ptr = cpu_tensor.data_ptr(); | ||
|
||
// Get a device pointer corresponding to the pinned host memory | ||
void* device_ptr = nullptr; | ||
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); | ||
TORCH_CHECK(err == cudaSuccess, | ||
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err)); | ||
|
||
// Construct a CUDA tensor from the device pointer. | ||
// We'll use the same sizes, strides, and dtype as the CPU tensor. | ||
auto sizes = cpu_tensor.sizes(); | ||
auto strides = cpu_tensor.strides(); | ||
auto options = | ||
cpu_tensor.options().device(torch::kCUDA); // Change device to CUDA | ||
|
||
// from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter, | ||
// const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the | ||
// memory, so we don't free it here. | ||
auto deleter = [](void*) { | ||
// no-op, since the memory is owned by the original CPU tensor | ||
}; | ||
|
||
torch::Tensor cuda_tensor = | ||
torch::from_blob(device_ptr, sizes, strides, deleter, options); | ||
|
||
TORCH_CHECK(cuda_tensor.device().is_cuda(), | ||
"Resulting tensor is not on CUDA device"); | ||
TORCH_CHECK(cuda_tensor.sizes().equals(sizes), "Size mismatch"); | ||
TORCH_CHECK(cuda_tensor.strides().equals(strides), "Stride mismatch"); | ||
TORCH_CHECK(cuda_tensor.dtype() == cpu_tensor.dtype(), "Dtype mismatch"); | ||
|
||
return cuda_tensor; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#include <torch/all.h> | ||
|
||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
namespace vllm { | ||
__global__ void copy_subranges_kernel(const int* __restrict__ matrix_src, | ||
const int* __restrict__ matrix_diff, | ||
int* __restrict__ matrix_tgt, int64_t M) { | ||
int row_id = blockIdx.x; | ||
int row_offset = row_id * M; | ||
|
||
int start = matrix_diff[row_id * 2]; | ||
int length = matrix_diff[row_id * 2 + 1]; | ||
int end = start + length; | ||
int thread_idx = threadIdx.x; | ||
for (int i = start + thread_idx; i < end; i += blockDim.x) { | ||
int idx = row_offset + i; | ||
matrix_tgt[idx] = matrix_src[idx]; | ||
} | ||
} | ||
} // namespace vllm | ||
|
||
void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff, | ||
torch::Tensor& matrix_tgt, int64_t n) { | ||
// Check tensor properties | ||
TORCH_CHECK(matrix_src.is_cuda(), "matrix_src must be a CUDA tensor"); | ||
TORCH_CHECK(matrix_diff.is_cuda(), "matrix_diff must be a CUDA tensor"); | ||
TORCH_CHECK(matrix_tgt.is_cuda(), "matrix_tgt must be a CUDA tensor"); | ||
TORCH_CHECK(matrix_src.is_contiguous(), "matrix_src must be contiguous"); | ||
TORCH_CHECK(matrix_diff.is_contiguous(), "matrix_diff must be contiguous"); | ||
TORCH_CHECK(matrix_tgt.is_contiguous(), "matrix_tgt must be contiguous"); | ||
|
||
auto src_sizes = matrix_src.sizes(); | ||
auto diff_sizes = matrix_diff.sizes(); | ||
auto tgt_sizes = matrix_tgt.sizes(); | ||
|
||
TORCH_CHECK(src_sizes.size() == 2, "matrix_src must be 2D"); | ||
TORCH_CHECK(diff_sizes.size() == 2, "matrix_diff must be 2D"); | ||
TORCH_CHECK(tgt_sizes.size() == 2, "matrix_tgt must be 2D"); | ||
|
||
int64_t N = src_sizes[0]; | ||
int64_t M = src_sizes[1]; | ||
|
||
TORCH_CHECK(diff_sizes[0] == N, "matrix_diff first dim must match N"); | ||
TORCH_CHECK(diff_sizes[1] == 2, "matrix_diff second dim must be 2"); | ||
TORCH_CHECK(tgt_sizes[0] == N && tgt_sizes[1] == M, | ||
"matrix_tgt must have same shape as matrix_src"); | ||
|
||
TORCH_CHECK(n <= N, "n must be <= N"); | ||
|
||
const int* d_matrix_src = matrix_src.data_ptr<int>(); | ||
const int* d_matrix_diff = matrix_diff.data_ptr<int>(); | ||
int* d_matrix_tgt = matrix_tgt.data_ptr<int>(); | ||
|
||
// One thread block per row. | ||
int blocks = n; | ||
int threads = 1024; | ||
const at::cuda::OptionalCUDAGuard device_guard(device_of(matrix_tgt)); | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
vllm::copy_subranges_kernel<<<blocks, threads, 0, stream>>>( | ||
d_matrix_src, d_matrix_diff, d_matrix_tgt, M); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from typing import List | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from vllm import _custom_ops as ops | ||
from vllm.utils import get_cuda_view_from_cpu_tensor | ||
|
||
|
||
class BlockTable: | ||
|
||
def __init__( | ||
self, | ||
max_num_reqs: int, | ||
max_model_len: int, | ||
max_num_blocks_per_req: int, | ||
pin_memory: bool, | ||
device: torch.device, | ||
): | ||
self.max_num_reqs = max_num_reqs | ||
self.max_model_len = max_model_len | ||
self.max_num_blocks_per_req = max_num_blocks_per_req | ||
self.pin_memory = pin_memory | ||
self.device = device | ||
|
||
self.block_table = torch.zeros( | ||
(max_num_reqs, max_num_blocks_per_req), | ||
device=self.device, | ||
dtype=torch.int32, | ||
) | ||
self.block_table_cpu = torch.zeros( | ||
(max_num_reqs, max_num_blocks_per_req), | ||
device="cpu", | ||
dtype=torch.int32, | ||
pin_memory=pin_memory, | ||
) | ||
self.block_table_np = self.block_table_cpu.numpy() | ||
|
||
# Pinned memory is required to use UVA. | ||
# TODO(woosuk): Add other requirements for UVA. | ||
self.use_uva = pin_memory | ||
if self.use_uva: | ||
self.block_table_diff = torch.zeros((max_num_reqs, 2), | ||
dtype=torch.int32, | ||
device="cpu", | ||
pin_memory=True) | ||
self.block_table_diff_np = self.block_table_diff.numpy() | ||
|
||
self.block_table_cpu_cuda_view = get_cuda_view_from_cpu_tensor( | ||
self.block_table_cpu) | ||
self.block_table_diff_cuda_view = get_cuda_view_from_cpu_tensor( | ||
self.block_table_diff) | ||
|
||
def add_row(self, row_idx: int, block_ids: List[int]) -> None: | ||
num_blocks = len(block_ids) | ||
self.block_table_np[row_idx, :num_blocks] = block_ids | ||
if self.use_uva: | ||
self.block_table_diff_np[row_idx, 0] = 0 | ||
self.block_table_diff_np[row_idx, 1] = num_blocks | ||
|
||
def append_row( | ||
self, | ||
row_idx: int, | ||
start: int, | ||
block_ids: List[int], | ||
) -> None: | ||
num_blocks = len(block_ids) | ||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids | ||
if self.use_uva: | ||
self.block_table_diff_np[row_idx, 0] = start | ||
# Move-and-append is not allowed. | ||
assert self.block_table_diff_np[row_idx, 1] == 0 | ||
self.block_table_diff_np[row_idx, 1] = num_blocks | ||
|
||
def move_row(self, src: int, tgt: int) -> None: | ||
self.block_table_np[tgt] = self.block_table_np[src] | ||
if self.use_uva: | ||
# Append-and-move is allowed. | ||
self.block_table_diff_np[tgt] = self.block_table_diff_np[src] | ||
# Clear the source row. | ||
self.block_table_diff_np[src].fill(0) | ||
|
||
def apply_diff(self, num_reqs: int) -> None: | ||
if self.use_uva: | ||
# Only copy the diff to the GPU. | ||
ops.copy_subranges( | ||
self.block_table_cpu_cuda_view, | ||
self.block_table_diff_cuda_view, | ||
self.block_table, | ||
num_reqs, | ||
) | ||
else: | ||
# Copy the entire block table to the GPU. | ||
# NOTE(woosuk): This can be a performance bottleneck when the block | ||
# table is large. | ||
self.block_table[:num_reqs].copy_( | ||
self.block_table_cpu[:num_reqs], non_blocking=True) | ||
|
||
def clear(self) -> None: | ||
self.block_table.fill_(0) | ||
self.block_table_cpu.fill_(0) | ||
if self.use_uva: | ||
self.block_table_diff.fill_(0) | ||
|
||
def clear_diff(self) -> None: | ||
if self.use_uva: | ||
self.block_table_diff_np.fill(0) | ||
|
||
def cuda(self) -> torch.Tensor: | ||
return self.block_table | ||
|
||
def cpu(self) -> torch.Tensor: | ||
return self.block_table_cpu | ||
|
||
def numpy(self) -> np.ndarray: | ||
return self.block_table_np |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.