diff --git a/CMakeLists.txt b/CMakeLists.txt index 83c8033434f3b..6e1f4296bca5c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -193,6 +193,7 @@ set(VLLM_EXT_SRC "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" "csrc/layernorm_quant_kernels.cu" + "csrc/cuda_view.cu" "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" @@ -200,6 +201,7 @@ set(VLLM_EXT_SRC "csrc/quantization/gguf/gguf_kernel.cu" "csrc/cuda_utils_kernels.cu" "csrc/prepare_inputs/advance_step.cu" + "csrc/prepare_inputs/copy_subranges.cu" "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 82e55613d915a..736788f8d664c 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -47,3 +47,11 @@ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) #endif + +// #ifndef USE_ROCM +// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \ +// cudaHostGetDevicePointer(device_ptr, host_ptr, flags) +// #else +// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \ +// hipHostGetDevicePointer(device_ptr, host_ptr, flags) +// #endif diff --git a/csrc/cuda_view.cu b/csrc/cuda_view.cu new file mode 100644 index 0000000000000..5035cf482a36e --- /dev/null +++ b/csrc/cuda_view.cu @@ -0,0 +1,43 @@ +#include +#include + +// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned +// memory, and that UVA (Unified Virtual Addressing) is enabled. +torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) { + TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU"); + TORCH_CHECK(cpu_tensor.is_contiguous(), "Input tensor must be contiguous"); + + // Get raw host pointer from CPU tensor + void* host_ptr = cpu_tensor.data_ptr(); + + // Get a device pointer corresponding to the pinned host memory + void* device_ptr = nullptr; + cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); + TORCH_CHECK(err == cudaSuccess, + "cudaHostGetDevicePointer failed: ", cudaGetErrorString(err)); + + // Construct a CUDA tensor from the device pointer. + // We'll use the same sizes, strides, and dtype as the CPU tensor. + auto sizes = cpu_tensor.sizes(); + auto strides = cpu_tensor.strides(); + auto options = + cpu_tensor.options().device(torch::kCUDA); // Change device to CUDA + + // from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter, + // const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the + // memory, so we don't free it here. + auto deleter = [](void*) { + // no-op, since the memory is owned by the original CPU tensor + }; + + torch::Tensor cuda_tensor = + torch::from_blob(device_ptr, sizes, strides, deleter, options); + + TORCH_CHECK(cuda_tensor.device().is_cuda(), + "Resulting tensor is not on CUDA device"); + TORCH_CHECK(cuda_tensor.sizes().equals(sizes), "Size mismatch"); + TORCH_CHECK(cuda_tensor.strides().equals(strides), "Stride mismatch"); + TORCH_CHECK(cuda_tensor.dtype() == cpu_tensor.dtype(), "Dtype mismatch"); + + return cuda_tensor; +} diff --git a/csrc/ops.h b/csrc/ops.h index 347c502845d8f..74a31401c499e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -115,6 +115,11 @@ void advance_step_flashinfer( torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); +void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff, + torch::Tensor& matrix_tgt, int64_t n); + +torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); + #ifndef USE_ROCM torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebooks, diff --git a/csrc/prepare_inputs/copy_subranges.cu b/csrc/prepare_inputs/copy_subranges.cu new file mode 100644 index 0000000000000..082b37000c65b --- /dev/null +++ b/csrc/prepare_inputs/copy_subranges.cu @@ -0,0 +1,72 @@ +#include + +#include +#include + +namespace vllm { +__global__ void copy_subranges_kernel(const int* __restrict__ matrix_src, + const int* __restrict__ matrix_diff, + int* __restrict__ matrix_tgt, int64_t M) { + int row_id = blockIdx.x; + int row_offset = row_id * M; + + int start = matrix_diff[row_id * 2]; + int length = matrix_diff[row_id * 2 + 1]; + int end = start + length; + int thread_idx = threadIdx.x; + for (int i = start + thread_idx; i < end; i += blockDim.x) { + int idx = row_offset + i; + matrix_tgt[idx] = matrix_src[idx]; + } +} +} // namespace vllm + +void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff, + torch::Tensor& matrix_tgt, int64_t n) { + // Check tensor properties + TORCH_CHECK(matrix_src.is_cuda(), "matrix_src must be a CUDA tensor"); + TORCH_CHECK(matrix_diff.is_cuda(), "matrix_diff must be a CUDA tensor"); + TORCH_CHECK(matrix_tgt.is_cuda(), "matrix_tgt must be a CUDA tensor"); + TORCH_CHECK(matrix_src.is_contiguous(), "matrix_src must be contiguous"); + TORCH_CHECK(matrix_diff.is_contiguous(), "matrix_diff must be contiguous"); + TORCH_CHECK(matrix_tgt.is_contiguous(), "matrix_tgt must be contiguous"); + + auto src_sizes = matrix_src.sizes(); + auto diff_sizes = matrix_diff.sizes(); + auto tgt_sizes = matrix_tgt.sizes(); + + TORCH_CHECK(src_sizes.size() == 2, "matrix_src must be 2D"); + TORCH_CHECK(diff_sizes.size() == 2, "matrix_diff must be 2D"); + TORCH_CHECK(tgt_sizes.size() == 2, "matrix_tgt must be 2D"); + + int64_t N = src_sizes[0]; + int64_t M = src_sizes[1]; + + TORCH_CHECK(diff_sizes[0] == N, "matrix_diff first dim must match N"); + TORCH_CHECK(diff_sizes[1] == 2, "matrix_diff second dim must be 2"); + TORCH_CHECK(tgt_sizes[0] == N && tgt_sizes[1] == M, + "matrix_tgt must have same shape as matrix_src"); + + TORCH_CHECK(n <= N, "n must be <= N"); + + const int* d_matrix_src = matrix_src.data_ptr(); + const int* d_matrix_diff = matrix_diff.data_ptr(); + int* d_matrix_tgt = matrix_tgt.data_ptr(); + + // One thread block per row. + int blocks = n; + int threads; + if (blocks < 128) { + threads = 1024; + } else if (blocks < 256) { + threads = 512; + } else if (blocks < 512) { + threads = 256; + } else { + threads = 128; + } + const at::cuda::OptionalCUDAGuard device_guard(device_of(matrix_tgt)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::copy_subranges_kernel<<>>( + d_matrix_src, d_matrix_diff, d_matrix_tgt, M); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 956258c1001d3..31394c04597a5 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -21,6 +21,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); + ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor"); + ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU, + &get_cuda_view_from_cpu_tensor); + // Attention ops // Compute the attention between an input query and the cached // keys/values using PagedAttention. @@ -98,6 +102,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ") -> ()"); ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); + ops.def( + "copy_subranges(Tensor matrix_src, Tensor matrix_diff, Tensor! " + "matrix_tgt, int n) -> ()"); + ops.impl("copy_subranges", torch::kCUDA, ©_subranges); + // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 19f31b8ec419d..b5b0115f4cf3b 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -249,6 +249,17 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, block_table_bound) +# copy subrange op. Used for input preparation in the vLLM V1 GPU backend. +def copy_subranges( + src_matrix: torch.Tensor, + diff_matrix: torch.Tensor, + tgt_matrix: torch.Tensor, + num_subranges: int, +) -> None: + torch.ops._C.copy_subranges(src_matrix, diff_matrix, tgt_matrix, + num_subranges) + + # fused quant layer norm ops def rms_norm_dynamic_per_token_quant( input: torch.Tensor, diff --git a/vllm/utils.py b/vllm/utils.py index 1b90eca1cd6cc..6d7f268bfd6b7 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1523,6 +1523,13 @@ def weak_ref_tensors( raise ValueError("Invalid type for tensors") +def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: + """ + Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA). + """ + return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) + + def is_in_doc_build() -> bool: try: from sphinx.ext.autodoc.mock import _MockModule diff --git a/vllm/v1/worker/gpu_block_table.py b/vllm/v1/worker/gpu_block_table.py new file mode 100644 index 0000000000000..25a927c371566 --- /dev/null +++ b/vllm/v1/worker/gpu_block_table.py @@ -0,0 +1,116 @@ +from typing import List + +import numpy as np +import torch + +from vllm import _custom_ops as ops +from vllm.utils import get_cuda_view_from_cpu_tensor + + +class BlockTable: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + pin_memory: bool, + device: torch.device, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.pin_memory = pin_memory + self.device = device + + self.block_table = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32, + ) + self.block_table_cpu = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_np = self.block_table_cpu.numpy() + + # Pinned memory is required to use UVA. + # TODO(woosuk): Add other requirements for UVA. + self.use_uva = pin_memory + if self.use_uva: + self.block_table_diff = torch.zeros((max_num_reqs, 2), + dtype=torch.int32, + device="cpu", + pin_memory=True) + self.block_table_diff_np = self.block_table_diff.numpy() + + self.block_table_cpu_cuda_view = get_cuda_view_from_cpu_tensor( + self.block_table_cpu) + self.block_table_diff_cuda_view = get_cuda_view_from_cpu_tensor( + self.block_table_diff) + + def add_row(self, row_idx: int, block_ids: List[int]) -> None: + num_blocks = len(block_ids) + self.block_table_np[row_idx, :num_blocks] = block_ids + if self.use_uva: + self.block_table_diff_np[row_idx, 0] = 0 + self.block_table_diff_np[row_idx, 1] = num_blocks + + def append_row( + self, + row_idx: int, + start: int, + block_ids: List[int], + ) -> None: + num_blocks = len(block_ids) + self.block_table_np[row_idx, start:start + num_blocks] = block_ids + if self.use_uva: + self.block_table_diff_np[row_idx, 0] = start + # Move-and-append is not allowed. + assert self.block_table_diff_np[row_idx, 1] == 0 + self.block_table_diff_np[row_idx, 1] = num_blocks + + def move_row(self, src: int, tgt: int) -> None: + self.block_table_np[tgt] = self.block_table_np[src] + if self.use_uva: + # Append-and-move is allowed. + self.block_table_diff_np[tgt] = self.block_table_diff_np[src] + # Clear the source row. + self.block_table_diff_np[src].fill(0) + + def apply_diff(self, num_reqs: int) -> None: + if self.use_uva: + # Only copy the diff to the GPU. + ops.copy_subranges( + self.block_table_cpu_cuda_view, + self.block_table_diff_cuda_view, + self.block_table, + num_reqs, + ) + else: + # Copy the entire block table to the GPU. + # NOTE(woosuk): This can be a performance bottleneck when the block + # table is large. + self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], + non_blocking=True) + + def clear(self) -> None: + self.block_table.fill_(0) + self.block_table_cpu.fill_(0) + if self.use_uva: + self.block_table_diff.fill_(0) + + def clear_diff(self) -> None: + if self.use_uva: + self.block_table_diff_np.fill(0) + + def cuda(self) -> torch.Tensor: + return self.block_table + + def cpu(self) -> torch.Tensor: + return self.block_table_cpu + + def numpy(self) -> np.ndarray: + return self.block_table_np diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 5c113c74778df..4f9890ad6a630 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -9,6 +9,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu_block_table import BlockTable if TYPE_CHECKING: from vllm.multimodal.inputs import PlaceholderRange @@ -64,19 +65,14 @@ def __init__( self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) - # Attention-related. - self.block_table = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device=self.device, - dtype=torch.int32, - ) - self.block_table_cpu_tensor = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device="cpu", - dtype=torch.int32, + # Block table. + self.block_table = BlockTable( + max_num_reqs=max_num_reqs, + max_model_len=max_model_len, + max_num_blocks_per_req=max_num_blocks_per_req, pin_memory=pin_memory, + device=device, ) - self.block_table_cpu = self.block_table_cpu_tensor.numpy() # Sampling-related. self.temperature = torch.empty((max_num_reqs, ), @@ -141,8 +137,7 @@ def add_request( start_idx:end_idx] = request.output_token_ids self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - num_blocks = len(request.block_ids) - self.block_table_cpu[req_index, :num_blocks] = request.block_ids + self.block_table.add_row(req_index, request.block_ids) sampling_params = request.sampling_params self.temperature_cpu[req_index] = sampling_params.temperature @@ -221,13 +216,12 @@ def condense(self, empty_req_indices: List[int]) -> None: self.req_id_to_index[req_id] = empty_index # TODO(woosuk): Optimize the copy of token_ids_cpu and - # block_table_cpu. + # block_table. self.token_ids_cpu[empty_index] = self.token_ids_cpu[ last_req_index] self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] - self.block_table_cpu[empty_index] = self.block_table_cpu[ - last_req_index] + self.block_table.move_row(last_req_index, empty_index) self.temperature_cpu[empty_index] = self.temperature_cpu[ last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ace62d8978bea..f4dc0d945407d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -162,6 +162,9 @@ def __init__( self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + # Clean up diffs. + self.input_batch.block_table.clear_diff() + # Remove stopped requests from the cached states. # Keep the states of the pre-empted requests. for req_id in scheduler_output.finished_req_ids: @@ -203,10 +206,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if num_new_blocks == 0: continue start_index = len(req_state.block_ids) - end_index = start_index + num_new_blocks req_state.block_ids.extend(req_data.new_block_ids) - self.input_batch.block_table_cpu[ - req_index, start_index:end_index] = req_data.new_block_ids + self.input_batch.block_table.append_row(req_index, start_index, + req_data.new_block_ids) req_ids_to_add: List[str] = [] # Add new requests to the cached states. @@ -267,9 +269,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. - self.input_batch.block_table[:num_reqs].copy_( - self.input_batch.block_table_cpu_tensor[:num_reqs], - non_blocking=True) + self.input_batch.block_table.apply_diff(num_reqs) # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. @@ -325,7 +325,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - block_numbers = (self.input_batch.block_table_cpu_tensor.flatten() + block_numbers = (self.input_batch.block_table.cpu().flatten() [block_table_indices].numpy()) block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, @@ -360,7 +360,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_start_loc=seq_start_loc, - block_table=self.input_batch.block_table[:num_reqs], + block_table=self.input_batch.block_table.cuda()[:num_reqs], slot_mapping=slot_mapping, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial