Skip to content

Commit

Permalink
[Kernel] fix types used in aqlm and ggml kernels to support dynamo (v…
Browse files Browse the repository at this point in the history
  • Loading branch information
bnellnm authored Aug 16, 2024
1 parent 7759ae9 commit 37fd47e
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 53 deletions.
16 changes: 8 additions & 8 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes,
const std::vector<int64_t>& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias);

torch::Tensor aqlm_dequant(const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes);
torch::Tensor aqlm_dequant(
const torch::Tensor& codes, const torch::Tensor& codebooks,
const std::vector<int64_t>& codebook_partition_sizes);

torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros,
Expand Down Expand Up @@ -107,13 +107,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);

torch::Tensor ggml_dequantize(torch::Tensor W, int8_t type, int64_t m,
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n);

torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int8_t type,
int64_t row);
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
int64_t type, int64_t row);

torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int8_t type,
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
int64_t row);

torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
Expand Down
25 changes: 13 additions & 12 deletions csrc/quantization/aqlm/gemm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -496,14 +496,14 @@ torch::Tensor code2x8_matmat(const torch::Tensor& input,
}

// Accumulate the partition sizes.
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
int4 accumulate_sizes(const std::vector<int64_t>& codebook_partition_sizes) {
int4 cumulative_sizes;
auto cumulative_size = &cumulative_sizes.x;
int i = 0;
size_t i = 0;
int last = 0;
assert(codebook_partition_sizes.size(0) <= 4);
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) {
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
assert(codebook_partition_sizes.size() <= 4);
for (; i < codebook_partition_sizes.size(); ++i, ++cumulative_size) {
*cumulative_size = codebook_partition_sizes[i] + last;
last = *cumulative_size;
}
// fill in the rest with unreachable.
Expand All @@ -519,12 +519,12 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes,
const std::vector<int64_t>& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias) {
int4 cumulative_sizes =
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);

int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
int const entries = codebooks.size(1);

if (nbooks == 1 && entries == (1 << 16)) {
Expand All @@ -541,13 +541,13 @@ torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
return {};
}

torch::Tensor aqlm_dequant(const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes) {
torch::Tensor aqlm_dequant(
const torch::Tensor& codes, const torch::Tensor& codebooks,
const std::vector<int64_t>& codebook_partition_sizes) {
int4 cumulative_sizes =
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);

int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
int const entries = codebooks.size(1);

const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
Expand All @@ -557,7 +557,8 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes,
auto in_features = codes.size(1) * 8;
auto out_features = codes.size(0);

assert(out_features = codebook_partition_sizes.sum().item<int>());
assert(out_features == std::accumulate(codebook_partition_sizes.begin(),
codebook_partition_sizes.end(), 0));

auto weights = torch::empty({out_features, in_features},
torch::TensorOptions()
Expand Down
2 changes: 1 addition & 1 deletion csrc/quantization/gguf/dequantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
}

static to_fp16_cuda_t ggml_get_to_fp16_cuda(int type) {
static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
switch (type) {
case 2:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
Expand Down
8 changes: 4 additions & 4 deletions csrc/quantization/gguf/gguf_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx,
}

torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
int8_t type, int64_t m, int64_t n) {
int64_t type, int64_t m, int64_t n) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
auto options =
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
Expand All @@ -73,7 +73,7 @@ torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight

torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
torch::Tensor X, // input
int8_t type, int64_t row) {
int64_t type, int64_t row) {
int col = X.sizes()[1];
const int padded = (col + 512 - 1) / 512 * 512;
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
Expand Down Expand Up @@ -172,7 +172,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight

torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
torch::Tensor X, // input
int8_t type, int64_t row) {
int64_t type, int64_t row) {
int col = X.sizes()[1];
int padded = (col + 512 - 1) / 512 * 512;
int batch = X.sizes()[0];
Expand Down Expand Up @@ -239,4 +239,4 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
break;
}
return Y;
}
}
28 changes: 7 additions & 21 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,7 @@
logger.warning("Failed to import from vllm._C with %r", e)

with contextlib.suppress(ImportError):
# ruff: noqa: F401
import vllm._moe_C


def is_custom_op_supported(op_name: str) -> bool:
op, overloads = torch._C._jit_get_operation(op_name)
return op is not None
import vllm._moe_C # noqa: F401


def hint_on_error(fn):
Expand Down Expand Up @@ -280,14 +274,14 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: torch.Tensor,
codebook_partition_sizes: List[int],
bias: Optional[torch.Tensor]) -> torch.Tensor:
return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
codebook_partition_sizes, bias)


def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
codebook_partition_sizes: List[int]) -> torch.Tensor:
return torch.ops._C.aqlm_dequant(codes, codebooks,
codebook_partition_sizes)

Expand Down Expand Up @@ -434,25 +428,17 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,


# gguf
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int):
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
n: int) -> torch.Tensor:
return torch.ops._C.ggml_dequantize(W, quant_type, m, n)


def ggml_mul_mat_vec(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
):
return torch.ops._C.ggml_mul_mat_vec(W, X, quant_type, row)


def ggml_mul_mat_vec_a8(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
):
) -> torch.Tensor:
return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)


Expand All @@ -461,7 +447,7 @@ def ggml_mul_mat_a8(
X: torch.Tensor,
quant_type: int,
row: int,
):
) -> torch.Tensor:
return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)


Expand Down
12 changes: 5 additions & 7 deletions vllm/model_executor/layers/quantization/aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def generic_dequantize_gemm(
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
output_partition_sizes: List[int],
bias: Optional[torch.Tensor],
) -> torch.Tensor:
output_shape = input.shape[:-1] + (scales.shape[0], )
Expand Down Expand Up @@ -133,7 +133,7 @@ def optimized_dequantize_gemm(
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
output_partition_sizes: List[int],
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
Expand Down Expand Up @@ -288,10 +288,8 @@ def create_weights(self, layer: torch.nn.Module,
codebooks,
{
# metadata indicates fixed size concatenated along dim 0
"is_metadata":
True,
"output_partition_sizes":
torch.tensor(output_partition_sizes, device='cpu'),
"is_metadata": True,
"output_partition_sizes": output_partition_sizes
},
)

Expand Down Expand Up @@ -334,7 +332,7 @@ def apply(
codes = layer.codes
scales = layer.scales
output_partition_sizes = getattr(codebooks, "output_partition_sizes",
None)
[])

nbooks = codes.shape[2]
ingroups = codebooks.shape[3]
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
layer.g_idx.data = torch.empty((0, ),
dtype=torch.int,
device=layer.g_idx.device)
layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
Expand Down

0 comments on commit 37fd47e

Please sign in to comment.