Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon committed Dec 22, 2024
1 parent c2d1b07 commit 1aaced5
Show file tree
Hide file tree
Showing 11 changed files with 283 additions and 24 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,15 @@ set(VLLM_EXT_SRC
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/cuda_view.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/prepare_inputs/copy_subranges.cu"
"csrc/torch_bindings.cpp")

if(VLLM_GPU_LANG STREQUAL "CUDA")
Expand Down
8 changes: 8 additions & 0 deletions csrc/cuda_compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,11 @@
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif

// #ifndef USE_ROCM
// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \
// cudaHostGetDevicePointer(device_ptr, host_ptr, flags)
// #else
// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \
// hipHostGetDevicePointer(device_ptr, host_ptr, flags)
// #endif
43 changes: 43 additions & 0 deletions csrc/cuda_view.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include <torch/all.h>
#include <torch/cuda.h>

// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned
// memory, and that UVA (Unified Virtual Addressing) is enabled.
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
TORCH_CHECK(cpu_tensor.is_contiguous(), "Input tensor must be contiguous");

// Get raw host pointer from CPU tensor
void* host_ptr = cpu_tensor.data_ptr();

// Get a device pointer corresponding to the pinned host memory
void* device_ptr = nullptr;
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
TORCH_CHECK(err == cudaSuccess,
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));

// Construct a CUDA tensor from the device pointer.
// We'll use the same sizes, strides, and dtype as the CPU tensor.
auto sizes = cpu_tensor.sizes();
auto strides = cpu_tensor.strides();
auto options =
cpu_tensor.options().device(torch::kCUDA); // Change device to CUDA

// from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter,
// const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the
// memory, so we don't free it here.
auto deleter = [](void*) {
// no-op, since the memory is owned by the original CPU tensor
};

torch::Tensor cuda_tensor =
torch::from_blob(device_ptr, sizes, strides, deleter, options);

TORCH_CHECK(cuda_tensor.device().is_cuda(),
"Resulting tensor is not on CUDA device");
TORCH_CHECK(cuda_tensor.sizes().equals(sizes), "Size mismatch");
TORCH_CHECK(cuda_tensor.strides().equals(strides), "Stride mismatch");
TORCH_CHECK(cuda_tensor.dtype() == cpu_tensor.dtype(), "Dtype mismatch");

return cuda_tensor;
}
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ void advance_step_flashinfer(
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);

void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
torch::Tensor& matrix_tgt, int64_t n);

torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);

#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
Expand Down
63 changes: 63 additions & 0 deletions csrc/prepare_inputs/copy_subranges.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include <torch/all.h>

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

namespace vllm {
__global__ void copy_subranges_kernel(const int* __restrict__ matrix_src,
const int* __restrict__ matrix_diff,
int* __restrict__ matrix_tgt, int64_t M) {
int row_id = blockIdx.x;
int row_offset = row_id * M;

int start = matrix_diff[row_id * 2];
int length = matrix_diff[row_id * 2 + 1];
int end = start + length;
int thread_idx = threadIdx.x;
for (int i = start + thread_idx; i < end; i += blockDim.x) {
int idx = row_offset + i;
matrix_tgt[idx] = matrix_src[idx];
}
}
} // namespace vllm

void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
torch::Tensor& matrix_tgt, int64_t n) {
// Check tensor properties
TORCH_CHECK(matrix_src.is_cuda(), "matrix_src must be a CUDA tensor");
TORCH_CHECK(matrix_diff.is_cuda(), "matrix_diff must be a CUDA tensor");
TORCH_CHECK(matrix_tgt.is_cuda(), "matrix_tgt must be a CUDA tensor");
TORCH_CHECK(matrix_src.is_contiguous(), "matrix_src must be contiguous");
TORCH_CHECK(matrix_diff.is_contiguous(), "matrix_diff must be contiguous");
TORCH_CHECK(matrix_tgt.is_contiguous(), "matrix_tgt must be contiguous");

auto src_sizes = matrix_src.sizes();
auto diff_sizes = matrix_diff.sizes();
auto tgt_sizes = matrix_tgt.sizes();

TORCH_CHECK(src_sizes.size() == 2, "matrix_src must be 2D");
TORCH_CHECK(diff_sizes.size() == 2, "matrix_diff must be 2D");
TORCH_CHECK(tgt_sizes.size() == 2, "matrix_tgt must be 2D");

int64_t N = src_sizes[0];
int64_t M = src_sizes[1];

TORCH_CHECK(diff_sizes[0] == N, "matrix_diff first dim must match N");
TORCH_CHECK(diff_sizes[1] == 2, "matrix_diff second dim must be 2");
TORCH_CHECK(tgt_sizes[0] == N && tgt_sizes[1] == M,
"matrix_tgt must have same shape as matrix_src");

TORCH_CHECK(n <= N, "n must be <= N");

const int* d_matrix_src = matrix_src.data_ptr<int>();
const int* d_matrix_diff = matrix_diff.data_ptr<int>();
int* d_matrix_tgt = matrix_tgt.data_ptr<int>();

// One thread block per row.
int blocks = n;
int threads = 1024;
const at::cuda::OptionalCUDAGuard device_guard(device_of(matrix_tgt));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::copy_subranges_kernel<<<blocks, threads, 0, stream>>>(
d_matrix_src, d_matrix_diff, d_matrix_tgt, M);
}
10 changes: 10 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);

ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
&get_cuda_view_from_cpu_tensor);

// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
Expand Down Expand Up @@ -98,6 +102,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
") -> ()");
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);

ops.def(
"copy_subranges(Tensor matrix_src, Tensor matrix_diff, Tensor! "
"matrix_tgt, "
"int n) -> ()");
ops.impl("copy_subranges", torch::kCUDA, &copy_subranges);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
Expand Down
11 changes: 11 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,17 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
block_table_bound)


# copy subrange op. Used for input preparation in the vLLM V1 GPU backend.
def copy_subranges(
src_matrix: torch.Tensor,
diff_matrix: torch.Tensor,
tgt_matrix: torch.Tensor,
num_subranges: int,
) -> None:
torch.ops._C.copy_subranges(src_matrix, diff_matrix, tgt_matrix,
num_subranges)


# fused quant layer norm ops
def rms_norm_dynamic_per_token_quant(
input: torch.Tensor,
Expand Down
7 changes: 7 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,13 @@ def weak_ref_tensors(
raise ValueError("Invalid type for tensors")


def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
"""
Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
"""
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)


def is_in_doc_build() -> bool:
try:
from sphinx.ext.autodoc.mock import _MockModule
Expand Down
116 changes: 116 additions & 0 deletions vllm/v1/worker/gpu_block_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import List

import numpy as np
import torch

from vllm import _custom_ops as ops
from vllm.utils import get_cuda_view_from_cpu_tensor


class BlockTable:

def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
pin_memory: bool,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.pin_memory = pin_memory
self.device = device

self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_np = self.block_table_cpu.numpy()

# Pinned memory is required to use UVA.
# TODO(woosuk): Add other requirements for UVA.
self.use_uva = pin_memory
if self.use_uva:
self.block_table_diff = torch.zeros((max_num_reqs, 2),
dtype=torch.int32,
device="cpu",
pin_memory=True)
self.block_table_diff_np = self.block_table_diff.numpy()

self.block_table_cpu_cuda_view = get_cuda_view_from_cpu_tensor(
self.block_table_cpu)
self.block_table_diff_cuda_view = get_cuda_view_from_cpu_tensor(
self.block_table_diff)

def add_row(self, row_idx: int, block_ids: List[int]) -> None:
num_blocks = len(block_ids)
self.block_table_np[row_idx, :num_blocks] = block_ids
if self.use_uva:
self.block_table_diff_np[row_idx, 0] = 0
self.block_table_diff_np[row_idx, 1] = num_blocks

def append_row(
self,
row_idx: int,
start: int,
block_ids: List[int],
) -> None:
num_blocks = len(block_ids)
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
if self.use_uva:
self.block_table_diff_np[row_idx, 0] = start
# Move-and-append is not allowed.
assert self.block_table_diff_np[row_idx, 1] == 0
self.block_table_diff_np[row_idx, 1] = num_blocks

def move_row(self, src: int, tgt: int) -> None:
self.block_table_np[tgt] = self.block_table_np[src]
if self.use_uva:
# Append-and-move is allowed.
self.block_table_diff_np[tgt] = self.block_table_diff_np[src]
# Clear the source row.
self.block_table_diff_np[src].fill(0)

def apply_diff(self, num_reqs: int) -> None:
if self.use_uva:
# Only copy the diff to the GPU.
ops.copy_subranges(
self.block_table_cpu_cuda_view,
self.block_table_diff_cuda_view,
self.block_table,
num_reqs,
)
else:
# Copy the entire block table to the GPU.
# NOTE(woosuk): This can be a performance bottleneck when the block
# table is large.
self.block_table[:num_reqs].copy_(
self.block_table_cpu[:num_reqs], non_blocking=True)

def clear(self) -> None:
self.block_table.fill_(0)
self.block_table_cpu.fill_(0)
if self.use_uva:
self.block_table_diff.fill_(0)

def clear_diff(self) -> None:
if self.use_uva:
self.block_table_diff_np.fill(0)

def cuda(self) -> torch.Tensor:
return self.block_table

def cpu(self) -> torch.Tensor:
return self.block_table_cpu

def numpy(self) -> np.ndarray:
return self.block_table_np
26 changes: 10 additions & 16 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_block_table import BlockTable

if TYPE_CHECKING:
from vllm.multimodal.inputs import PlaceholderRange
Expand Down Expand Up @@ -64,19 +65,14 @@ def __init__(
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)

# Attention-related.
self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu_tensor = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
# Block table.
self.block_table = BlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_blocks_per_req=max_num_blocks_per_req,
pin_memory=pin_memory,
device=device,
)
self.block_table_cpu = self.block_table_cpu_tensor.numpy()

# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),
Expand Down Expand Up @@ -141,8 +137,7 @@ def add_request(
start_idx:end_idx] = request.output_token_ids

self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
num_blocks = len(request.block_ids)
self.block_table_cpu[req_index, :num_blocks] = request.block_ids
self.block_table.add_row(req_index, request.block_ids)

sampling_params = request.sampling_params
self.temperature_cpu[req_index] = sampling_params.temperature
Expand Down Expand Up @@ -221,13 +216,12 @@ def condense(self, empty_req_indices: List[int]) -> None:
self.req_id_to_index[req_id] = empty_index

# TODO(woosuk): Optimize the copy of token_ids_cpu and
# block_table_cpu.
# block_table.
self.token_ids_cpu[empty_index] = self.token_ids_cpu[
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table_cpu[empty_index] = self.block_table_cpu[
last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
Expand Down
Loading

0 comments on commit 1aaced5

Please sign in to comment.